pass error reason to csrf abort handler

This commit is contained in:
Pavel 2019-12-12 12:05:48 +03:00
parent 5a497fcc73
commit 20dc1f1aa8
4 changed files with 81 additions and 19 deletions

View File

@ -15,8 +15,31 @@ import (
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
) )
// CSRFErrorReason is a error reason type
type CSRFErrorReason uint8
// CSRFTokenGetter func type // CSRFTokenGetter func type
type CSRFTokenGetter func(c *gin.Context) string type CSRFTokenGetter func(*gin.Context) string
// CSRFAbortFunc is a callback which
type CSRFAbortFunc func(*gin.Context, CSRFErrorReason)
const (
// CSRFErrorNoTokenInSession will be returned if token is not present in session
CSRFErrorNoTokenInSession CSRFErrorReason = iota
// CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session
CSRFErrorCannotStoreTokenInSession
// CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string
CSRFErrorIncorrectTokenType
// CSRFErrorEmptyToken will be returned if token in session is empty
CSRFErrorEmptyToken
// CSRFErrorTokenMismatch will be returned in case of invalid token
CSRFErrorTokenMismatch
)
// DefaultCSRFTokenGetter default getter // DefaultCSRFTokenGetter default getter
var DefaultCSRFTokenGetter = func(c *gin.Context) string { var DefaultCSRFTokenGetter = func(c *gin.Context) string {
@ -50,7 +73,7 @@ type CSRF struct {
salt string salt string
secret string secret string
sessionName string sessionName string
abortFunc gin.HandlerFunc abortFunc CSRFAbortFunc
csrfTokenGetter CSRFTokenGetter csrfTokenGetter CSRFTokenGetter
store sessions.Store store sessions.Store
} }
@ -60,7 +83,7 @@ type CSRF struct {
// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default, // Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default,
// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token. // store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token.
// Usage (with random salt): // Usage (with random salt):
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context) { // core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) {
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"}) // c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
// }, core.DefaultCSRFTokenGetter) // }, core.DefaultCSRFTokenGetter)
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data! // Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data!
@ -70,7 +93,8 @@ type CSRF struct {
// } // }
// will close body - and all next middlewares won't be able to read body at all! // will close body - and all next middlewares won't be able to read body at all!
// Use DefaultCSRFTokenGetter as example to implement your own token getter. // Use DefaultCSRFTokenGetter as example to implement your own token getter.
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc gin.HandlerFunc, csrfTokenGetter CSRFTokenGetter) *CSRF { // CSRFErrorReason will be passed to abortFunc and can be used for better error messages.
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
if store == nil { if store == nil {
panic("store must not be nil") panic("store must not be nil")
} }
@ -178,14 +202,14 @@ func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc {
if i, ok := session.Values["csrf_token"]; ok { if i, ok := session.Values["csrf_token"]; ok {
if i, ok := i.(string); !ok || i == "" { if i, ok := i.(string); !ok || i == "" {
if x.fillToken(session, c) != nil { if x.fillToken(session, c) != nil {
x.abortFunc(c) x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
c.Abort() c.Abort()
return return
} }
} }
} else { } else {
if x.fillToken(session, c) != nil { if x.fillToken(session, c) != nil {
x.abortFunc(c) x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
c.Abort() c.Abort()
return return
} }
@ -216,22 +240,45 @@ func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc {
if i, ok := session.Values["csrf_token"]; ok { if i, ok := session.Values["csrf_token"]; ok {
var v string var v string
if v, ok = i.(string); !ok || v == "" { if v, ok = i.(string); !ok || v == "" {
x.abortFunc(c) if !ok {
x.abortFunc(c, CSRFErrorIncorrectTokenType)
} else if v == "" {
x.abortFunc(c, CSRFErrorEmptyToken)
}
c.Abort() c.Abort()
return return
} }
token = v token = v
} else { } else {
x.abortFunc(c) x.abortFunc(c, CSRFErrorNoTokenInSession)
c.Abort() c.Abort()
return return
} }
if x.csrfTokenGetter(c) != token { if x.csrfTokenGetter(c) != token {
x.abortFunc(c) x.abortFunc(c, CSRFErrorTokenMismatch)
c.Abort() c.Abort()
return return
} }
} }
} }
// GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs)
func GetCSRFErrorMessage(r CSRFErrorReason) string {
switch r {
case CSRFErrorNoTokenInSession:
return "token is not present in session"
case CSRFErrorCannotStoreTokenInSession:
return "cannot store token in session"
case CSRFErrorIncorrectTokenType:
return "incorrect token type"
case CSRFErrorEmptyToken:
return "empty token present in session"
case CSRFErrorTokenMismatch:
return "token mismatch"
default:
return "unknown error"
}
}

View File

@ -101,7 +101,7 @@ func TestCSRF_NewCSRF_NilStore(t *testing.T) {
assert.NotNil(t, recover()) assert.NotNil(t, recover())
}() }()
NewCSRF("salt", "secret", "csrf", nil, func(context *gin.Context) {}, DefaultCSRFTokenGetter) NewCSRF("salt", "secret", "csrf", nil, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
} }
func TestCSRF_NewCSRF_EmptySecret(t *testing.T) { func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
@ -110,23 +110,38 @@ func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
}() }()
store := sessions.NewCookieStore([]byte("keys")) store := sessions.NewCookieStore([]byte("keys"))
NewCSRF("salt", "", "csrf", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter) NewCSRF("salt", "", "csrf", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
} }
func TestCSRF_NewCSRF_SaltAndSessionNotEmpty(t *testing.T) { func TestCSRF_NewCSRF_SaltAndSessionNotEmpty(t *testing.T) {
store := sessions.NewCookieStore([]byte("keys")) store := sessions.NewCookieStore([]byte("keys"))
csrf := NewCSRF("salt", "secret", "", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter) csrf := NewCSRF("salt", "secret", "", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
assert.NotEmpty(t, csrf.salt) assert.NotEmpty(t, csrf.salt)
assert.NotEmpty(t, csrf.sessionName) assert.NotEmpty(t, csrf.sessionName)
} }
func TestCSRF_GetCSRFErrorMessage(t *testing.T) {
items := map[CSRFErrorReason]string{
CSRFErrorNoTokenInSession: "token is not present in session",
CSRFErrorCannotStoreTokenInSession: "cannot store token in session",
CSRFErrorIncorrectTokenType: "incorrect token type",
CSRFErrorEmptyToken: "empty token present in session",
CSRFErrorTokenMismatch: "token mismatch",
99: "unknown error",
}
for reason, message := range items {
assert.Equal(t, message, GetCSRFErrorMessage(reason))
}
}
func TestCSRF_Suite(t *testing.T) { func TestCSRF_Suite(t *testing.T) {
suite.Run(t, new(CSRFTest)) suite.Run(t, new(CSRFTest))
} }
func (x *CSRFTest) SetupSuite() { func (x *CSRFTest) SetupSuite() {
store := sessions.NewCookieStore([]byte("keys")) store := sessions.NewCookieStore([]byte("keys"))
x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context) { x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context, r CSRFErrorReason) {
context.AbortWithStatus(900) context.AbortWithStatus(900)
}, DefaultCSRFTokenGetter) }, DefaultCSRFTokenGetter)
} }

View File

@ -194,7 +194,7 @@ func (e *Engine) WithFilesystemSessions(path string, keyLength ...int) *Engine {
// InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized, // InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized,
// use engine.WithCookieStore or engine.WithFilesystemStore for that. // use engine.WithCookieStore or engine.WithFilesystemStore for that.
// Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt. // Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt.
func (e *Engine) InitCSRF(secret string, abortFunc gin.HandlerFunc, getter CSRFTokenGetter) *Engine { func (e *Engine) InitCSRF(secret string, abortFunc CSRFAbortFunc, getter CSRFTokenGetter) *Engine {
if e.Sessions == nil { if e.Sessions == nil {
panic("engine.Sessions must be initialized first") panic("engine.Sessions must be initialized first")
} }

View File

@ -204,7 +204,7 @@ func (e *EngineTest) Test_InitCSRF_Fail() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.Sessions = nil e.engine.Sessions = nil
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
assert.Nil(e.T(), e.engine.csrf) assert.Nil(e.T(), e.engine.csrf)
} }
@ -215,7 +215,7 @@ func (e *EngineTest) Test_InitCSRF() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.WithCookieSessions(4) e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
assert.NotNil(e.T(), e.engine.csrf) assert.NotNil(e.T(), e.engine.csrf)
} }
@ -235,7 +235,7 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.WithCookieSessions(4) e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods) e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods)
} }
@ -255,7 +255,7 @@ func (e *EngineTest) Test_GenerateCSRFMiddleware() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.WithCookieSessions(4) e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
e.engine.GenerateCSRFMiddleware() e.engine.GenerateCSRFMiddleware()
} }
@ -284,7 +284,7 @@ func (e *EngineTest) Test_GetCSRFToken() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.WithCookieSessions(4) e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c)) assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c))
assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c)) assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c))
} }