package core import ( "bytes" // nolint:gosec "crypto/sha1" "encoding/base64" "io" "io/ioutil" "math/rand" "time" "github.com/gin-gonic/gin" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" ) // CSRFErrorReason is a error reason type. type CSRFErrorReason uint8 // CSRFTokenGetter func type. type CSRFTokenGetter func(*gin.Context) string // CSRFAbortFunc is a callback which. type CSRFAbortFunc func(*gin.Context, CSRFErrorReason) const ( // CSRFErrorNoTokenInSession will be returned if token is not present in session. CSRFErrorNoTokenInSession CSRFErrorReason = iota // CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session. CSRFErrorCannotStoreTokenInSession // CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string. CSRFErrorIncorrectTokenType // CSRFErrorEmptyToken will be returned if token in session is empty. CSRFErrorEmptyToken // CSRFErrorTokenMismatch will be returned in case of invalid token. CSRFErrorTokenMismatch ) // DefaultCSRFTokenGetter default getter. var DefaultCSRFTokenGetter = func(c *gin.Context) string { r := c.Request if t := r.URL.Query().Get("csrf_token"); len(t) > 0 { return t } else if t := r.Header.Get("X-CSRF-Token"); len(t) > 0 { return t } else if t := r.Header.Get("X-XSRF-Token"); len(t) > 0 { return t } else if c.Request.Body != nil { data, _ := ioutil.ReadAll(c.Request.Body) c.Request.Body = ioutil.NopCloser(bytes.NewReader(data)) t := r.FormValue("csrf_token") c.Request.Body = ioutil.NopCloser(bytes.NewReader(data)) if len(t) > 0 { return t } } return "" } // DefaultIgnoredMethods ignored methods for CSRF verifier middleware. var DefaultIgnoredMethods = []string{"GET", "HEAD", "OPTIONS"} // CSRF struct. Provides CSRF token verification. type CSRF struct { salt string secret string sessionName string abortFunc CSRFAbortFunc csrfTokenGetter CSRFTokenGetter store sessions.Store } // NewCSRF creates CSRF struct with specified configuration and session store. // GenerateCSRFMiddleware and VerifyCSRFMiddleware returns CSRF middlewares. // Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default, // store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token. // Usage (with random salt): // core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) { // c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"}) // }, core.DefaultCSRFTokenGetter) // Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data! // Body in http.Request is io.ReadCloser instance. Reading CSRF token from form like that: // if t := r.FormValue("csrf_token"); len(t) > 0 { // return t // } // will close body - and all next middlewares won't be able to read body at all! // Use DefaultCSRFTokenGetter as example to implement your own token getter. // CSRFErrorReason will be passed to abortFunc and can be used for better error messages. func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF { if store == nil { panic("store must not be nil") } if secret == "" { panic("at least secret must be provided") } csrf := &CSRF{ store: store, secret: secret, abortFunc: abortFunc, csrfTokenGetter: csrfTokenGetter, } if salt == "" { salt = csrf.generateSalt() } if sessionName == "" { sessionName = "csrf_token_session" } csrf.salt = salt csrf.sessionName = sessionName return csrf } // strInSlice checks whether string exists in slice. func (x *CSRF) strInSlice(slice []string, v string) bool { exists := false for _, i := range slice { if i == v { exists = true break } } return exists } // generateCSRFToken generates new CSRF token. func (x *CSRF) generateCSRFToken() string { // nolint:gosec h := sha1.New() // Fallback to less secure method - token must be always filled even if we cannot properly generate it if _, err := io.WriteString(h, x.salt+"#"+x.secret); err != nil { return base64.URLEncoding.EncodeToString([]byte(time.Now().String())) } hash := base64.URLEncoding.EncodeToString(h.Sum(nil)) return hash } // generateSalt generates salt from random bytes. If it fails to generate cryptographically // secure salt - it will generate pseudo-random, weaker salt. // It will be used automatically if no salt provided. // Default secure salt length: 8 bytes. // Default pseudo-random salt length: 64 bytes. func (x *CSRF) generateSalt() string { salt := securecookie.GenerateRandomKey(8) if salt == nil { return x.pseudoRandomString(64) } return string(salt) } // pseudoRandomString generates pseudo-random string with specified length. func (x *CSRF) pseudoRandomString(length int) string { rand.Seed(time.Now().UnixNano()) data := make([]byte, length) for i := 0; i < length; i++ { data[i] = byte(65 + rand.Intn(90-65)) } return string(data) } // CSRFFromContext returns csrf token or random token. It shouldn't return empty string because it will make csrf protection useless. // e.g. any request without token will work fine, which is inacceptable. func (x *CSRF) CSRFFromContext(c *gin.Context) string { if i, ok := c.Get("csrf_token"); ok { if token, ok := i.(string); ok { return token } } return x.generateCSRFToken() } // GenerateCSRFMiddleware returns gin.HandlerFunc which will generate CSRF token // Usage: // engine := gin.New() // csrf := NewCSRF("salt", "secret", "not_found", "incorrect", localizer) // engine.Use(csrf.GenerateCSRFMiddleware()) func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc { return func(c *gin.Context) { session, _ := x.store.Get(c.Request, x.sessionName) if i, ok := session.Values["csrf_token"]; ok { if i, ok := i.(string); !ok || i == "" { if x.fillToken(session, c) != nil { x.abortFunc(c, CSRFErrorCannotStoreTokenInSession) c.Abort() return } } } else { if x.fillToken(session, c) != nil { x.abortFunc(c, CSRFErrorCannotStoreTokenInSession) c.Abort() return } } } } // fillToken stores token in session and context. func (x *CSRF) fillToken(s *sessions.Session, c *gin.Context) error { s.Values["csrf_token"] = x.generateCSRFToken() c.Set("csrf_token", s.Values["csrf_token"]) return s.Save(c.Request, c.Writer) } // VerifyCSRFMiddleware verifies CSRF token // Usage: // engine := gin.New() // engine.Use(csrf.VerifyCSRFMiddleware()) func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc { return func(c *gin.Context) { if x.strInSlice(ignoredMethods, c.Request.Method) { return } var token string session, _ := x.store.Get(c.Request, x.sessionName) if i, ok := session.Values["csrf_token"]; ok { var v string if v, ok = i.(string); !ok || v == "" { if !ok { x.abortFunc(c, CSRFErrorIncorrectTokenType) } else if v == "" { x.abortFunc(c, CSRFErrorEmptyToken) } c.Abort() return } token = v } else { x.abortFunc(c, CSRFErrorNoTokenInSession) c.Abort() return } if x.csrfTokenGetter(c) != token { x.abortFunc(c, CSRFErrorTokenMismatch) c.Abort() return } } } // GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs). func GetCSRFErrorMessage(r CSRFErrorReason) string { switch r { case CSRFErrorNoTokenInSession: return "token is not present in session" case CSRFErrorCannotStoreTokenInSession: return "cannot store token in session" case CSRFErrorIncorrectTokenType: return "incorrect token type" case CSRFErrorEmptyToken: return "empty token present in session" case CSRFErrorTokenMismatch: return "token mismatch" default: return "unknown error" } }