pass error reason to csrf abort handler

This commit is contained in:
Pavel 2019-12-12 12:05:48 +03:00
parent 5a497fcc73
commit 20dc1f1aa8
4 changed files with 81 additions and 19 deletions

View File

@ -15,8 +15,31 @@ import (
"github.com/gorilla/sessions"
)
// CSRFErrorReason is a error reason type
type CSRFErrorReason uint8
// CSRFTokenGetter func type
type CSRFTokenGetter func(c *gin.Context) string
type CSRFTokenGetter func(*gin.Context) string
// CSRFAbortFunc is a callback which
type CSRFAbortFunc func(*gin.Context, CSRFErrorReason)
const (
// CSRFErrorNoTokenInSession will be returned if token is not present in session
CSRFErrorNoTokenInSession CSRFErrorReason = iota
// CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session
CSRFErrorCannotStoreTokenInSession
// CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string
CSRFErrorIncorrectTokenType
// CSRFErrorEmptyToken will be returned if token in session is empty
CSRFErrorEmptyToken
// CSRFErrorTokenMismatch will be returned in case of invalid token
CSRFErrorTokenMismatch
)
// DefaultCSRFTokenGetter default getter
var DefaultCSRFTokenGetter = func(c *gin.Context) string {
@ -50,7 +73,7 @@ type CSRF struct {
salt string
secret string
sessionName string
abortFunc gin.HandlerFunc
abortFunc CSRFAbortFunc
csrfTokenGetter CSRFTokenGetter
store sessions.Store
}
@ -60,7 +83,7 @@ type CSRF struct {
// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default,
// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token.
// Usage (with random salt):
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context) {
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) {
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
// }, core.DefaultCSRFTokenGetter)
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data!
@ -70,7 +93,8 @@ type CSRF struct {
// }
// will close body - and all next middlewares won't be able to read body at all!
// Use DefaultCSRFTokenGetter as example to implement your own token getter.
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc gin.HandlerFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
// CSRFErrorReason will be passed to abortFunc and can be used for better error messages.
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
if store == nil {
panic("store must not be nil")
}
@ -178,14 +202,14 @@ func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc {
if i, ok := session.Values["csrf_token"]; ok {
if i, ok := i.(string); !ok || i == "" {
if x.fillToken(session, c) != nil {
x.abortFunc(c)
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
c.Abort()
return
}
}
} else {
if x.fillToken(session, c) != nil {
x.abortFunc(c)
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
c.Abort()
return
}
@ -216,22 +240,45 @@ func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc {
if i, ok := session.Values["csrf_token"]; ok {
var v string
if v, ok = i.(string); !ok || v == "" {
x.abortFunc(c)
if !ok {
x.abortFunc(c, CSRFErrorIncorrectTokenType)
} else if v == "" {
x.abortFunc(c, CSRFErrorEmptyToken)
}
c.Abort()
return
}
token = v
} else {
x.abortFunc(c)
x.abortFunc(c, CSRFErrorNoTokenInSession)
c.Abort()
return
}
if x.csrfTokenGetter(c) != token {
x.abortFunc(c)
x.abortFunc(c, CSRFErrorTokenMismatch)
c.Abort()
return
}
}
}
// GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs)
func GetCSRFErrorMessage(r CSRFErrorReason) string {
switch r {
case CSRFErrorNoTokenInSession:
return "token is not present in session"
case CSRFErrorCannotStoreTokenInSession:
return "cannot store token in session"
case CSRFErrorIncorrectTokenType:
return "incorrect token type"
case CSRFErrorEmptyToken:
return "empty token present in session"
case CSRFErrorTokenMismatch:
return "token mismatch"
default:
return "unknown error"
}
}

View File

@ -101,7 +101,7 @@ func TestCSRF_NewCSRF_NilStore(t *testing.T) {
assert.NotNil(t, recover())
}()
NewCSRF("salt", "secret", "csrf", nil, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
NewCSRF("salt", "secret", "csrf", nil, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
}
func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
@ -110,23 +110,38 @@ func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
}()
store := sessions.NewCookieStore([]byte("keys"))
NewCSRF("salt", "", "csrf", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
NewCSRF("salt", "", "csrf", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
}
func TestCSRF_NewCSRF_SaltAndSessionNotEmpty(t *testing.T) {
store := sessions.NewCookieStore([]byte("keys"))
csrf := NewCSRF("salt", "secret", "", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
csrf := NewCSRF("salt", "secret", "", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
assert.NotEmpty(t, csrf.salt)
assert.NotEmpty(t, csrf.sessionName)
}
func TestCSRF_GetCSRFErrorMessage(t *testing.T) {
items := map[CSRFErrorReason]string{
CSRFErrorNoTokenInSession: "token is not present in session",
CSRFErrorCannotStoreTokenInSession: "cannot store token in session",
CSRFErrorIncorrectTokenType: "incorrect token type",
CSRFErrorEmptyToken: "empty token present in session",
CSRFErrorTokenMismatch: "token mismatch",
99: "unknown error",
}
for reason, message := range items {
assert.Equal(t, message, GetCSRFErrorMessage(reason))
}
}
func TestCSRF_Suite(t *testing.T) {
suite.Run(t, new(CSRFTest))
}
func (x *CSRFTest) SetupSuite() {
store := sessions.NewCookieStore([]byte("keys"))
x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context) {
x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context, r CSRFErrorReason) {
context.AbortWithStatus(900)
}, DefaultCSRFTokenGetter)
}

View File

@ -194,7 +194,7 @@ func (e *Engine) WithFilesystemSessions(path string, keyLength ...int) *Engine {
// InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized,
// use engine.WithCookieStore or engine.WithFilesystemStore for that.
// Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt.
func (e *Engine) InitCSRF(secret string, abortFunc gin.HandlerFunc, getter CSRFTokenGetter) *Engine {
func (e *Engine) InitCSRF(secret string, abortFunc CSRFAbortFunc, getter CSRFTokenGetter) *Engine {
if e.Sessions == nil {
panic("engine.Sessions must be initialized first")
}

View File

@ -204,7 +204,7 @@ func (e *EngineTest) Test_InitCSRF_Fail() {
e.engine.csrf = nil
e.engine.Sessions = nil
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
assert.Nil(e.T(), e.engine.csrf)
}
@ -215,7 +215,7 @@ func (e *EngineTest) Test_InitCSRF() {
e.engine.csrf = nil
e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
assert.NotNil(e.T(), e.engine.csrf)
}
@ -235,7 +235,7 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware() {
e.engine.csrf = nil
e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods)
}
@ -255,7 +255,7 @@ func (e *EngineTest) Test_GenerateCSRFMiddleware() {
e.engine.csrf = nil
e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
e.engine.GenerateCSRFMiddleware()
}
@ -284,7 +284,7 @@ func (e *EngineTest) Test_GetCSRFToken() {
e.engine.csrf = nil
e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c))
assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c))
}