diff --git a/core/csrf.go b/core/csrf.go index 0f423d6..669e051 100644 --- a/core/csrf.go +++ b/core/csrf.go @@ -15,8 +15,31 @@ import ( "github.com/gorilla/sessions" ) +// CSRFErrorReason is a error reason type +type CSRFErrorReason uint8 + // CSRFTokenGetter func type -type CSRFTokenGetter func(c *gin.Context) string +type CSRFTokenGetter func(*gin.Context) string + +// CSRFAbortFunc is a callback which +type CSRFAbortFunc func(*gin.Context, CSRFErrorReason) + +const ( + // CSRFErrorNoTokenInSession will be returned if token is not present in session + CSRFErrorNoTokenInSession CSRFErrorReason = iota + + // CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session + CSRFErrorCannotStoreTokenInSession + + // CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string + CSRFErrorIncorrectTokenType + + // CSRFErrorEmptyToken will be returned if token in session is empty + CSRFErrorEmptyToken + + // CSRFErrorTokenMismatch will be returned in case of invalid token + CSRFErrorTokenMismatch +) // DefaultCSRFTokenGetter default getter var DefaultCSRFTokenGetter = func(c *gin.Context) string { @@ -50,7 +73,7 @@ type CSRF struct { salt string secret string sessionName string - abortFunc gin.HandlerFunc + abortFunc CSRFAbortFunc csrfTokenGetter CSRFTokenGetter store sessions.Store } @@ -60,7 +83,7 @@ type CSRF struct { // Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default, // store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token. // Usage (with random salt): -// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context) { +// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) { // c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"}) // }, core.DefaultCSRFTokenGetter) // Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data! @@ -70,7 +93,8 @@ type CSRF struct { // } // will close body - and all next middlewares won't be able to read body at all! // Use DefaultCSRFTokenGetter as example to implement your own token getter. -func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc gin.HandlerFunc, csrfTokenGetter CSRFTokenGetter) *CSRF { +// CSRFErrorReason will be passed to abortFunc and can be used for better error messages. +func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF { if store == nil { panic("store must not be nil") } @@ -178,14 +202,14 @@ func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc { if i, ok := session.Values["csrf_token"]; ok { if i, ok := i.(string); !ok || i == "" { if x.fillToken(session, c) != nil { - x.abortFunc(c) + x.abortFunc(c, CSRFErrorCannotStoreTokenInSession) c.Abort() return } } } else { if x.fillToken(session, c) != nil { - x.abortFunc(c) + x.abortFunc(c, CSRFErrorCannotStoreTokenInSession) c.Abort() return } @@ -216,22 +240,45 @@ func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc { if i, ok := session.Values["csrf_token"]; ok { var v string if v, ok = i.(string); !ok || v == "" { - x.abortFunc(c) + if !ok { + x.abortFunc(c, CSRFErrorIncorrectTokenType) + } else if v == "" { + x.abortFunc(c, CSRFErrorEmptyToken) + } + c.Abort() return } token = v } else { - x.abortFunc(c) + x.abortFunc(c, CSRFErrorNoTokenInSession) c.Abort() return } if x.csrfTokenGetter(c) != token { - x.abortFunc(c) + x.abortFunc(c, CSRFErrorTokenMismatch) c.Abort() return } } } + +// GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs) +func GetCSRFErrorMessage(r CSRFErrorReason) string { + switch r { + case CSRFErrorNoTokenInSession: + return "token is not present in session" + case CSRFErrorCannotStoreTokenInSession: + return "cannot store token in session" + case CSRFErrorIncorrectTokenType: + return "incorrect token type" + case CSRFErrorEmptyToken: + return "empty token present in session" + case CSRFErrorTokenMismatch: + return "token mismatch" + default: + return "unknown error" + } +} diff --git a/core/csrf_test.go b/core/csrf_test.go index ec5556c..28e8467 100644 --- a/core/csrf_test.go +++ b/core/csrf_test.go @@ -101,7 +101,7 @@ func TestCSRF_NewCSRF_NilStore(t *testing.T) { assert.NotNil(t, recover()) }() - NewCSRF("salt", "secret", "csrf", nil, func(context *gin.Context) {}, DefaultCSRFTokenGetter) + NewCSRF("salt", "secret", "csrf", nil, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) } func TestCSRF_NewCSRF_EmptySecret(t *testing.T) { @@ -110,23 +110,38 @@ func TestCSRF_NewCSRF_EmptySecret(t *testing.T) { }() store := sessions.NewCookieStore([]byte("keys")) - NewCSRF("salt", "", "csrf", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter) + NewCSRF("salt", "", "csrf", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) } func TestCSRF_NewCSRF_SaltAndSessionNotEmpty(t *testing.T) { store := sessions.NewCookieStore([]byte("keys")) - csrf := NewCSRF("salt", "secret", "", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter) + csrf := NewCSRF("salt", "secret", "", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) assert.NotEmpty(t, csrf.salt) assert.NotEmpty(t, csrf.sessionName) } +func TestCSRF_GetCSRFErrorMessage(t *testing.T) { + items := map[CSRFErrorReason]string{ + CSRFErrorNoTokenInSession: "token is not present in session", + CSRFErrorCannotStoreTokenInSession: "cannot store token in session", + CSRFErrorIncorrectTokenType: "incorrect token type", + CSRFErrorEmptyToken: "empty token present in session", + CSRFErrorTokenMismatch: "token mismatch", + 99: "unknown error", + } + + for reason, message := range items { + assert.Equal(t, message, GetCSRFErrorMessage(reason)) + } +} + func TestCSRF_Suite(t *testing.T) { suite.Run(t, new(CSRFTest)) } func (x *CSRFTest) SetupSuite() { store := sessions.NewCookieStore([]byte("keys")) - x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context) { + x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context, r CSRFErrorReason) { context.AbortWithStatus(900) }, DefaultCSRFTokenGetter) } diff --git a/core/engine.go b/core/engine.go index 10b6a0a..396d3cd 100644 --- a/core/engine.go +++ b/core/engine.go @@ -194,7 +194,7 @@ func (e *Engine) WithFilesystemSessions(path string, keyLength ...int) *Engine { // InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized, // use engine.WithCookieStore or engine.WithFilesystemStore for that. // Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt. -func (e *Engine) InitCSRF(secret string, abortFunc gin.HandlerFunc, getter CSRFTokenGetter) *Engine { +func (e *Engine) InitCSRF(secret string, abortFunc CSRFAbortFunc, getter CSRFTokenGetter) *Engine { if e.Sessions == nil { panic("engine.Sessions must be initialized first") } diff --git a/core/engine_test.go b/core/engine_test.go index e8de080..422b1f7 100644 --- a/core/engine_test.go +++ b/core/engine_test.go @@ -204,7 +204,7 @@ func (e *EngineTest) Test_InitCSRF_Fail() { e.engine.csrf = nil e.engine.Sessions = nil - e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) assert.Nil(e.T(), e.engine.csrf) } @@ -215,7 +215,7 @@ func (e *EngineTest) Test_InitCSRF() { e.engine.csrf = nil e.engine.WithCookieSessions(4) - e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) assert.NotNil(e.T(), e.engine.csrf) } @@ -235,7 +235,7 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware() { e.engine.csrf = nil e.engine.WithCookieSessions(4) - e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods) } @@ -255,7 +255,7 @@ func (e *EngineTest) Test_GenerateCSRFMiddleware() { e.engine.csrf = nil e.engine.WithCookieSessions(4) - e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) e.engine.GenerateCSRFMiddleware() } @@ -284,7 +284,7 @@ func (e *EngineTest) Test_GetCSRFToken() { e.engine.csrf = nil e.engine.WithCookieSessions(4) - e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c)) assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c)) }