add skipPaths to gin log middleware, refactor validator

This commit is contained in:
Pavel 2024-12-17 13:33:04 +03:00
parent b8ccfa8c8c
commit f5e2bfefb2
11 changed files with 367 additions and 69 deletions

View File

@ -27,7 +27,10 @@ import (
"github.com/retailcrm/mg-transport-core/v2/core/logger" "github.com/retailcrm/mg-transport-core/v2/core/logger"
) )
const DefaultHTTPClientTimeout time.Duration = 30 const (
DefaultHTTPClientTimeout time.Duration = 30
AppContextKey = "app"
)
var boolTrue = true var boolTrue = true
@ -110,6 +113,9 @@ func (e *Engine) initGin() {
} }
r := gin.New() r := gin.New()
r.Use(func(c *gin.Context) {
c.Set(AppContextKey, e)
})
e.buildSentryConfig() e.buildSentryConfig()
e.InitSentrySDK() e.InitSentrySDK()
@ -414,3 +420,19 @@ func (e *Engine) buildSentryConfig() {
Debug: e.Config.IsDebug(), Debug: e.Config.IsDebug(),
} }
} }
func GetApp(c *gin.Context) (app *Engine, exists bool) {
item, exists := c.Get(AppContextKey)
if !exists {
return nil, false
}
converted, ok := item.(*Engine)
if !ok {
return nil, false
}
return converted, true
}
func MustGetApp(c *gin.Context) *Engine {
return c.MustGet(AppContextKey).(*Engine)
}

View File

@ -5,7 +5,6 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"io" "io"
"os"
"sync" "sync"
"github.com/guregu/null/v5" "github.com/guregu/null/v5"
@ -32,7 +31,7 @@ type jSONRecordScanner struct {
func newJSONBufferedLogger(buf *bufferLogger) *jSONRecordScanner { func newJSONBufferedLogger(buf *bufferLogger) *jSONRecordScanner {
if buf == nil { if buf == nil {
buf = newBufferLogger() buf = newBufferLoggerSilent()
} }
return &jSONRecordScanner{scan: bufio.NewScanner(buf), buf: buf} return &jSONRecordScanner{scan: bufio.NewScanner(buf), buf: buf}
} }
@ -59,13 +58,17 @@ type bufferLogger struct {
buf lockableBuffer buf lockableBuffer
} }
// NewBufferedLogger returns new BufferedLogger instance. // newBufferLoggerSilent returns new BufferedLogger instance which won't duplicate entries to stdout/stderr.
func newBufferLogger() *bufferLogger { func newBufferLoggerSilent(level ...zapcore.Level) *bufferLogger {
lvl := zapcore.DebugLevel
if len(level) > 0 {
lvl = level[0]
}
bl := &bufferLogger{} bl := &bufferLogger{}
bl.Logger = zap.New( bl.Logger = zap.New(
zapcore.NewCore( zapcore.NewCore(
NewJSONWithContextEncoder( NewJSONWithContextEncoder(
EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), zapcore.DebugLevel)) EncoderConfigJSON()), &bl.buf, lvl))
return bl return bl
} }

View File

@ -32,7 +32,7 @@ func (s *TestDefaultSuite) TestNewDefault_Panic() {
} }
func (s *TestDefaultSuite) TestWith() { func (s *TestDefaultSuite) TestWith() {
log := newBufferLogger() log := newBufferLoggerSilent()
log.With(zap.String(HandlerAttr, "Handler")).Info("test") log.With(zap.String(HandlerAttr, "Handler")).Info("test")
items, err := newJSONBufferedLogger(log).ScanAll() items, err := newJSONBufferedLogger(log).ScanAll()
@ -42,7 +42,7 @@ func (s *TestDefaultSuite) TestWith() {
} }
func (s *TestDefaultSuite) TestWithLazy() { func (s *TestDefaultSuite) TestWithLazy() {
log := newBufferLogger() log := newBufferLoggerSilent()
log.WithLazy(zap.String(HandlerAttr, "Handler")).Info("test") log.WithLazy(zap.String(HandlerAttr, "Handler")).Info("test")
items, err := newJSONBufferedLogger(log).ScanAll() items, err := newJSONBufferedLogger(log).ScanAll()
@ -52,7 +52,7 @@ func (s *TestDefaultSuite) TestWithLazy() {
} }
func (s *TestDefaultSuite) TestForHandler() { func (s *TestDefaultSuite) TestForHandler() {
log := newBufferLogger() log := newBufferLoggerSilent()
log.ForHandler("Handler").Info("test") log.ForHandler("Handler").Info("test")
items, err := newJSONBufferedLogger(log).ScanAll() items, err := newJSONBufferedLogger(log).ScanAll()
@ -62,7 +62,7 @@ func (s *TestDefaultSuite) TestForHandler() {
} }
func (s *TestDefaultSuite) TestForHandlerNoDuplicate() { func (s *TestDefaultSuite) TestForHandlerNoDuplicate() {
log := newBufferLogger() log := newBufferLoggerSilent()
log.ForHandler("handler1").ForHandler("handler2").Info("test") log.ForHandler("handler1").ForHandler("handler2").Info("test")
s.Assert().Contains(log.String(), "handler2") s.Assert().Contains(log.String(), "handler2")
@ -70,7 +70,7 @@ func (s *TestDefaultSuite) TestForHandlerNoDuplicate() {
} }
func (s *TestDefaultSuite) TestForConnection() { func (s *TestDefaultSuite) TestForConnection() {
log := newBufferLogger() log := newBufferLoggerSilent()
log.ForConnection("connection").Info("test") log.ForConnection("connection").Info("test")
items, err := newJSONBufferedLogger(log).ScanAll() items, err := newJSONBufferedLogger(log).ScanAll()
@ -80,7 +80,7 @@ func (s *TestDefaultSuite) TestForConnection() {
} }
func (s *TestDefaultSuite) TestForConnectionNoDuplicate() { func (s *TestDefaultSuite) TestForConnectionNoDuplicate() {
log := newBufferLogger() log := newBufferLoggerSilent()
log.ForConnection("conn1").ForConnection("conn2").Info("test") log.ForConnection("conn1").ForConnection("conn2").Info("test")
s.Assert().Contains(log.String(), "conn2") s.Assert().Contains(log.String(), "conn2")
@ -88,7 +88,7 @@ func (s *TestDefaultSuite) TestForConnectionNoDuplicate() {
} }
func (s *TestDefaultSuite) TestForAccount() { func (s *TestDefaultSuite) TestForAccount() {
log := newBufferLogger() log := newBufferLoggerSilent()
log.ForAccount("account").Info("test") log.ForAccount("account").Info("test")
items, err := newJSONBufferedLogger(log).ScanAll() items, err := newJSONBufferedLogger(log).ScanAll()
@ -98,7 +98,7 @@ func (s *TestDefaultSuite) TestForAccount() {
} }
func (s *TestDefaultSuite) TestForAccountNoDuplicate() { func (s *TestDefaultSuite) TestForAccountNoDuplicate() {
log := newBufferLogger() log := newBufferLoggerSilent()
log.ForAccount("acc1").ForAccount("acc2").Info("test") log.ForAccount("acc1").ForAccount("acc2").Info("test")
s.Assert().Contains(log.String(), "acc2") s.Assert().Contains(log.String(), "acc2")
@ -106,7 +106,7 @@ func (s *TestDefaultSuite) TestForAccountNoDuplicate() {
} }
func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() { func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() {
log := newBufferLogger() log := newBufferLoggerSilent()
log. log.
ForHandler("handler1"). ForHandler("handler1").
ForHandler("handler2"). ForHandler("handler2").
@ -126,7 +126,7 @@ func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() {
// TestPersistRecordsIncompatibleWith is not a unit test, but rather a demonstration how you shouldn't use For* methods. // TestPersistRecordsIncompatibleWith is not a unit test, but rather a demonstration how you shouldn't use For* methods.
func (s *TestDefaultSuite) TestPersistRecordsIncompatibleWith() { func (s *TestDefaultSuite) TestPersistRecordsIncompatibleWith() {
log := newBufferLogger() log := newBufferLoggerSilent()
log. log.
ForHandler("handler1"). ForHandler("handler1").
With(zap.Int("f1", 1)). With(zap.Int("f1", 1)).

View File

@ -1,24 +1,64 @@
package logger package logger
import ( import (
"path"
"regexp"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"go.uber.org/zap" "go.uber.org/zap"
) )
const (
LoggerContextKey = "logger"
LoggerRealContextKey = "loggerReal"
)
// GinMiddleware will construct Gin middleware which will log requests and provide logger with unique request ID. // GinMiddleware will construct Gin middleware which will log requests and provide logger with unique request ID.
func GinMiddleware(log Logger) gin.HandlerFunc { func GinMiddleware(log Logger, skipPaths ...string) gin.HandlerFunc {
var (
skip map[string]struct{}
matchSkip []*skippedPath
)
if length := len(skipPaths); length > 0 {
skip = make(map[string]struct{}, length)
for _, path := range skipPaths {
if skipped, ok := newSkippedPath(path); ok {
matchSkip = append(matchSkip, skipped)
continue
}
skip[path] = struct{}{}
}
}
nilLogger := NewNil()
return func(c *gin.Context) { return func(c *gin.Context) {
// Start timer // Start timer
start := time.Now() start := time.Now()
path := c.Request.URL.Path path := c.Request.URL.Path
raw := c.Request.URL.RawQuery raw := c.Request.URL.RawQuery
_, shouldSkip := skip[path]
streamID := generateStreamID() streamID := generateStreamID()
log := log.With(StreamID(streamID)) log := log.With(StreamID(streamID))
if !shouldSkip && len(matchSkip) > 0 {
for _, skipper := range matchSkip {
if skipper.match(path) {
shouldSkip = true
break
}
}
}
if shouldSkip {
c.Set(LoggerRealContextKey, log)
log = nilLogger
}
c.Set(StreamIDAttr, streamID) c.Set(StreamIDAttr, streamID)
c.Set("logger", log) c.Set(LoggerContextKey, log)
// Process request // Process request
c.Next() c.Next()
@ -28,19 +68,68 @@ func GinMiddleware(log Logger) gin.HandlerFunc {
path = path + "?" + raw path = path + "?" + raw
} }
log.Info("request", if !shouldSkip {
zap.String(HandlerAttr, "GIN"), log.Info("request",
zap.String("startTime", start.Format(time.RFC3339)), zap.String(HandlerAttr, "GIN"),
zap.String("endTime", end.Format(time.RFC3339)), zap.String("startTime", start.Format(time.RFC3339)),
zap.Any("latency", end.Sub(start)/time.Millisecond), zap.String("endTime", end.Format(time.RFC3339)),
zap.String("remoteAddress", c.ClientIP()), zap.Any("latency", end.Sub(start)/time.Millisecond),
zap.String(HTTPMethodAttr, c.Request.Method), zap.String("remoteAddress", c.ClientIP()),
zap.String("path", path), zap.String(HTTPMethodAttr, c.Request.Method),
zap.Int("bodySize", c.Writer.Size()), zap.String("path", path),
) zap.Int("bodySize", c.Writer.Size()),
)
}
} }
} }
func MustGet(c *gin.Context) Logger { func MustGet(c *gin.Context) Logger {
return c.MustGet("logger").(Logger) return c.MustGet("logger").(Logger)
} }
func MustGetReal(c *gin.Context) Logger {
log, ok := c.Get(LoggerContextKey)
if _, isNil := log.(*Nil); !ok || isNil {
return c.MustGet(LoggerRealContextKey).(Logger)
}
return log.(Logger)
}
var (
hasParamsMatcher = regexp.MustCompile(`/:\w+`)
hasWildcardParamsMatcher = regexp.MustCompile(`/\*\w+.*`)
)
type skippedPath struct {
path string
expr *regexp.Regexp
}
// newSkippedPath returns new path skipping struct. It returns nil, false if expr is simple and
// no complex logic is needed.
func newSkippedPath(expr string) (result *skippedPath, compatible bool) {
hasParams, hasWildcard := hasParamsMatcher.MatchString(expr), hasWildcardParamsMatcher.MatchString(expr)
if !hasParams && !hasWildcard {
return nil, false
}
if hasWildcard {
return &skippedPath{expr: matcherForWildcard(expr)}, true
}
return &skippedPath{path: matcherForPath(expr)}, true
}
func matcherForWildcard(expr string) *regexp.Regexp {
return regexp.MustCompile(hasWildcardParamsMatcher.ReplaceAllString(expr, "/[\\w/]+"))
}
func matcherForPath(expr string) string {
return hasParamsMatcher.ReplaceAllString(expr, "/*")
}
func (p *skippedPath) match(route string) bool {
if p.expr != nil {
return p.expr.MatchString(route)
}
result, err := path.Match(p.path, route)
return result && err == nil
}

View File

@ -1,6 +1,8 @@
package logger package logger
import ( import (
"encoding/json"
"fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"testing" "testing"
@ -11,7 +13,7 @@ import (
) )
func TestGinMiddleware(t *testing.T) { func TestGinMiddleware(t *testing.T) {
log := newBufferLogger() log := newBufferLoggerSilent()
rr := httptest.NewRecorder() rr := httptest.NewRecorder()
r := gin.New() r := gin.New()
r.Use(GinMiddleware(log)) r.Use(GinMiddleware(log))
@ -41,3 +43,81 @@ func TestGinMiddleware(t *testing.T) {
assert.NotEmpty(t, items[1].Context["path"]) assert.NotEmpty(t, items[1].Context["path"])
assert.NotEmpty(t, items[1].Context["bodySize"]) assert.NotEmpty(t, items[1].Context["bodySize"])
} }
func TestGinMiddleware_SkipPaths(t *testing.T) {
log := newBufferLoggerSilent()
rr := httptest.NewRecorder()
r := gin.New()
r.Use(GinMiddleware(log, "/hidden", "/hidden/:id", "/superhidden/*id"))
r.GET("/hidden", func(c *gin.Context) {
log := MustGet(c)
log.Info("hidden message from /hidden")
realLog := MustGetReal(c)
realLog.Info("visible message from /hidden")
c.JSON(http.StatusOK, gin.H{})
})
r.GET("/logged", func(c *gin.Context) {
log := MustGet(c)
log.Info("visible message from /logged")
c.JSON(http.StatusOK, gin.H{})
})
r.GET("/hidden/:id", func(c *gin.Context) {
log := MustGet(c)
log.Info("hidden message from /hidden/:id")
c.JSON(http.StatusOK, gin.H{})
})
r.GET("/superhidden/*id", func(c *gin.Context) {
log := MustGet(c)
log.Info("hidden message from /superhidden/*id")
c.JSON(http.StatusOK, gin.H{})
})
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/hidden", nil))
require.Equal(t, http.StatusOK, rr.Code)
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/logged", nil))
require.Equal(t, http.StatusOK, rr.Code)
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/hidden/param", nil))
require.Equal(t, http.StatusOK, rr.Code)
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/superhidden/param/1/2/3", nil))
require.Equal(t, http.StatusOK, rr.Code)
items, err := newJSONBufferedLogger(log).ScanAll()
require.NoError(t, err)
require.Len(t, items, 3, printEntries(items))
assert.Equal(t, "visible message from /hidden", items[0].Message, printEntries(items))
assert.Equal(t, "visible message from /logged", items[1].Message, printEntries(items))
}
func TestSkippedPath(t *testing.T) {
cases := map[string]map[string]bool{
"/hidden/:id": {
"/hidden/1": true,
"/hidden/2/3": false,
},
"/hidden/*id": {
"/hidden/1": true,
"/hidden/2/3": true,
},
}
for pattern, items := range cases {
matcher, ok := newSkippedPath(pattern)
require.True(t, ok)
for item, result := range items {
assert.Equal(t, result, matcher.match(item), fmt.Sprintf(`"%s" does not match "%s", internals: %#v`,
pattern, item, matcher))
}
}
}
func printEntries(entries []logRecord) string {
data, err := json.MarshalIndent(entries, "", " ")
if err != nil {
panic(err)
}
return string(data)
}

View File

@ -9,7 +9,7 @@ import (
) )
func TestWriterAdapter(t *testing.T) { func TestWriterAdapter(t *testing.T) {
log := newBufferLogger() log := newBufferLoggerSilent()
adapter := WriterAdapter(log, zap.InfoLevel) adapter := WriterAdapter(log, zap.InfoLevel)
msg := []byte("hello world") msg := []byte("hello world")

View File

@ -9,7 +9,7 @@ import (
) )
func TestZabbixCollectorAdapter(t *testing.T) { func TestZabbixCollectorAdapter(t *testing.T) {
log := newBufferLogger() log := newBufferLoggerSilent()
adapter := ZabbixCollectorAdapter(log) adapter := ZabbixCollectorAdapter(log)
adapter.Errorf("highly unexpected error: %s", "unexpected error") adapter.Errorf("highly unexpected error: %s", "unexpected error")
adapter.Errorf("cannot stop collector: %s", "app error") adapter.Errorf("cannot stop collector: %s", "app error")

View File

@ -42,12 +42,30 @@ type BufferLogger struct {
} }
// NewBufferedLogger returns new BufferedLogger instance. // NewBufferedLogger returns new BufferedLogger instance.
func NewBufferedLogger() BufferedLogger { func NewBufferedLogger(level ...zapcore.Level) BufferedLogger {
lvl := zapcore.DebugLevel
if len(level) > 0 {
lvl = level[0]
}
bl := &BufferLogger{} bl := &BufferLogger{}
bl.Logger = zap.New( bl.Logger = zap.New(
zapcore.NewCore( zapcore.NewCore(
logger.NewJSONWithContextEncoder( logger.NewJSONWithContextEncoder(
logger.EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), zapcore.DebugLevel)) logger.EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), lvl))
return bl
}
// NewBufferedLoggerSilent returns new BufferedLogger instance which won't duplicate entries to stdout/stderr.
func NewBufferedLoggerSilent(level ...zapcore.Level) BufferedLogger {
lvl := zapcore.DebugLevel
if len(level) > 0 {
lvl = level[0]
}
bl := &BufferLogger{}
bl.Logger = zap.New(
zapcore.NewCore(
logger.NewJSONWithContextEncoder(
logger.EncoderConfigJSON()), &bl.buf, lvl))
return bl return bl
} }

View File

@ -2,6 +2,7 @@ package testutil
import ( import (
"io" "io"
"sync"
"testing" "testing"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
@ -9,7 +10,8 @@ import (
type BufferLoggerTest struct { type BufferLoggerTest struct {
suite.Suite suite.Suite
logger BufferedLogger logger BufferedLogger
silentLogger BufferedLogger
} }
func TestBufferLogger(t *testing.T) { func TestBufferLogger(t *testing.T) {
@ -18,10 +20,12 @@ func TestBufferLogger(t *testing.T) {
func (t *BufferLoggerTest) SetupSuite() { func (t *BufferLoggerTest) SetupSuite() {
t.logger = NewBufferedLogger() t.logger = NewBufferedLogger()
t.silentLogger = NewBufferedLoggerSilent()
} }
func (t *BufferLoggerTest) SetupTest() { func (t *BufferLoggerTest) SetupTest() {
t.logger.Reset() t.logger.Reset()
t.silentLogger.Reset()
} }
func (t *BufferLoggerTest) Test_Read() { func (t *BufferLoggerTest) Test_Read() {
@ -31,25 +35,62 @@ func (t *BufferLoggerTest) Test_Read() {
t.Require().NoError(err) t.Require().NoError(err)
t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"") t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(string(data), "\"message\":\"test\"") t.Assert().Contains(string(data), "\"message\":\"test\"")
t.silentLogger.Debug("test")
data, err = io.ReadAll(t.silentLogger)
t.Require().NoError(err)
t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(string(data), "\"message\":\"test\"")
} }
func (t *BufferLoggerTest) Test_Bytes() { func (t *BufferLoggerTest) Test_Bytes() {
t.logger.Debug("test") t.logger.Debug("test")
t.Assert().Contains(string(t.logger.Bytes()), "\"level_name\":\"DEBUG\"") t.Assert().Contains(string(t.logger.Bytes()), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(string(t.logger.Bytes()), "\"message\":\"test\"") t.Assert().Contains(string(t.logger.Bytes()), "\"message\":\"test\"")
t.silentLogger.Debug("test")
t.Assert().Contains(string(t.silentLogger.Bytes()), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(string(t.silentLogger.Bytes()), "\"message\":\"test\"")
} }
func (t *BufferLoggerTest) Test_String() { func (t *BufferLoggerTest) Test_String() {
t.logger.Debug("test") t.logger.Debug("test")
t.Assert().Contains(t.logger.String(), "\"level_name\":\"DEBUG\"") t.Assert().Contains(t.logger.String(), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(t.logger.String(), "\"message\":\"test\"") t.Assert().Contains(t.logger.String(), "\"message\":\"test\"")
t.silentLogger.Debug("test")
t.Assert().Contains(t.silentLogger.String(), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(t.silentLogger.String(), "\"message\":\"test\"")
} }
func (t *BufferLoggerTest) TestRace() { func (t *BufferLoggerTest) TestRace() {
var (
wg sync.WaitGroup
starter sync.WaitGroup
)
starter.Add(1)
wg.Add(4)
go func() { go func() {
starter.Wait()
t.logger.Debug("test") t.logger.Debug("test")
wg.Done()
}() }()
go func() { go func() {
starter.Wait()
t.logger.String() t.logger.String()
wg.Done()
}() }()
go func() {
starter.Wait()
t.silentLogger.Debug("test")
wg.Done()
}()
go func() {
starter.Wait()
t.silentLogger.String()
wg.Done()
}()
starter.Done()
wg.Wait()
} }

View File

@ -1,13 +1,77 @@
package core package core
import ( import (
"go.uber.org/atomic"
"net/url" "net/url"
"strings" "strings"
"sync"
"time"
"github.com/gin-gonic/gin/binding" "github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10" "github.com/go-playground/validator/v10"
) )
var (
crmDomainStore = &domainStore{
source: GetSaasDomains,
matcher: func(domain string, domains []Domain) bool {
if len(domains) == 0 {
return false
}
secondLevel := strings.Join(strings.Split(domain, ".")[1:], ".")
for _, crmDomain := range domains {
if crmDomain.Domain == secondLevel {
return true
}
}
return false
},
}
boxDomainStore = &domainStore{
source: GetBoxDomains,
matcher: func(domain string, domains []Domain) bool {
if len(domains) == 0 {
return false
}
for _, crmDomain := range domains {
if crmDomain.Domain == domain {
return true
}
}
return false
},
}
)
type domainStore struct {
domains []Domain
mutex sync.RWMutex
source func() []Domain
matcher func(string, []Domain) bool
lastUpdate atomic.Time
}
func (ds *domainStore) match(domain string) bool {
if time.Since(ds.lastUpdate.Load()) > time.Hour {
ds.update()
}
defer ds.mutex.RUnlock()
ds.mutex.RLock()
return ds.matcher(domain, ds.domains)
}
func (ds *domainStore) update() {
defer ds.mutex.Unlock()
ds.mutex.Lock()
ds.domains = ds.source()
ds.lastUpdate.Store(time.Now())
}
// init here will register `validateCrmURL` function for gin validator. // init here will register `validateCrmURL` function for gin validator.
func init() { func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok { if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
@ -26,40 +90,16 @@ func validateCrmURL(fl validator.FieldLevel) bool {
func isDomainValid(crmURL string) bool { func isDomainValid(crmURL string) bool {
parseURL, err := url.ParseRequestURI(crmURL) parseURL, err := url.ParseRequestURI(crmURL)
if err != nil || nil == parseURL || !checkURLString(parseURL) { if err != nil || nil == parseURL || !checkURLString(parseURL) {
return false return false
} }
mainDomain := getMainDomain(parseURL.Hostname()) hostname := parseURL.Hostname()
if crmDomainStore.match(hostname) {
if checkDomains(GetSaasDomains(), mainDomain) {
return true return true
} }
if checkDomains(GetBoxDomains(), parseURL.Hostname()) { return boxDomainStore.match(hostname)
return true
}
return false
}
func checkDomains(crmDomains []Domain, domain string) bool {
if nil == crmDomains {
return false
}
for _, crmDomain := range crmDomains {
if crmDomain.Domain == domain {
return true
}
}
return false
}
func getMainDomain(hostname string) (mainDomain string) {
return strings.Join(strings.Split(hostname, ".")[1:], ".")
} }
func checkURLString(parseURL *url.URL) bool { func checkURLString(parseURL *url.URL) bool {

View File

@ -29,6 +29,11 @@ func (s *ValidatorSuite) SetupSuite() {
} }
} }
func (s *ValidatorSuite) SetupTest() {
crmDomainStore.update()
boxDomainStore.update()
}
func (s *ValidatorSuite) getError(err error) string { func (s *ValidatorSuite) getError(err error) string {
if err == nil { if err == nil {
return "" return ""
@ -72,14 +77,14 @@ func (s *ValidatorSuite) Test_ValidationSuccess() {
"https://test.retailcrm.pro", "https://test.retailcrm.pro",
"https://raisa.retailcrm.es", "https://raisa.retailcrm.es",
"https://blabla.simla.com", "https://blabla.simla.com",
"https://blabla.ecomlogic.com",
"https://crm.baucenter.ru",
"https://crm.holodilnik.ru",
"https://crm.eco.lanit.ru",
"https://ecom.inventive.ru",
"https://retailcrm.tvoydom.ru",
} }
for _, domain := range boxDomainStore.domains {
crmDomains = append(crmDomains, "https://"+domain.Domain)
}
s.Assert().True(len(crmDomains) > 4, "No box domains were tested, test is incomplete!")
for _, domain := range crmDomains { for _, domain := range crmDomains {
conn := models.Connection{ conn := models.Connection{
Key: "key", Key: "key",
@ -87,6 +92,6 @@ func (s *ValidatorSuite) Test_ValidationSuccess() {
} }
err := s.engine.Struct(conn) err := s.engine.Struct(conn)
assert.NoError(s.T(), err, s.getError(err)) assert.NoError(s.T(), err, domain+": "+s.getError(err))
} }
} }