mg-transport-core/core/csrf.go
Neur0toxine d15ed7ffec
Static analysis (#23)
* lint stage for a workflow
* golangci-lint config
* lint only new code or last commit
* run lint only for pull requests
2021-02-09 14:57:14 +03:00

285 lines
8.0 KiB
Go

package core
import (
"bytes"
// nolint:gosec
"crypto/sha1"
"encoding/base64"
"io"
"io/ioutil"
"math/rand"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
)
// CSRFErrorReason is a error reason type.
type CSRFErrorReason uint8
// CSRFTokenGetter func type.
type CSRFTokenGetter func(*gin.Context) string
// CSRFAbortFunc is a callback which.
type CSRFAbortFunc func(*gin.Context, CSRFErrorReason)
const (
// CSRFErrorNoTokenInSession will be returned if token is not present in session.
CSRFErrorNoTokenInSession CSRFErrorReason = iota
// CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session.
CSRFErrorCannotStoreTokenInSession
// CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string.
CSRFErrorIncorrectTokenType
// CSRFErrorEmptyToken will be returned if token in session is empty.
CSRFErrorEmptyToken
// CSRFErrorTokenMismatch will be returned in case of invalid token.
CSRFErrorTokenMismatch
)
// DefaultCSRFTokenGetter default getter.
var DefaultCSRFTokenGetter = func(c *gin.Context) string {
r := c.Request
if t := r.URL.Query().Get("csrf_token"); len(t) > 0 {
return t
} else if t := r.Header.Get("X-CSRF-Token"); len(t) > 0 {
return t
} else if t := r.Header.Get("X-XSRF-Token"); len(t) > 0 {
return t
} else if c.Request.Body != nil {
data, _ := ioutil.ReadAll(c.Request.Body)
c.Request.Body = ioutil.NopCloser(bytes.NewReader(data))
t := r.FormValue("csrf_token")
c.Request.Body = ioutil.NopCloser(bytes.NewReader(data))
if len(t) > 0 {
return t
}
}
return ""
}
// DefaultIgnoredMethods ignored methods for CSRF verifier middleware.
var DefaultIgnoredMethods = []string{"GET", "HEAD", "OPTIONS"}
// CSRF struct. Provides CSRF token verification.
type CSRF struct {
salt string
secret string
sessionName string
abortFunc CSRFAbortFunc
csrfTokenGetter CSRFTokenGetter
store sessions.Store
}
// NewCSRF creates CSRF struct with specified configuration and session store.
// GenerateCSRFMiddleware and VerifyCSRFMiddleware returns CSRF middlewares.
// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default,
// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token.
// Usage (with random salt):
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) {
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
// }, core.DefaultCSRFTokenGetter)
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data!
// Body in http.Request is io.ReadCloser instance. Reading CSRF token from form like that:
// if t := r.FormValue("csrf_token"); len(t) > 0 {
// return t
// }
// will close body - and all next middlewares won't be able to read body at all!
// Use DefaultCSRFTokenGetter as example to implement your own token getter.
// CSRFErrorReason will be passed to abortFunc and can be used for better error messages.
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
if store == nil {
panic("store must not be nil")
}
if secret == "" {
panic("at least secret must be provided")
}
csrf := &CSRF{
store: store,
secret: secret,
abortFunc: abortFunc,
csrfTokenGetter: csrfTokenGetter,
}
if salt == "" {
salt = csrf.generateSalt()
}
if sessionName == "" {
sessionName = "csrf_token_session"
}
csrf.salt = salt
csrf.sessionName = sessionName
return csrf
}
// strInSlice checks whether string exists in slice.
func (x *CSRF) strInSlice(slice []string, v string) bool {
exists := false
for _, i := range slice {
if i == v {
exists = true
break
}
}
return exists
}
// generateCSRFToken generates new CSRF token.
func (x *CSRF) generateCSRFToken() string {
// nolint:gosec
h := sha1.New()
// Fallback to less secure method - token must be always filled even if we cannot properly generate it
if _, err := io.WriteString(h, x.salt+"#"+x.secret); err != nil {
return base64.URLEncoding.EncodeToString([]byte(time.Now().String()))
}
hash := base64.URLEncoding.EncodeToString(h.Sum(nil))
return hash
}
// generateSalt generates salt from random bytes. If it fails to generate cryptographically
// secure salt - it will generate pseudo-random, weaker salt.
// It will be used automatically if no salt provided.
// Default secure salt length: 8 bytes.
// Default pseudo-random salt length: 64 bytes.
func (x *CSRF) generateSalt() string {
salt := securecookie.GenerateRandomKey(8)
if salt == nil {
return x.pseudoRandomString(64)
}
return string(salt)
}
// pseudoRandomString generates pseudo-random string with specified length.
func (x *CSRF) pseudoRandomString(length int) string {
rand.Seed(time.Now().UnixNano())
data := make([]byte, length)
for i := 0; i < length; i++ {
data[i] = byte(65 + rand.Intn(90-65))
}
return string(data)
}
// CSRFFromContext returns csrf token or random token. It shouldn't return empty string because it will make csrf protection useless.
// e.g. any request without token will work fine, which is inacceptable.
func (x *CSRF) CSRFFromContext(c *gin.Context) string {
if i, ok := c.Get("csrf_token"); ok {
if token, ok := i.(string); ok {
return token
}
}
return x.generateCSRFToken()
}
// GenerateCSRFMiddleware returns gin.HandlerFunc which will generate CSRF token
// Usage:
// engine := gin.New()
// csrf := NewCSRF("salt", "secret", "not_found", "incorrect", localizer)
// engine.Use(csrf.GenerateCSRFMiddleware())
func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
session, _ := x.store.Get(c.Request, x.sessionName)
if i, ok := session.Values["csrf_token"]; ok {
if i, ok := i.(string); !ok || i == "" {
if x.fillToken(session, c) != nil {
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
c.Abort()
return
}
}
} else {
if x.fillToken(session, c) != nil {
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
c.Abort()
return
}
}
}
}
// fillToken stores token in session and context.
func (x *CSRF) fillToken(s *sessions.Session, c *gin.Context) error {
s.Values["csrf_token"] = x.generateCSRFToken()
c.Set("csrf_token", s.Values["csrf_token"])
return s.Save(c.Request, c.Writer)
}
// VerifyCSRFMiddleware verifies CSRF token
// Usage:
// engine := gin.New()
// engine.Use(csrf.VerifyCSRFMiddleware())
func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc {
return func(c *gin.Context) {
if x.strInSlice(ignoredMethods, c.Request.Method) {
return
}
var token string
session, _ := x.store.Get(c.Request, x.sessionName)
if i, ok := session.Values["csrf_token"]; ok {
var v string
if v, ok = i.(string); !ok || v == "" {
if !ok {
x.abortFunc(c, CSRFErrorIncorrectTokenType)
} else if v == "" {
x.abortFunc(c, CSRFErrorEmptyToken)
}
c.Abort()
return
}
token = v
} else {
x.abortFunc(c, CSRFErrorNoTokenInSession)
c.Abort()
return
}
if x.csrfTokenGetter(c) != token {
x.abortFunc(c, CSRFErrorTokenMismatch)
c.Abort()
return
}
}
}
// GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs).
func GetCSRFErrorMessage(r CSRFErrorReason) string {
switch r {
case CSRFErrorNoTokenInSession:
return "token is not present in session"
case CSRFErrorCannotStoreTokenInSession:
return "cannot store token in session"
case CSRFErrorIncorrectTokenType:
return "incorrect token type"
case CSRFErrorEmptyToken:
return "empty token present in session"
case CSRFErrorTokenMismatch:
return "token mismatch"
default:
return "unknown error"
}
}