mirror of
https://github.com/retailcrm/mg-transport-core.git
synced 2024-11-22 13:16:04 +03:00
306 lines
8.2 KiB
Go
306 lines
8.2 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
// nolint:gosec
|
|
"crypto/sha1"
|
|
"encoding/base64"
|
|
"io"
|
|
"math/rand"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/securecookie"
|
|
"github.com/gorilla/sessions"
|
|
)
|
|
|
|
// CSRFErrorReason is a error reason type.
|
|
type CSRFErrorReason uint8
|
|
|
|
// CSRFTokenGetter func type.
|
|
type CSRFTokenGetter func(*gin.Context) string
|
|
|
|
// CSRFAbortFunc is a callback which.
|
|
type CSRFAbortFunc func(*gin.Context, CSRFErrorReason)
|
|
|
|
const (
|
|
// CSRFErrorNoTokenInSession will be returned if token is not present in session.
|
|
CSRFErrorNoTokenInSession CSRFErrorReason = iota
|
|
|
|
// CSRFErrorCannotStoreTokenInSession will be returned if middleware cannot store token in session.
|
|
CSRFErrorCannotStoreTokenInSession
|
|
|
|
// CSRFErrorIncorrectTokenType will be returned if data type of token in session is not string.
|
|
CSRFErrorIncorrectTokenType
|
|
|
|
// CSRFErrorEmptyToken will be returned if token in session is empty.
|
|
CSRFErrorEmptyToken
|
|
|
|
// CSRFErrorTokenMismatch will be returned in case of invalid token.
|
|
CSRFErrorTokenMismatch
|
|
)
|
|
|
|
const (
|
|
keySize = 8
|
|
randomStringSize = 64
|
|
)
|
|
|
|
// DefaultCSRFTokenGetter default getter.
|
|
var DefaultCSRFTokenGetter = func(c *gin.Context) string {
|
|
r := c.Request
|
|
|
|
if t := r.URL.Query().Get("csrf_token"); len(t) > 0 {
|
|
return t
|
|
} else if t := r.Header.Get("X-CSRF-Token"); len(t) > 0 {
|
|
return t
|
|
} else if t := r.Header.Get("X-XSRF-Token"); len(t) > 0 {
|
|
return t
|
|
} else if c.Request.Body != nil {
|
|
data, _ := io.ReadAll(c.Request.Body)
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(data))
|
|
t := r.FormValue("csrf_token")
|
|
c.Request.Body = io.NopCloser(bytes.NewReader(data))
|
|
|
|
if len(t) > 0 {
|
|
return t
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// DefaultIgnoredMethods ignored methods for CSRF verifier middleware.
|
|
var DefaultIgnoredMethods = []string{"GET", "HEAD", "OPTIONS"}
|
|
|
|
// CSRF struct. Provides CSRF token verification.
|
|
type CSRF struct {
|
|
store sessions.Store
|
|
abortFunc CSRFAbortFunc
|
|
csrfTokenGetter CSRFTokenGetter
|
|
salt string
|
|
secret string
|
|
sessionName string
|
|
}
|
|
|
|
// NewCSRF creates CSRF struct with specified configuration and session store.
|
|
// GenerateCSRFMiddleware and VerifyCSRFMiddleware returns CSRF middlewares.
|
|
// Salt must be different every time (pass empty salt to use random), secret must be provided,
|
|
// sessionName is optional - pass empty to use default,
|
|
// store will be used to store sessions, abortFunc will be called to return error if token is invalid,
|
|
// csrfTokenGetter will be used to obtain token.
|
|
//
|
|
// Usage (with random salt):
|
|
//
|
|
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) {
|
|
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
|
|
// }, core.DefaultCSRFTokenGetter)
|
|
//
|
|
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field)
|
|
// - don't forget to restore Body data!
|
|
//
|
|
// Body in http.Request is io.ReadCloser instance. Reading CSRF token from form like that:
|
|
//
|
|
// if t := r.FormValue("csrf_token"); len(t) > 0 {
|
|
// return t
|
|
// }
|
|
//
|
|
// will close body - and all next middlewares won't be able to read body at all!
|
|
//
|
|
// Use DefaultCSRFTokenGetter as example to implement your own token getter.
|
|
// CSRFErrorReason will be passed to abortFunc and can be used for better error messages.
|
|
func NewCSRF(
|
|
salt, secret, sessionName string,
|
|
store sessions.Store,
|
|
abortFunc CSRFAbortFunc,
|
|
csrfTokenGetter CSRFTokenGetter,
|
|
) *CSRF {
|
|
if store == nil {
|
|
panic("store must not be nil")
|
|
}
|
|
|
|
if secret == "" {
|
|
panic("at least secret must be provided")
|
|
}
|
|
|
|
csrf := &CSRF{
|
|
store: store,
|
|
secret: secret,
|
|
abortFunc: abortFunc,
|
|
csrfTokenGetter: csrfTokenGetter,
|
|
}
|
|
|
|
if salt == "" {
|
|
salt = csrf.generateSalt()
|
|
}
|
|
|
|
if sessionName == "" {
|
|
sessionName = "csrf_token_session"
|
|
}
|
|
|
|
csrf.salt = salt
|
|
csrf.sessionName = sessionName
|
|
|
|
return csrf
|
|
}
|
|
|
|
// strInSlice checks whether string exists in slice.
|
|
func (x *CSRF) strInSlice(slice []string, v string) bool {
|
|
exists := false
|
|
|
|
for _, i := range slice {
|
|
if i == v {
|
|
exists = true
|
|
break
|
|
}
|
|
}
|
|
|
|
return exists
|
|
}
|
|
|
|
// generateCSRFToken generates new CSRF token.
|
|
func (x *CSRF) generateCSRFToken() string {
|
|
// nolint:gosec
|
|
h := sha1.New()
|
|
// Fallback to less secure method - token must be always filled even if we cannot properly generate it
|
|
if _, err := io.WriteString(h, x.salt+"#"+x.secret); err != nil {
|
|
return base64.URLEncoding.EncodeToString([]byte(time.Now().String()))
|
|
}
|
|
hash := base64.URLEncoding.EncodeToString(h.Sum(nil))
|
|
|
|
return hash
|
|
}
|
|
|
|
// generateSalt generates salt from random bytes. If it fails to generate cryptographically
|
|
// secure salt - it will generate pseudo-random, weaker salt.
|
|
// It will be used automatically if no salt provided.
|
|
// Default secure salt length: 8 bytes.
|
|
// Default pseudo-random salt length: 64 bytes.
|
|
func (x *CSRF) generateSalt() string {
|
|
salt := securecookie.GenerateRandomKey(keySize)
|
|
|
|
if salt == nil {
|
|
return x.pseudoRandomString(randomStringSize)
|
|
}
|
|
|
|
return string(salt)
|
|
}
|
|
|
|
// pseudoRandomString generates pseudo-random string with specified length.
|
|
func (x *CSRF) pseudoRandomString(length int) string {
|
|
r := rand.New(rand.NewSource(time.Now().UnixNano())) // nolint:gosec
|
|
data := make([]byte, length)
|
|
|
|
for i := 0; i < length; i++ { // it is supposed to use pseudo-random data.
|
|
data[i] = byte(65 + r.Intn(90-65)) // nolint:gosec,gomnd
|
|
}
|
|
|
|
return string(data)
|
|
}
|
|
|
|
// CSRFFromContext returns csrf token or random token. It shouldn't return empty string because
|
|
// it will make csrf protection useless. e.g. any request without token will work fine, which is unacceptable.
|
|
func (x *CSRF) CSRFFromContext(c *gin.Context) string {
|
|
if i, ok := c.Get("csrf_token"); ok {
|
|
if token, ok := i.(string); ok {
|
|
return token
|
|
}
|
|
}
|
|
|
|
return x.generateCSRFToken()
|
|
}
|
|
|
|
// GenerateCSRFMiddleware returns gin.HandlerFunc which will generate CSRF token
|
|
// Usage:
|
|
//
|
|
// engine := gin.New()
|
|
// csrf := NewCSRF("salt", "secret", "not_found", "incorrect", localizer)
|
|
// engine.Use(csrf.GenerateCSRFMiddleware())
|
|
func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
session, _ := x.store.Get(c.Request, x.sessionName)
|
|
|
|
if i, ok := session.Values["csrf_token"]; ok { // nolint:nestif
|
|
if i, ok := i.(string); !ok || i == "" { // nolint:nestif
|
|
if x.fillToken(session, c) != nil {
|
|
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
} else {
|
|
if x.fillToken(session, c) != nil {
|
|
x.abortFunc(c, CSRFErrorCannotStoreTokenInSession)
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// fillToken stores token in session and context.
|
|
func (x *CSRF) fillToken(s *sessions.Session, c *gin.Context) error {
|
|
s.Values["csrf_token"] = x.generateCSRFToken()
|
|
c.Set("csrf_token", s.Values["csrf_token"])
|
|
return s.Save(c.Request, c.Writer)
|
|
}
|
|
|
|
// VerifyCSRFMiddleware verifies CSRF token
|
|
// Usage:
|
|
//
|
|
// engine := gin.New()
|
|
// engine.Use(csrf.VerifyCSRFMiddleware())
|
|
func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if x.strInSlice(ignoredMethods, c.Request.Method) {
|
|
return
|
|
}
|
|
|
|
var token string
|
|
session, _ := x.store.Get(c.Request, x.sessionName)
|
|
|
|
if i, ok := session.Values["csrf_token"]; ok { // nolint:nestif
|
|
var v string
|
|
if v, ok = i.(string); !ok || v == "" { // nolint:nestif
|
|
if !ok {
|
|
x.abortFunc(c, CSRFErrorIncorrectTokenType)
|
|
} else if v == "" {
|
|
x.abortFunc(c, CSRFErrorEmptyToken)
|
|
}
|
|
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
token = v
|
|
} else {
|
|
x.abortFunc(c, CSRFErrorNoTokenInSession)
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
if x.csrfTokenGetter(c) != token {
|
|
x.abortFunc(c, CSRFErrorTokenMismatch)
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
// GetCSRFErrorMessage returns generic error message for CSRFErrorReason in English (useful for logs).
|
|
func GetCSRFErrorMessage(r CSRFErrorReason) string {
|
|
switch r {
|
|
case CSRFErrorNoTokenInSession:
|
|
return "token is not present in session"
|
|
case CSRFErrorCannotStoreTokenInSession:
|
|
return "cannot store token in session"
|
|
case CSRFErrorIncorrectTokenType:
|
|
return "incorrect token type"
|
|
case CSRFErrorEmptyToken:
|
|
return "empty token present in session"
|
|
case CSRFErrorTokenMismatch:
|
|
return "token mismatch"
|
|
default:
|
|
return "unknown error"
|
|
}
|
|
}
|