mirror of
https://github.com/retailcrm/mg-transport-core.git
synced 2024-11-25 06:36:03 +03:00
pass error reason to csrf abort handler
This commit is contained in:
parent
5a497fcc73
commit
20dc1f1aa8
65
core/csrf.go
65
core/csrf.go
@ -15,8 +15,31 @@ import (
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
// CSRFErrorReason is a error reason type
|
||||
type CSRFErrorReason uint8
|
||||
|
||||
// CSRFTokenGetter func type
|
||||
type CSRFTokenGetter func(c *gin.Context) string
|
||||
type CSRFTokenGetter func(*gin.Context) string
|
||||
|
||||
// CSRFAbortFunc is a callback which
|
||||
type CSRFAbortFunc func(*gin.Context, CSRFErrorReason)
|
||||
|
||||
const (
|
||||
// CSRFErrorNoTokenInSession will be returned if token is not present in session
|
||||
CSRFErrorNoTokenInSession CSRFErrorReason = iota
|
||||
|
||||
// CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session
|
||||
CSRFErrorCannotStoreTokenInSession
|
||||
|
||||
// CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string
|
||||
CSRFErrorIncorrectTokenType
|
||||
|
||||
// CSRFErrorEmptyToken will be returned if token in session is empty
|
||||
CSRFErrorEmptyToken
|
||||
|
||||
// CSRFErrorTokenMismatch will be returned in case of invalid token
|
||||
CSRFErrorTokenMismatch
|
||||
)
|
||||
|
||||
// DefaultCSRFTokenGetter default getter
|
||||
var DefaultCSRFTokenGetter = func(c *gin.Context) string {
|
||||
@ -50,7 +73,7 @@ type CSRF struct {
|
||||
salt string
|
||||
secret string
|
||||
sessionName string
|
||||
abortFunc gin.HandlerFunc
|
||||
abortFunc CSRFAbortFunc
|
||||
csrfTokenGetter CSRFTokenGetter
|
||||
store sessions.Store
|
||||
}
|
||||
@ -60,7 +83,7 @@ type CSRF struct {
|
||||
// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default,
|
||||
// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token.
|
||||
// Usage (with random salt):
|
||||
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context) {
|
||||
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) {
|
||||
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
|
||||
// }, core.DefaultCSRFTokenGetter)
|
||||
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data!
|
||||
@ -70,7 +93,8 @@ type CSRF struct {
|
||||
// }
|
||||
// will close body - and all next middlewares won't be able to read body at all!
|
||||
// Use DefaultCSRFTokenGetter as example to implement your own token getter.
|
||||
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc gin.HandlerFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
|
||||
// CSRFErrorReason will be passed to abortFunc and can be used for better error messages.
|
||||
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
|
||||
if store == nil {
|
||||
panic("store must not be nil")
|
||||
}
|
||||
@ -178,14 +202,14 @@ func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc {
|
||||
if i, ok := session.Values["csrf_token"]; ok {
|
||||
if i, ok := i.(string); !ok || i == "" {
|
||||
if x.fillToken(session, c) != nil {
|
||||
x.abortFunc(c)
|
||||
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if x.fillToken(session, c) != nil {
|
||||
x.abortFunc(c)
|
||||
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
@ -216,22 +240,45 @@ func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc {
|
||||
if i, ok := session.Values["csrf_token"]; ok {
|
||||
var v string
|
||||
if v, ok = i.(string); !ok || v == "" {
|
||||
x.abortFunc(c)
|
||||
if !ok {
|
||||
x.abortFunc(c, CSRFErrorIncorrectTokenType)
|
||||
} else if v == "" {
|
||||
x.abortFunc(c, CSRFErrorEmptyToken)
|
||||
}
|
||||
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
token = v
|
||||
} else {
|
||||
x.abortFunc(c)
|
||||
x.abortFunc(c, CSRFErrorNoTokenInSession)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if x.csrfTokenGetter(c) != token {
|
||||
x.abortFunc(c)
|
||||
x.abortFunc(c, CSRFErrorTokenMismatch)
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs)
|
||||
func GetCSRFErrorMessage(r CSRFErrorReason) string {
|
||||
switch r {
|
||||
case CSRFErrorNoTokenInSession:
|
||||
return "token is not present in session"
|
||||
case CSRFErrorCannotStoreTokenInSession:
|
||||
return "cannot store token in session"
|
||||
case CSRFErrorIncorrectTokenType:
|
||||
return "incorrect token type"
|
||||
case CSRFErrorEmptyToken:
|
||||
return "empty token present in session"
|
||||
case CSRFErrorTokenMismatch:
|
||||
return "token mismatch"
|
||||
default:
|
||||
return "unknown error"
|
||||
}
|
||||
}
|
||||
|
@ -101,7 +101,7 @@ func TestCSRF_NewCSRF_NilStore(t *testing.T) {
|
||||
assert.NotNil(t, recover())
|
||||
}()
|
||||
|
||||
NewCSRF("salt", "secret", "csrf", nil, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
||||
NewCSRF("salt", "secret", "csrf", nil, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||
}
|
||||
|
||||
func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
|
||||
@ -110,23 +110,38 @@ func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
|
||||
}()
|
||||
|
||||
store := sessions.NewCookieStore([]byte("keys"))
|
||||
NewCSRF("salt", "", "csrf", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
||||
NewCSRF("salt", "", "csrf", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||
}
|
||||
|
||||
func TestCSRF_NewCSRF_SaltAndSessionNotEmpty(t *testing.T) {
|
||||
store := sessions.NewCookieStore([]byte("keys"))
|
||||
csrf := NewCSRF("salt", "secret", "", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
||||
csrf := NewCSRF("salt", "secret", "", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||
assert.NotEmpty(t, csrf.salt)
|
||||
assert.NotEmpty(t, csrf.sessionName)
|
||||
}
|
||||
|
||||
func TestCSRF_GetCSRFErrorMessage(t *testing.T) {
|
||||
items := map[CSRFErrorReason]string{
|
||||
CSRFErrorNoTokenInSession: "token is not present in session",
|
||||
CSRFErrorCannotStoreTokenInSession: "cannot store token in session",
|
||||
CSRFErrorIncorrectTokenType: "incorrect token type",
|
||||
CSRFErrorEmptyToken: "empty token present in session",
|
||||
CSRFErrorTokenMismatch: "token mismatch",
|
||||
99: "unknown error",
|
||||
}
|
||||
|
||||
for reason, message := range items {
|
||||
assert.Equal(t, message, GetCSRFErrorMessage(reason))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRF_Suite(t *testing.T) {
|
||||
suite.Run(t, new(CSRFTest))
|
||||
}
|
||||
|
||||
func (x *CSRFTest) SetupSuite() {
|
||||
store := sessions.NewCookieStore([]byte("keys"))
|
||||
x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context) {
|
||||
x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context, r CSRFErrorReason) {
|
||||
context.AbortWithStatus(900)
|
||||
}, DefaultCSRFTokenGetter)
|
||||
}
|
||||
|
@ -194,7 +194,7 @@ func (e *Engine) WithFilesystemSessions(path string, keyLength ...int) *Engine {
|
||||
// InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized,
|
||||
// use engine.WithCookieStore or engine.WithFilesystemStore for that.
|
||||
// Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt.
|
||||
func (e *Engine) InitCSRF(secret string, abortFunc gin.HandlerFunc, getter CSRFTokenGetter) *Engine {
|
||||
func (e *Engine) InitCSRF(secret string, abortFunc CSRFAbortFunc, getter CSRFTokenGetter) *Engine {
|
||||
if e.Sessions == nil {
|
||||
panic("engine.Sessions must be initialized first")
|
||||
}
|
||||
|
@ -204,7 +204,7 @@ func (e *EngineTest) Test_InitCSRF_Fail() {
|
||||
|
||||
e.engine.csrf = nil
|
||||
e.engine.Sessions = nil
|
||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||
assert.Nil(e.T(), e.engine.csrf)
|
||||
}
|
||||
|
||||
@ -215,7 +215,7 @@ func (e *EngineTest) Test_InitCSRF() {
|
||||
|
||||
e.engine.csrf = nil
|
||||
e.engine.WithCookieSessions(4)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||
assert.NotNil(e.T(), e.engine.csrf)
|
||||
}
|
||||
|
||||
@ -235,7 +235,7 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware() {
|
||||
|
||||
e.engine.csrf = nil
|
||||
e.engine.WithCookieSessions(4)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||
e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods)
|
||||
}
|
||||
|
||||
@ -255,7 +255,7 @@ func (e *EngineTest) Test_GenerateCSRFMiddleware() {
|
||||
|
||||
e.engine.csrf = nil
|
||||
e.engine.WithCookieSessions(4)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||
e.engine.GenerateCSRFMiddleware()
|
||||
}
|
||||
|
||||
@ -284,7 +284,7 @@ func (e *EngineTest) Test_GetCSRFToken() {
|
||||
|
||||
e.engine.csrf = nil
|
||||
e.engine.WithCookieSessions(4)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
||||
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||
assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c))
|
||||
assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c))
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user