mirror of
https://github.com/retailcrm/mg-transport-core.git
synced 2024-11-22 13:16:04 +03:00
234 lines
6.1 KiB
Go
234 lines
6.1 KiB
Go
package core
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/sha1"
|
|
"encoding/base64"
|
|
"io"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/gorilla/securecookie"
|
|
"github.com/gorilla/sessions"
|
|
)
|
|
|
|
// CSRFTokenGetter func type
|
|
type CSRFTokenGetter func(c *gin.Context) string
|
|
|
|
// DefaultCSRFTokenGetter default getter
|
|
var DefaultCSRFTokenGetter = func(c *gin.Context) string {
|
|
r := c.Request
|
|
|
|
if t := r.URL.Query().Get("csrf_token"); len(t) > 0 {
|
|
return t
|
|
} else if t := r.Header.Get("X-CSRF-Token"); len(t) > 0 {
|
|
return t
|
|
} else if t := r.Header.Get("X-XSRF-Token"); len(t) > 0 {
|
|
return t
|
|
} else if c.Request.Body != nil {
|
|
data, _ := ioutil.ReadAll(c.Request.Body)
|
|
c.Request.Body = ioutil.NopCloser(bytes.NewReader(data))
|
|
t := r.FormValue("csrf_token")
|
|
c.Request.Body = ioutil.NopCloser(bytes.NewReader(data))
|
|
|
|
if len(t) > 0 {
|
|
return t
|
|
}
|
|
}
|
|
|
|
return ""
|
|
}
|
|
|
|
// DefaultIgnoredMethods ignored methods for CSRF verifier middleware
|
|
var DefaultIgnoredMethods = []string{"GET", "HEAD", "OPTIONS"}
|
|
|
|
type CSRF struct {
|
|
salt string
|
|
secret string
|
|
sessionName string
|
|
abortFunc gin.HandlerFunc
|
|
csrfTokenGetter CSRFTokenGetter
|
|
store sessions.Store
|
|
locale *Localizer
|
|
}
|
|
|
|
// NewCSRF creates CSRF struct with specified configuration and session store.
|
|
// GenerateCSRFMiddleware and VerifyCSRFMiddleware returns CSRF middlewares.
|
|
// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default,
|
|
// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token.
|
|
// Usage (with random salt):
|
|
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context) {
|
|
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
|
|
// }, core.DefaultCSRFTokenGetter)
|
|
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data!
|
|
// Body in http.Request is io.ReadCloser instance. Reading CSRF token from form like that:
|
|
// if t := r.FormValue("csrf_token"); len(t) > 0 {
|
|
// return t
|
|
// }
|
|
// will close body - and all next middlewares won't be able to read body at all!
|
|
// Use DefaultCSRFTokenGetter as example to implement your own token getter.
|
|
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc gin.HandlerFunc, csrfTokenGetter CSRFTokenGetter) *CSRF {
|
|
if store == nil {
|
|
panic("store must not be nil")
|
|
}
|
|
|
|
if secret == "" {
|
|
panic("at least secret must be provided")
|
|
}
|
|
|
|
csrf := &CSRF{
|
|
store: store,
|
|
secret: secret,
|
|
abortFunc: abortFunc,
|
|
csrfTokenGetter: csrfTokenGetter,
|
|
}
|
|
|
|
if salt == "" {
|
|
salt = csrf.generateSalt()
|
|
}
|
|
|
|
if sessionName == "" {
|
|
sessionName = "csrf_token_session"
|
|
}
|
|
|
|
csrf.salt = salt
|
|
csrf.sessionName = sessionName
|
|
|
|
return csrf
|
|
}
|
|
|
|
// strInSlice checks whether string exists in slice
|
|
func (x *CSRF) strInSlice(slice []string, v string) bool {
|
|
exists := false
|
|
|
|
for _, i := range slice {
|
|
if i == v {
|
|
exists = true
|
|
break
|
|
}
|
|
}
|
|
|
|
return exists
|
|
}
|
|
|
|
// generateCSRFToken generates new CSRF token
|
|
func (x *CSRF) generateCSRFToken() string {
|
|
h := sha1.New()
|
|
io.WriteString(h, x.salt+"#"+x.secret)
|
|
hash := base64.URLEncoding.EncodeToString(h.Sum(nil))
|
|
|
|
return hash
|
|
}
|
|
|
|
// generateSalt generates salt from random bytes. If it fails to generate cryptographically
|
|
// secure salt - it will generate pseudo-random, weaker salt.
|
|
// It will be used automatically if no salt provided.
|
|
// Default secure salt length: 8 bytes.
|
|
// Default pseudo-random salt length: 64 bytes.
|
|
func (x *CSRF) generateSalt() string {
|
|
salt := securecookie.GenerateRandomKey(8)
|
|
|
|
if salt == nil {
|
|
return x.pseudoRandomString(64)
|
|
}
|
|
|
|
return string(salt)
|
|
}
|
|
|
|
// pseudoRandomString generates pseudo-random string with specified length
|
|
func (x *CSRF) pseudoRandomString(length int) string {
|
|
rand.Seed(time.Now().UnixNano())
|
|
data := make([]byte, length)
|
|
|
|
for i := 0; i < length; i++ {
|
|
data[i] = byte(65 + rand.Intn(90-65))
|
|
}
|
|
|
|
return string(data)
|
|
}
|
|
|
|
// CSRFFromContext returns csrf token or random token. It shouldn't return empty string because it will make csrf protection useless.
|
|
// e.g. any request without token will work fine, which is inacceptable.
|
|
func (x *CSRF) CSRFFromContext(c *gin.Context) string {
|
|
if i, ok := c.Get("csrf_token"); ok {
|
|
if token, ok := i.(string); ok {
|
|
return token
|
|
} else {
|
|
return x.generateCSRFToken()
|
|
}
|
|
} else {
|
|
return x.generateCSRFToken()
|
|
}
|
|
}
|
|
|
|
// GenerateCSRFMiddleware returns gin.HandlerFunc which will generate CSRF token
|
|
// Usage:
|
|
// engine := gin.New()
|
|
// csrf := NewCSRF("salt", "secret", "not_found", "incorrect", localizer)
|
|
// engine.Use(csrf.GenerateCSRFMiddleware())
|
|
func (x *CSRF) GenerateCSRFMiddleware() gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
session, _ := x.store.Get(c.Request, x.sessionName)
|
|
|
|
if i, ok := session.Values["csrf_token"]; ok {
|
|
if i, ok := i.(string); !ok || i == "" {
|
|
if x.fillToken(session, c) != nil {
|
|
x.abortFunc(c)
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
} else {
|
|
if x.fillToken(session, c) != nil {
|
|
x.abortFunc(c)
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// fillToken stores token in session and context
|
|
func (x *CSRF) fillToken(s *sessions.Session, c *gin.Context) error {
|
|
s.Values["csrf_token"] = x.generateCSRFToken()
|
|
c.Set("csrf_token", s.Values["csrf_token"])
|
|
return s.Save(c.Request, c.Writer)
|
|
}
|
|
|
|
// VerifyCSRFMiddleware verifies CSRF token
|
|
// Usage:
|
|
// engine := gin.New()
|
|
// engine.Use(csrf.VerifyCSRFMiddleware())
|
|
func (x *CSRF) VerifyCSRFMiddleware(ignoredMethods []string) gin.HandlerFunc {
|
|
return func(c *gin.Context) {
|
|
if x.strInSlice(ignoredMethods, c.Request.Method) {
|
|
return
|
|
}
|
|
|
|
var token string
|
|
session, _ := x.store.Get(c.Request, x.sessionName)
|
|
|
|
if i, ok := session.Values["csrf_token"]; ok {
|
|
if i, ok := i.(string); !ok || i == "" {
|
|
x.abortFunc(c)
|
|
c.Abort()
|
|
return
|
|
} else {
|
|
token = i
|
|
}
|
|
} else {
|
|
x.abortFunc(c)
|
|
c.Abort()
|
|
return
|
|
}
|
|
|
|
if x.csrfTokenGetter(c) != token {
|
|
x.abortFunc(c)
|
|
c.Abort()
|
|
return
|
|
}
|
|
}
|
|
}
|