Skip to content

Commit

Permalink
fix: authentication middleware is implemented by changing from framew…
Browse files Browse the repository at this point in the history
…ork droplet to framework gin (#2254)
  • Loading branch information
nic-chen authored Dec 19, 2021
1 parent ffa596d commit b565f7c
Show file tree
Hide file tree
Showing 15 changed files with 163 additions and 195 deletions.
3 changes: 1 addition & 2 deletions api/internal/core/server/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ import (

"github.com/apisix/manager-api/internal"
"github.com/apisix/manager-api/internal/conf"
"github.com/apisix/manager-api/internal/filter"
"github.com/apisix/manager-api/internal/handler"
)

Expand All @@ -37,7 +36,7 @@ func (s *server) setupAPI() {
var newMws []droplet.Middleware
// default middleware order: resp_reshape, auto_input, traffic_log
// We should put err_transform at second to catch all error
newMws = append(newMws, mws[0], &handler.ErrorTransformMiddleware{}, &filter.AuthenticationMiddleware{})
newMws = append(newMws, mws[0], &handler.ErrorTransformMiddleware{})
newMws = append(newMws, mws[1:]...)
return newMws
}
Expand Down
2 changes: 1 addition & 1 deletion api/internal/core/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ type server struct {
options *Options
}

type Options struct {}
type Options struct{}

// NewServer Create a server manager
func NewServer(options *Options) (*server, error) {
Expand Down
123 changes: 53 additions & 70 deletions api/internal/filter/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,84 +17,67 @@
package filter

import (
"errors"
"net/http"
"strings"

"github.com/dgrijalva/jwt-go"
"github.com/shiningrush/droplet"
"github.com/shiningrush/droplet/data"
"github.com/shiningrush/droplet/middleware"
"github.com/gin-gonic/gin"

"github.com/apisix/manager-api/internal/conf"
"github.com/apisix/manager-api/internal/log"
)

type AuthenticationMiddleware struct {
middleware.BaseMiddleware
}

func (mw *AuthenticationMiddleware) Handle(ctx droplet.Context) error {
httpReq := ctx.Get(middleware.KeyHttpRequest)
if httpReq == nil {
err := errors.New("input middleware cannot get http request")

// Wrong usage, just panic here and let recoverHandler to deal with
panic(err)
}

req := httpReq.(*http.Request)

if req.URL.Path == "/apisix/admin/tool/version" || req.URL.Path == "/apisix/admin/user/login" {
return mw.BaseMiddleware.Handle(ctx)
}

if !strings.HasPrefix(req.URL.Path, "/apisix") {
return mw.BaseMiddleware.Handle(ctx)
}

// Need check the auth header
tokenStr := req.Header.Get("Authorization")

// verify token
token, err := jwt.ParseWithClaims(tokenStr, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(conf.AuthConf.Secret), nil
})

// TODO: design the response error code
response := data.Response{Code: 010013, Message: "request unauthorized"}

if err != nil || token == nil || !token.Valid {
log.Warnf("token validate failed: %s", err)
log.Warn("please check the secret in conf.yaml")
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
return nil
}

claims, ok := token.Claims.(*jwt.StandardClaims)
if !ok {
log.Warnf("token validate failed: %s, %v", err, token.Valid)
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
return nil
}

if err := token.Claims.Valid(); err != nil {
log.Warnf("token claims validate failed: %s", err)
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
return nil
func Authentication() gin.HandlerFunc {
return func(c *gin.Context) {
if c.Request.URL.Path == "/apisix/admin/user/login" ||
c.Request.URL.Path == "/apisix/admin/tool/version" ||
!strings.HasPrefix(c.Request.URL.Path, "/apisix") {
c.Next()
return
}

tokenStr := c.GetHeader("Authorization")
// verify token
token, err := jwt.ParseWithClaims(tokenStr, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
return []byte(conf.AuthConf.Secret), nil
})

errResp := gin.H{
"code": 010013,
"message": "request unauthorized",
}

if err != nil || token == nil || !token.Valid {
log.Warnf("token validate failed: %s", err)
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
return
}

claims, ok := token.Claims.(*jwt.StandardClaims)
if !ok {
log.Warnf("token validate failed: %s, %v", err, token.Valid)
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
return
}

if err := token.Claims.Valid(); err != nil {
log.Warnf("token claims validate failed: %s", err)
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
return
}

if claims.Subject == "" {
log.Warn("token claims subject empty")
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
return
}

if _, ok := conf.UserList[claims.Subject]; !ok {
log.Warnf("user not exists by token claims subject %s", claims.Subject)
c.AbortWithStatusJSON(http.StatusUnauthorized, errResp)
return
}

c.Next()
}

if claims.Subject == "" {
log.Warn("token claims subject empty")
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
return nil
}

if _, ok := conf.UserList[claims.Subject]; !ok {
log.Warnf("user not exists by token claims subject %s", claims.Subject)
ctx.SetOutput(&data.SpecCodeResponse{StatusCode: http.StatusUnauthorized, Response: response})
return nil
}

return mw.BaseMiddleware.Handle(ctx)
}
86 changes: 22 additions & 64 deletions api/internal/filter/authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,12 @@
package filter

import (
"errors"
"net/http"
"net/url"
"testing"
"time"

"github.com/dgrijalva/jwt-go"
"github.com/shiningrush/droplet"
"github.com/shiningrush/droplet/data"
"github.com/shiningrush/droplet/middleware"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"

"github.com/apisix/manager-api/internal/conf"
Expand All @@ -44,73 +40,35 @@ func genToken(username string, issueAt, expireAt int64) string {
return signedToken
}

type mockMiddleware struct {
middleware.BaseMiddleware
}

func (mw *mockMiddleware) Handle(ctx droplet.Context) error {
return errors.New("next middleware")
}

func testPanic(t *testing.T, mw AuthenticationMiddleware, ctx droplet.Context) {
defer func() {
panicErr := recover()
assert.Contains(t, panicErr.(error).Error(), "input middleware cannot get http request")
}()
_ = mw.Handle(ctx)
}

func TestAuthenticationMiddleware_Handle(t *testing.T) {
ctx := droplet.NewContext()
fakeReq, _ := http.NewRequest(http.MethodGet, "", nil)
expectOutput := &data.SpecCodeResponse{
Response: data.Response{
Code: 010013,
Message: "request unauthorized",
},
StatusCode: http.StatusUnauthorized,
}

mw := AuthenticationMiddleware{}
mockMw := mockMiddleware{}
mw.SetNext(&mockMw)

// test without http.Request
testPanic(t, mw, ctx)

ctx.Set(middleware.KeyHttpRequest, fakeReq)
r := gin.New()
r.Use(Authentication())
r.GET("/*path", func(c *gin.Context) {
})

// test without token check
fakeReq.URL = &url.URL{Path: "/apisix/admin/user/login"}
assert.Equal(t, mw.Handle(ctx), errors.New("next middleware"))
w := performRequest(r, "GET", "/apisix/admin/user/login", nil)
assert.Equal(t, http.StatusOK, w.Code)

// test without authorization header
fakeReq.URL = &url.URL{Path: "/apisix/admin/routes"}
assert.Nil(t, mw.Handle(ctx))
assert.Equal(t, expectOutput, ctx.Output().(*data.SpecCodeResponse))
w = performRequest(r, "GET", "/apisix/admin/routes", nil)
assert.Equal(t, http.StatusUnauthorized, w.Code)

// test with token expire
expireToken := genToken("admin", time.Now().Unix(), time.Now().Unix()-60*3600)
fakeReq.Header.Set("Authorization", expireToken)
assert.Nil(t, mw.Handle(ctx))
assert.Equal(t, expectOutput, ctx.Output().(*data.SpecCodeResponse))
w = performRequest(r, "GET", "/apisix/admin/routes", map[string]string{"Authorization": expireToken})
assert.Equal(t, http.StatusUnauthorized, w.Code)

// test with temp subject
tempSubjectToken := genToken("", time.Now().Unix(), time.Now().Unix()+60*3600)
fakeReq.Header.Set("Authorization", tempSubjectToken)
assert.Nil(t, mw.Handle(ctx))
assert.Equal(t, expectOutput, ctx.Output().(*data.SpecCodeResponse))
// test with empty subject
emptySubjectToken := genToken("", time.Now().Unix(), time.Now().Unix()+60*3600)
w = performRequest(r, "GET", "/apisix/admin/routes", map[string]string{"Authorization": emptySubjectToken})
assert.Equal(t, http.StatusUnauthorized, w.Code)

// test username doesn't exist
userToken := genToken("user1", time.Now().Unix(), time.Now().Unix()+60*3600)
fakeReq.Header.Set("Authorization", userToken)
assert.Nil(t, mw.Handle(ctx))
assert.Equal(t, expectOutput, ctx.Output().(*data.SpecCodeResponse))
// test token with nonexistent username
nonexistentUserToken := genToken("user1", time.Now().Unix(), time.Now().Unix()+60*3600)
w = performRequest(r, "GET", "/apisix/admin/routes", map[string]string{"Authorization": nonexistentUserToken})
assert.Equal(t, http.StatusUnauthorized, w.Code)

// test auth success
adminToken := genToken("admin", time.Now().Unix(), time.Now().Unix()+60*3600)
fakeReq.Header.Set("Authorization", adminToken)
ctx.SetOutput("test data")
assert.Equal(t, mw.Handle(ctx), errors.New("next middleware"))
assert.Equal(t, "test data", ctx.Output().(string))
validToken := genToken("admin", time.Now().Unix(), time.Now().Unix()+60*3600)
w = performRequest(r, "GET", "/apisix/admin/routes", map[string]string{"Authorization": validToken})
assert.Equal(t, http.StatusOK, w.Code)
}
6 changes: 3 additions & 3 deletions api/internal/filter/ip_filter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func TestIPFilter_Handle(t *testing.T) {
r.GET("/", func(c *gin.Context) {
})

w := performRequest(r, "GET", "/")
w := performRequest(r, "GET", "/", nil)
assert.Equal(t, 200, w.Code)

// should forbidden
Expand All @@ -45,7 +45,7 @@ func TestIPFilter_Handle(t *testing.T) {
r.GET("/fbd", func(c *gin.Context) {
})

w = performRequest(r, "GET", "/fbd")
w = performRequest(r, "GET", "/fbd", nil)
assert.Equal(t, 403, w.Code)

// should allowed
Expand All @@ -54,7 +54,7 @@ func TestIPFilter_Handle(t *testing.T) {
r.Use(IPFilter())
r.GET("/test", func(c *gin.Context) {
})
w = performRequest(r, "GET", "/test")
w = performRequest(r, "GET", "/test", nil)
assert.Equal(t, 200, w.Code)

// should forbidden
Expand Down
7 changes: 5 additions & 2 deletions api/internal/filter/logging_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ import (
"github.com/apisix/manager-api/internal/log"
)

func performRequest(r http.Handler, method, path string) *httptest.ResponseRecorder {
func performRequest(r http.Handler, method, path string, headers map[string]string) *httptest.ResponseRecorder {
req := httptest.NewRequest(method, path, nil)
for key, val := range headers {
req.Header.Add(key, val)
}
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
return w
Expand All @@ -41,6 +44,6 @@ func TestRequestLogHandler(t *testing.T) {
r.GET("/", func(c *gin.Context) {
})

w := performRequest(r, "GET", "/")
w := performRequest(r, "GET", "/", nil)
assert.Equal(t, 200, w.Code)
}
Loading

0 comments on commit b565f7c

Please sign in to comment.