mg-transport-core/core/middleware/csrf.go

302 lines
8.1 KiB
Go
Raw Normal View History

package middleware
2019-10-31 14:21:39 +03:00
import (
"bytes"
2019-12-12 09:35:05 +03:00
// nolint:gosec
2019-10-31 14:21:39 +03:00
"crypto/sha1"
"encoding/base64"
"io"
"io/ioutil"
"math/rand"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
)
// CSRFErrorReason is a error reason type.
type CSRFErrorReason uint8
// CSRFTokenGetter func type.
type CSRFTokenGetter func(*gin.Context) string
// CSRFAbortFunc is a callback which.
type CSRFAbortFunc func(*gin.Context, CSRFErrorReason)
const (
// CSRFErrorNoTokenInSession will be returned if token is not present in session.
CSRFErrorNoTokenInSession CSRFErrorReason = iota
// CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session.
CSRFErrorCannotStoreTokenInSession
// CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string.
CSRFErrorIncorrectTokenType
// CSRFErrorEmptyToken will be returned if token in session is empty.
CSRFErrorEmptyToken
// CSRFErrorTokenMismatch will be returned in case of invalid token.
CSRFErrorTokenMismatch
)
2019-10-31 14:21:39 +03:00
const (
keySize = 8
randomStringSize = 64
)
// DefaultCSRFTokenGetter default getter.
2019-10-31 14:21:39 +03:00
var DefaultCSRFTokenGetter = func(c *gin.Context) string {
r := c.Request
if t := r.URL.Query().Get("csrf_token"); len(t) > 0 {
return t
} else if t := r.Header.Get("X-CSRF-Token"); len(t) > 0 {
return t
} else if t := r.Header.Get("X-XSRF-Token"); len(t) > 0 {
return t
} else if c.Request.Body != nil {
data, _ := ioutil.ReadAll(c.Request.Body)
c.Request.Body = ioutil.NopCloser(bytes.NewReader(data))
t := r.FormValue("csrf_token")
c.Request.Body = ioutil.NopCloser(bytes.NewReader(data))
if len(t) > 0 {
return t
}
}
return ""
}
// DefaultIgnoredMethods ignored methods for CSRF verifier middleware.
2019-10-31 14:21:39 +03:00
var DefaultIgnoredMethods = []string{"GET", "HEAD", "OPTIONS"}
2019-12-12 09:35:05 +03:00
// CSRF struct. Provides CSRF token verification.
2019-10-31 14:21:39 +03:00
type CSRF struct {
store sessions.Store
abortFunc CSRFAbortFunc
csrfTokenGetter CSRFTokenGetter
2019-10-31 14:21:39 +03:00
salt string
secret string
sessionName string
}
// NewCSRF creates CSRF struct with specified configuration and session store.
// GenerateCSRFMiddleware and VerifyCSRFMiddleware returns CSRF middlewares.
// Salt must be different every time (pass empty salt to use random), secret must be provided,
// sessionName is optional - pass empty to use default,
// store will be used to store sessions, abortFunc will be called to return error if token is invalid,
// csrfTokenGetter will be used to obtain token.
//
2019-10-31 14:21:39 +03:00
// Usage (with random salt):
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) {
2019-10-31 14:21:39 +03:00
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
// }, core.DefaultCSRFTokenGetter)
//
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field)
// - don't forget to restore Body data!
//
2019-10-31 14:21:39 +03:00
// Body in http.Request is io.ReadCloser instance. Reading CSRF token from form like that:
// if t := r.FormValue("csrf_token"); len(t) > 0 {
// return t
// }
// will close body - and all next middlewares won't be able to read body at all!
//
2019-10-31 14:21:39 +03:00
// Use DefaultCSRFTokenGetter as example to implement your own token getter.
// CSRFErrorReason will be passed to abortFunc and can be used for better error messages.
func NewCSRF(
salt, secret, sessionName string,
store sessions.Store,
abortFunc CSRFAbortFunc,
csrfTokenGetter CSRFTokenGetter,
) *CSRF {
2019-10-31 14:21:39 +03:00
if store == nil {
panic("store must not be nil")
}
if secret == "" {
panic("at least secret must be provided")
}
csrf := &CSRF{
store: store,
secret: secret,
abortFunc: abortFunc,
csrfTokenGetter: csrfTokenGetter,
}
if salt == "" {
salt = csrf.generateSalt()
}
if sessionName == "" {
sessionName = "csrf_token_session"
}
csrf.salt = salt
csrf.sessionName = sessionName
return csrf
}
// strInSlice checks whether string exists in slice.
2019-10-31 14:21:39 +03:00
func (x *CSRF) strInSlice(slice []string, v string) bool {
exists := false
for _, i := range slice {
if i == v {
exists = true
break
}
}
return exists
}
// generateCSRFToken generates new CSRF token.
2019-10-31 14:21:39 +03:00
func (x *CSRF) generateCSRFToken() string {
2019-12-12 09:35:05 +03:00
// nolint:gosec
2019-10-31 14:21:39 +03:00
h := sha1.New()
2019-12-12 09:35:05 +03:00
// Fallback to less secure method - token must be always filled even if we cannot properly generate it
if _, err := io.WriteString(h, x.salt+"#"+x.secret); err != nil {
return base64.URLEncoding.EncodeToString([]byte(time.Now().String()))
}
2019-10-31 14:21:39 +03:00
hash := base64.URLEncoding.EncodeToString(h.Sum(nil))
return hash
}
// generateSalt generates salt from random bytes. If it fails to generate cryptographically
// secure salt - it will generate pseudo-random, weaker salt.
// It will be used automatically if no salt provided.
// Default secure salt length: 8 bytes.
// Default pseudo-random salt length: 64 bytes.
func (x *CSRF) generateSalt() string {
salt := securecookie.GenerateRandomKey(keySize)
2019-10-31 14:21:39 +03:00
if salt == nil {
return x.pseudoRandomString(randomStringSize)
2019-10-31 14:21:39 +03:00
}
return string(salt)
}
// pseudoRandomString generates pseudo-random string with specified length.
2019-10-31 14:21:39 +03:00
func (x *CSRF) pseudoRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
data := make([]byte, length)
for i := 0; i < length; i++ { // it is supposed to use pseudo-random data.
data[i] = byte(65 + rand.Intn(90-65)) // nolint:gosec,gomnd
2019-10-31 14:21:39 +03:00
}
return string(data)
}
// CSRFFromContext returns csrf token or random token. It shouldn't return empty string because
// it will make csrf protection useless. e.g. any request without token will work fine, which is unacceptable.
2019-10-31 14:21:39 +03:00
func (x *CSRF) CSRFFromContext(c *gin.Context) string {
if i, ok := c.Get("csrf_token"); ok {
if token, ok := i.(string); ok {
return token
}
}
2019-12-12 09:35:05 +03:00
return x.generateCSRFToken()
2019-10-31 14:21:39 +03:00
}
// GenerateCSRFMiddleware returns gin.HandlerFunc which will generate CSRF token
// Usage:
// engine := gin.New()
// csrf := NewCSRF("salt", "secret", "not_found", "incorrect", localizer)
// engine.Use(csrf.GenerateCSRFMiddleware())
func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
session, _ := x.store.Get(c.Request, x.sessionName)
if i, ok := session.Values["csrf_token"]; ok {
if i, ok := i.(string); !ok || i == "" {
if x.fillToken(session, c) != nil {
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
2019-10-31 14:21:39 +03:00
c.Abort()
return
}
}
} else {
if x.fillToken(session, c) != nil {
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
2019-10-31 14:21:39 +03:00
c.Abort()
return
}
}
}
}
// fillToken stores token in session and context.
2019-10-31 14:21:39 +03:00
func (x *CSRF) fillToken(s *sessions.Session, c *gin.Context) error {
s.Values["csrf_token"] = x.generateCSRFToken()
c.Set("csrf_token", s.Values["csrf_token"])
return s.Save(c.Request, c.Writer)
}
// VerifyCSRFMiddleware verifies CSRF token
// Usage:
// engine := gin.New()
// engine.Use(csrf.VerifyCSRFMiddleware())
func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc {
return func(c *gin.Context) {
if x.strInSlice(ignoredMethods, c.Request.Method) {
return
}
var token string
session, _ := x.store.Get(c.Request, x.sessionName)
if i, ok := session.Values["csrf_token"]; ok {
2019-12-12 09:35:05 +03:00
var v string
if v, ok = i.(string); !ok || v == "" {
if !ok {
x.abortFunc(c, CSRFErrorIncorrectTokenType)
} else if v == "" {
x.abortFunc(c, CSRFErrorEmptyToken)
}
2019-10-31 14:21:39 +03:00
c.Abort()
return
}
2019-12-12 09:35:05 +03:00
token = v
2019-10-31 14:21:39 +03:00
} else {
x.abortFunc(c, CSRFErrorNoTokenInSession)
2019-10-31 14:21:39 +03:00
c.Abort()
return
}
if x.csrfTokenGetter(c) != token {
x.abortFunc(c, CSRFErrorTokenMismatch)
2019-10-31 14:21:39 +03:00
c.Abort()
return
}
}
}
// GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs).
func GetCSRFErrorMessage(r CSRFErrorReason) string {
switch r {
case CSRFErrorNoTokenInSession:
return "token is not present in session"
case CSRFErrorCannotStoreTokenInSession:
return "cannot store token in session"
case CSRFErrorIncorrectTokenType:
return "incorrect token type"
case CSRFErrorEmptyToken:
return "empty token present in session"
case CSRFErrorTokenMismatch:
return "token mismatch"
default:
return "unknown error"
}
}