mirror of
https://github.com/retailcrm/mg-transport-core.git
synced 2024-11-25 06:36:03 +03:00
pass error reason to csrf abort handler
This commit is contained in:
parent
5a497fcc73
commit
20dc1f1aa8
65
core/csrf.go
65
core/csrf.go
@ -15,8 +15,31 @@ import (
|
|||||||
"github.com/gorilla/sessions"
|
"github.com/gorilla/sessions"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// CSRFErrorReason is a error reason type
|
||||||
|
type CSRFErrorReason uint8
|
||||||
|
|
||||||
// CSRFTokenGetter func type
|
// CSRFTokenGetter func type
|
||||||
type CSRFTokenGetter func(c *gin.Context) string
|
type CSRFTokenGetter func(*gin.Context) string
|
||||||
|
|
||||||
|
// CSRFAbortFunc is a callback which
|
||||||
|
type CSRFAbortFunc func(*gin.Context, CSRFErrorReason)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CSRFErrorNoTokenInSession will be returned if token is not present in session
|
||||||
|
CSRFErrorNoTokenInSession CSRFErrorReason = iota
|
||||||
|
|
||||||
|
// CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session
|
||||||
|
CSRFErrorCannotStoreTokenInSession
|
||||||
|
|
||||||
|
// CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string
|
||||||
|
CSRFErrorIncorrectTokenType
|
||||||
|
|
||||||
|
// CSRFErrorEmptyToken will be returned if token in session is empty
|
||||||
|
CSRFErrorEmptyToken
|
||||||
|
|
||||||
|
// CSRFErrorTokenMismatch will be returned in case of invalid token
|
||||||
|
CSRFErrorTokenMismatch
|
||||||
|
)
|
||||||
|
|
||||||
// DefaultCSRFTokenGetter default getter
|
// DefaultCSRFTokenGetter default getter
|
||||||
var DefaultCSRFTokenGetter = func(c *gin.Context) string {
|
var DefaultCSRFTokenGetter = func(c *gin.Context) string {
|
||||||
@ -50,7 +73,7 @@ type CSRF struct {
|
|||||||
salt string
|
salt string
|
||||||
secret string
|
secret string
|
||||||
sessionName string
|
sessionName string
|
||||||
abortFunc gin.HandlerFunc
|
abortFunc CSRFAbortFunc
|
||||||
csrfTokenGetter CSRFTokenGetter
|
csrfTokenGetter CSRFTokenGetter
|
||||||
store sessions.Store
|
store sessions.Store
|
||||||
}
|
}
|
||||||
@ -60,7 +83,7 @@ type CSRF struct {
|
|||||||
// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default,
|
// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default,
|
||||||
// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token.
|
// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token.
|
||||||
// Usage (with random salt):
|
// Usage (with random salt):
|
||||||
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context) {
|
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) {
|
||||||
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
|
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
|
||||||
// }, core.DefaultCSRFTokenGetter)
|
// }, core.DefaultCSRFTokenGetter)
|
||||||
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data!
|
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data!
|
||||||
@ -70,7 +93,8 @@ type CSRF struct {
|
|||||||
// }
|
// }
|
||||||
// will close body - and all next middlewares won't be able to read body at all!
|
// will close body - and all next middlewares won't be able to read body at all!
|
||||||
// Use DefaultCSRFTokenGetter as example to implement your own token getter.
|
// Use DefaultCSRFTokenGetter as example to implement your own token getter.
|
||||||
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc gin.HandlerFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
|
// CSRFErrorReason will be passed to abortFunc and can be used for better error messages.
|
||||||
|
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
|
||||||
if store == nil {
|
if store == nil {
|
||||||
panic("store must not be nil")
|
panic("store must not be nil")
|
||||||
}
|
}
|
||||||
@ -178,14 +202,14 @@ func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc {
|
|||||||
if i, ok := session.Values["csrf_token"]; ok {
|
if i, ok := session.Values["csrf_token"]; ok {
|
||||||
if i, ok := i.(string); !ok || i == "" {
|
if i, ok := i.(string); !ok || i == "" {
|
||||||
if x.fillToken(session, c) != nil {
|
if x.fillToken(session, c) != nil {
|
||||||
x.abortFunc(c)
|
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if x.fillToken(session, c) != nil {
|
if x.fillToken(session, c) != nil {
|
||||||
x.abortFunc(c)
|
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -216,22 +240,45 @@ func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc {
|
|||||||
if i, ok := session.Values["csrf_token"]; ok {
|
if i, ok := session.Values["csrf_token"]; ok {
|
||||||
var v string
|
var v string
|
||||||
if v, ok = i.(string); !ok || v == "" {
|
if v, ok = i.(string); !ok || v == "" {
|
||||||
x.abortFunc(c)
|
if !ok {
|
||||||
|
x.abortFunc(c, CSRFErrorIncorrectTokenType)
|
||||||
|
} else if v == "" {
|
||||||
|
x.abortFunc(c, CSRFErrorEmptyToken)
|
||||||
|
}
|
||||||
|
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
token = v
|
token = v
|
||||||
} else {
|
} else {
|
||||||
x.abortFunc(c)
|
x.abortFunc(c, CSRFErrorNoTokenInSession)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if x.csrfTokenGetter(c) != token {
|
if x.csrfTokenGetter(c) != token {
|
||||||
x.abortFunc(c)
|
x.abortFunc(c, CSRFErrorTokenMismatch)
|
||||||
c.Abort()
|
c.Abort()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs)
|
||||||
|
func GetCSRFErrorMessage(r CSRFErrorReason) string {
|
||||||
|
switch r {
|
||||||
|
case CSRFErrorNoTokenInSession:
|
||||||
|
return "token is not present in session"
|
||||||
|
case CSRFErrorCannotStoreTokenInSession:
|
||||||
|
return "cannot store token in session"
|
||||||
|
case CSRFErrorIncorrectTokenType:
|
||||||
|
return "incorrect token type"
|
||||||
|
case CSRFErrorEmptyToken:
|
||||||
|
return "empty token present in session"
|
||||||
|
case CSRFErrorTokenMismatch:
|
||||||
|
return "token mismatch"
|
||||||
|
default:
|
||||||
|
return "unknown error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -101,7 +101,7 @@ func TestCSRF_NewCSRF_NilStore(t *testing.T) {
|
|||||||
assert.NotNil(t, recover())
|
assert.NotNil(t, recover())
|
||||||
}()
|
}()
|
||||||
|
|
||||||
NewCSRF("salt", "secret", "csrf", nil, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
NewCSRF("salt", "secret", "csrf", nil, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
|
func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
|
||||||
@ -110,23 +110,38 @@ func TestCSRF_NewCSRF_EmptySecret(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
store := sessions.NewCookieStore([]byte("keys"))
|
store := sessions.NewCookieStore([]byte("keys"))
|
||||||
NewCSRF("salt", "", "csrf", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
NewCSRF("salt", "", "csrf", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRF_NewCSRF_SaltAndSessionNotEmpty(t *testing.T) {
|
func TestCSRF_NewCSRF_SaltAndSessionNotEmpty(t *testing.T) {
|
||||||
store := sessions.NewCookieStore([]byte("keys"))
|
store := sessions.NewCookieStore([]byte("keys"))
|
||||||
csrf := NewCSRF("salt", "secret", "", store, func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
csrf := NewCSRF("salt", "secret", "", store, func(c *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||||
assert.NotEmpty(t, csrf.salt)
|
assert.NotEmpty(t, csrf.salt)
|
||||||
assert.NotEmpty(t, csrf.sessionName)
|
assert.NotEmpty(t, csrf.sessionName)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCSRF_GetCSRFErrorMessage(t *testing.T) {
|
||||||
|
items := map[CSRFErrorReason]string{
|
||||||
|
CSRFErrorNoTokenInSession: "token is not present in session",
|
||||||
|
CSRFErrorCannotStoreTokenInSession: "cannot store token in session",
|
||||||
|
CSRFErrorIncorrectTokenType: "incorrect token type",
|
||||||
|
CSRFErrorEmptyToken: "empty token present in session",
|
||||||
|
CSRFErrorTokenMismatch: "token mismatch",
|
||||||
|
99: "unknown error",
|
||||||
|
}
|
||||||
|
|
||||||
|
for reason, message := range items {
|
||||||
|
assert.Equal(t, message, GetCSRFErrorMessage(reason))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestCSRF_Suite(t *testing.T) {
|
func TestCSRF_Suite(t *testing.T) {
|
||||||
suite.Run(t, new(CSRFTest))
|
suite.Run(t, new(CSRFTest))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (x *CSRFTest) SetupSuite() {
|
func (x *CSRFTest) SetupSuite() {
|
||||||
store := sessions.NewCookieStore([]byte("keys"))
|
store := sessions.NewCookieStore([]byte("keys"))
|
||||||
x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context) {
|
x.csrf = NewCSRF("salt", "secret", "", store, func(context *gin.Context, r CSRFErrorReason) {
|
||||||
context.AbortWithStatus(900)
|
context.AbortWithStatus(900)
|
||||||
}, DefaultCSRFTokenGetter)
|
}, DefaultCSRFTokenGetter)
|
||||||
}
|
}
|
||||||
|
@ -194,7 +194,7 @@ func (e *Engine) WithFilesystemSessions(path string, keyLength ...int) *Engine {
|
|||||||
// InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized,
|
// InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized,
|
||||||
// use engine.WithCookieStore or engine.WithFilesystemStore for that.
|
// use engine.WithCookieStore or engine.WithFilesystemStore for that.
|
||||||
// Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt.
|
// Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt.
|
||||||
func (e *Engine) InitCSRF(secret string, abortFunc gin.HandlerFunc, getter CSRFTokenGetter) *Engine {
|
func (e *Engine) InitCSRF(secret string, abortFunc CSRFAbortFunc, getter CSRFTokenGetter) *Engine {
|
||||||
if e.Sessions == nil {
|
if e.Sessions == nil {
|
||||||
panic("engine.Sessions must be initialized first")
|
panic("engine.Sessions must be initialized first")
|
||||||
}
|
}
|
||||||
|
@ -204,7 +204,7 @@ func (e *EngineTest) Test_InitCSRF_Fail() {
|
|||||||
|
|
||||||
e.engine.csrf = nil
|
e.engine.csrf = nil
|
||||||
e.engine.Sessions = nil
|
e.engine.Sessions = nil
|
||||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||||
assert.Nil(e.T(), e.engine.csrf)
|
assert.Nil(e.T(), e.engine.csrf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -215,7 +215,7 @@ func (e *EngineTest) Test_InitCSRF() {
|
|||||||
|
|
||||||
e.engine.csrf = nil
|
e.engine.csrf = nil
|
||||||
e.engine.WithCookieSessions(4)
|
e.engine.WithCookieSessions(4)
|
||||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||||
assert.NotNil(e.T(), e.engine.csrf)
|
assert.NotNil(e.T(), e.engine.csrf)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware() {
|
|||||||
|
|
||||||
e.engine.csrf = nil
|
e.engine.csrf = nil
|
||||||
e.engine.WithCookieSessions(4)
|
e.engine.WithCookieSessions(4)
|
||||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||||
e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods)
|
e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -255,7 +255,7 @@ func (e *EngineTest) Test_GenerateCSRFMiddleware() {
|
|||||||
|
|
||||||
e.engine.csrf = nil
|
e.engine.csrf = nil
|
||||||
e.engine.WithCookieSessions(4)
|
e.engine.WithCookieSessions(4)
|
||||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||||
e.engine.GenerateCSRFMiddleware()
|
e.engine.GenerateCSRFMiddleware()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,7 +284,7 @@ func (e *EngineTest) Test_GetCSRFToken() {
|
|||||||
|
|
||||||
e.engine.csrf = nil
|
e.engine.csrf = nil
|
||||||
e.engine.WithCookieSessions(4)
|
e.engine.WithCookieSessions(4)
|
||||||
e.engine.InitCSRF("test", func(context *gin.Context) {}, DefaultCSRFTokenGetter)
|
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter)
|
||||||
assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c))
|
assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c))
|
||||||
assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c))
|
assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c))
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user