mirror of
https://github.com/retailcrm/mg-transport-core.git
synced 2025-01-29 21:01:41 +03:00
add skipPaths to gin log middleware, refactor validator
This commit is contained in:
parent
b8ccfa8c8c
commit
f5e2bfefb2
@ -27,7 +27,10 @@ import (
|
||||
"github.com/retailcrm/mg-transport-core/v2/core/logger"
|
||||
)
|
||||
|
||||
const DefaultHTTPClientTimeout time.Duration = 30
|
||||
const (
|
||||
DefaultHTTPClientTimeout time.Duration = 30
|
||||
AppContextKey = "app"
|
||||
)
|
||||
|
||||
var boolTrue = true
|
||||
|
||||
@ -110,6 +113,9 @@ func (e *Engine) initGin() {
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.Use(func(c *gin.Context) {
|
||||
c.Set(AppContextKey, e)
|
||||
})
|
||||
|
||||
e.buildSentryConfig()
|
||||
e.InitSentrySDK()
|
||||
@ -414,3 +420,19 @@ func (e *Engine) buildSentryConfig() {
|
||||
Debug: e.Config.IsDebug(),
|
||||
}
|
||||
}
|
||||
|
||||
func GetApp(c *gin.Context) (app *Engine, exists bool) {
|
||||
item, exists := c.Get(AppContextKey)
|
||||
if !exists {
|
||||
return nil, false
|
||||
}
|
||||
converted, ok := item.(*Engine)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return converted, true
|
||||
}
|
||||
|
||||
func MustGetApp(c *gin.Context) *Engine {
|
||||
return c.MustGet(AppContextKey).(*Engine)
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/guregu/null/v5"
|
||||
@ -32,7 +31,7 @@ type jSONRecordScanner struct {
|
||||
|
||||
func newJSONBufferedLogger(buf *bufferLogger) *jSONRecordScanner {
|
||||
if buf == nil {
|
||||
buf = newBufferLogger()
|
||||
buf = newBufferLoggerSilent()
|
||||
}
|
||||
return &jSONRecordScanner{scan: bufio.NewScanner(buf), buf: buf}
|
||||
}
|
||||
@ -59,13 +58,17 @@ type bufferLogger struct {
|
||||
buf lockableBuffer
|
||||
}
|
||||
|
||||
// NewBufferedLogger returns new BufferedLogger instance.
|
||||
func newBufferLogger() *bufferLogger {
|
||||
// newBufferLoggerSilent returns new BufferedLogger instance which won't duplicate entries to stdout/stderr.
|
||||
func newBufferLoggerSilent(level ...zapcore.Level) *bufferLogger {
|
||||
lvl := zapcore.DebugLevel
|
||||
if len(level) > 0 {
|
||||
lvl = level[0]
|
||||
}
|
||||
bl := &bufferLogger{}
|
||||
bl.Logger = zap.New(
|
||||
zapcore.NewCore(
|
||||
NewJSONWithContextEncoder(
|
||||
EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), zapcore.DebugLevel))
|
||||
EncoderConfigJSON()), &bl.buf, lvl))
|
||||
return bl
|
||||
}
|
||||
|
||||
|
@ -32,7 +32,7 @@ func (s *TestDefaultSuite) TestNewDefault_Panic() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestWith() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.With(zap.String(HandlerAttr, "Handler")).Info("test")
|
||||
items, err := newJSONBufferedLogger(log).ScanAll()
|
||||
|
||||
@ -42,7 +42,7 @@ func (s *TestDefaultSuite) TestWith() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestWithLazy() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.WithLazy(zap.String(HandlerAttr, "Handler")).Info("test")
|
||||
items, err := newJSONBufferedLogger(log).ScanAll()
|
||||
|
||||
@ -52,7 +52,7 @@ func (s *TestDefaultSuite) TestWithLazy() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestForHandler() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.ForHandler("Handler").Info("test")
|
||||
items, err := newJSONBufferedLogger(log).ScanAll()
|
||||
|
||||
@ -62,7 +62,7 @@ func (s *TestDefaultSuite) TestForHandler() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestForHandlerNoDuplicate() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.ForHandler("handler1").ForHandler("handler2").Info("test")
|
||||
|
||||
s.Assert().Contains(log.String(), "handler2")
|
||||
@ -70,7 +70,7 @@ func (s *TestDefaultSuite) TestForHandlerNoDuplicate() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestForConnection() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.ForConnection("connection").Info("test")
|
||||
items, err := newJSONBufferedLogger(log).ScanAll()
|
||||
|
||||
@ -80,7 +80,7 @@ func (s *TestDefaultSuite) TestForConnection() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestForConnectionNoDuplicate() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.ForConnection("conn1").ForConnection("conn2").Info("test")
|
||||
|
||||
s.Assert().Contains(log.String(), "conn2")
|
||||
@ -88,7 +88,7 @@ func (s *TestDefaultSuite) TestForConnectionNoDuplicate() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestForAccount() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.ForAccount("account").Info("test")
|
||||
items, err := newJSONBufferedLogger(log).ScanAll()
|
||||
|
||||
@ -98,7 +98,7 @@ func (s *TestDefaultSuite) TestForAccount() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestForAccountNoDuplicate() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.ForAccount("acc1").ForAccount("acc2").Info("test")
|
||||
|
||||
s.Assert().Contains(log.String(), "acc2")
|
||||
@ -106,7 +106,7 @@ func (s *TestDefaultSuite) TestForAccountNoDuplicate() {
|
||||
}
|
||||
|
||||
func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.
|
||||
ForHandler("handler1").
|
||||
ForHandler("handler2").
|
||||
@ -126,7 +126,7 @@ func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() {
|
||||
|
||||
// TestPersistRecordsIncompatibleWith is not a unit test, but rather a demonstration how you shouldn't use For* methods.
|
||||
func (s *TestDefaultSuite) TestPersistRecordsIncompatibleWith() {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
log.
|
||||
ForHandler("handler1").
|
||||
With(zap.Int("f1", 1)).
|
||||
|
@ -1,24 +1,64 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"path"
|
||||
"regexp"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
LoggerContextKey = "logger"
|
||||
LoggerRealContextKey = "loggerReal"
|
||||
)
|
||||
|
||||
// GinMiddleware will construct Gin middleware which will log requests and provide logger with unique request ID.
|
||||
func GinMiddleware(log Logger) gin.HandlerFunc {
|
||||
func GinMiddleware(log Logger, skipPaths ...string) gin.HandlerFunc {
|
||||
var (
|
||||
skip map[string]struct{}
|
||||
matchSkip []*skippedPath
|
||||
)
|
||||
if length := len(skipPaths); length > 0 {
|
||||
skip = make(map[string]struct{}, length)
|
||||
|
||||
for _, path := range skipPaths {
|
||||
if skipped, ok := newSkippedPath(path); ok {
|
||||
matchSkip = append(matchSkip, skipped)
|
||||
continue
|
||||
}
|
||||
skip[path] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
nilLogger := NewNil()
|
||||
|
||||
return func(c *gin.Context) {
|
||||
// Start timer
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := c.Request.URL.RawQuery
|
||||
|
||||
_, shouldSkip := skip[path]
|
||||
streamID := generateStreamID()
|
||||
log := log.With(StreamID(streamID))
|
||||
|
||||
if !shouldSkip && len(matchSkip) > 0 {
|
||||
for _, skipper := range matchSkip {
|
||||
if skipper.match(path) {
|
||||
shouldSkip = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if shouldSkip {
|
||||
c.Set(LoggerRealContextKey, log)
|
||||
log = nilLogger
|
||||
}
|
||||
|
||||
c.Set(StreamIDAttr, streamID)
|
||||
c.Set("logger", log)
|
||||
c.Set(LoggerContextKey, log)
|
||||
|
||||
// Process request
|
||||
c.Next()
|
||||
@ -28,19 +68,68 @@ func GinMiddleware(log Logger) gin.HandlerFunc {
|
||||
path = path + "?" + raw
|
||||
}
|
||||
|
||||
log.Info("request",
|
||||
zap.String(HandlerAttr, "GIN"),
|
||||
zap.String("startTime", start.Format(time.RFC3339)),
|
||||
zap.String("endTime", end.Format(time.RFC3339)),
|
||||
zap.Any("latency", end.Sub(start)/time.Millisecond),
|
||||
zap.String("remoteAddress", c.ClientIP()),
|
||||
zap.String(HTTPMethodAttr, c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Int("bodySize", c.Writer.Size()),
|
||||
)
|
||||
if !shouldSkip {
|
||||
log.Info("request",
|
||||
zap.String(HandlerAttr, "GIN"),
|
||||
zap.String("startTime", start.Format(time.RFC3339)),
|
||||
zap.String("endTime", end.Format(time.RFC3339)),
|
||||
zap.Any("latency", end.Sub(start)/time.Millisecond),
|
||||
zap.String("remoteAddress", c.ClientIP()),
|
||||
zap.String(HTTPMethodAttr, c.Request.Method),
|
||||
zap.String("path", path),
|
||||
zap.Int("bodySize", c.Writer.Size()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func MustGet(c *gin.Context) Logger {
|
||||
return c.MustGet("logger").(Logger)
|
||||
}
|
||||
|
||||
func MustGetReal(c *gin.Context) Logger {
|
||||
log, ok := c.Get(LoggerContextKey)
|
||||
if _, isNil := log.(*Nil); !ok || isNil {
|
||||
return c.MustGet(LoggerRealContextKey).(Logger)
|
||||
}
|
||||
return log.(Logger)
|
||||
}
|
||||
|
||||
var (
|
||||
hasParamsMatcher = regexp.MustCompile(`/:\w+`)
|
||||
hasWildcardParamsMatcher = regexp.MustCompile(`/\*\w+.*`)
|
||||
)
|
||||
|
||||
type skippedPath struct {
|
||||
path string
|
||||
expr *regexp.Regexp
|
||||
}
|
||||
|
||||
// newSkippedPath returns new path skipping struct. It returns nil, false if expr is simple and
|
||||
// no complex logic is needed.
|
||||
func newSkippedPath(expr string) (result *skippedPath, compatible bool) {
|
||||
hasParams, hasWildcard := hasParamsMatcher.MatchString(expr), hasWildcardParamsMatcher.MatchString(expr)
|
||||
if !hasParams && !hasWildcard {
|
||||
return nil, false
|
||||
}
|
||||
if hasWildcard {
|
||||
return &skippedPath{expr: matcherForWildcard(expr)}, true
|
||||
}
|
||||
return &skippedPath{path: matcherForPath(expr)}, true
|
||||
}
|
||||
|
||||
func matcherForWildcard(expr string) *regexp.Regexp {
|
||||
return regexp.MustCompile(hasWildcardParamsMatcher.ReplaceAllString(expr, "/[\\w/]+"))
|
||||
}
|
||||
|
||||
func matcherForPath(expr string) string {
|
||||
return hasParamsMatcher.ReplaceAllString(expr, "/*")
|
||||
}
|
||||
|
||||
func (p *skippedPath) match(route string) bool {
|
||||
if p.expr != nil {
|
||||
return p.expr.MatchString(route)
|
||||
}
|
||||
result, err := path.Match(p.path, route)
|
||||
return result && err == nil
|
||||
}
|
||||
|
@ -1,6 +1,8 @@
|
||||
package logger
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
@ -11,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
func TestGinMiddleware(t *testing.T) {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
rr := httptest.NewRecorder()
|
||||
r := gin.New()
|
||||
r.Use(GinMiddleware(log))
|
||||
@ -41,3 +43,81 @@ func TestGinMiddleware(t *testing.T) {
|
||||
assert.NotEmpty(t, items[1].Context["path"])
|
||||
assert.NotEmpty(t, items[1].Context["bodySize"])
|
||||
}
|
||||
|
||||
func TestGinMiddleware_SkipPaths(t *testing.T) {
|
||||
log := newBufferLoggerSilent()
|
||||
rr := httptest.NewRecorder()
|
||||
r := gin.New()
|
||||
r.Use(GinMiddleware(log, "/hidden", "/hidden/:id", "/superhidden/*id"))
|
||||
r.GET("/hidden", func(c *gin.Context) {
|
||||
log := MustGet(c)
|
||||
log.Info("hidden message from /hidden")
|
||||
realLog := MustGetReal(c)
|
||||
realLog.Info("visible message from /hidden")
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
r.GET("/logged", func(c *gin.Context) {
|
||||
log := MustGet(c)
|
||||
log.Info("visible message from /logged")
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
r.GET("/hidden/:id", func(c *gin.Context) {
|
||||
log := MustGet(c)
|
||||
log.Info("hidden message from /hidden/:id")
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
r.GET("/superhidden/*id", func(c *gin.Context) {
|
||||
log := MustGet(c)
|
||||
log.Info("hidden message from /superhidden/*id")
|
||||
c.JSON(http.StatusOK, gin.H{})
|
||||
})
|
||||
|
||||
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/hidden", nil))
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/logged", nil))
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/hidden/param", nil))
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/superhidden/param/1/2/3", nil))
|
||||
require.Equal(t, http.StatusOK, rr.Code)
|
||||
|
||||
items, err := newJSONBufferedLogger(log).ScanAll()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, items, 3, printEntries(items))
|
||||
assert.Equal(t, "visible message from /hidden", items[0].Message, printEntries(items))
|
||||
assert.Equal(t, "visible message from /logged", items[1].Message, printEntries(items))
|
||||
}
|
||||
|
||||
func TestSkippedPath(t *testing.T) {
|
||||
cases := map[string]map[string]bool{
|
||||
"/hidden/:id": {
|
||||
"/hidden/1": true,
|
||||
"/hidden/2/3": false,
|
||||
},
|
||||
"/hidden/*id": {
|
||||
"/hidden/1": true,
|
||||
"/hidden/2/3": true,
|
||||
},
|
||||
}
|
||||
|
||||
for pattern, items := range cases {
|
||||
matcher, ok := newSkippedPath(pattern)
|
||||
require.True(t, ok)
|
||||
|
||||
for item, result := range items {
|
||||
assert.Equal(t, result, matcher.match(item), fmt.Sprintf(`"%s" does not match "%s", internals: %#v`,
|
||||
pattern, item, matcher))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func printEntries(entries []logRecord) string {
|
||||
data, err := json.MarshalIndent(entries, "", " ")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestWriterAdapter(t *testing.T) {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
adapter := WriterAdapter(log, zap.InfoLevel)
|
||||
|
||||
msg := []byte("hello world")
|
||||
|
@ -9,7 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func TestZabbixCollectorAdapter(t *testing.T) {
|
||||
log := newBufferLogger()
|
||||
log := newBufferLoggerSilent()
|
||||
adapter := ZabbixCollectorAdapter(log)
|
||||
adapter.Errorf("highly unexpected error: %s", "unexpected error")
|
||||
adapter.Errorf("cannot stop collector: %s", "app error")
|
||||
|
@ -42,12 +42,30 @@ type BufferLogger struct {
|
||||
}
|
||||
|
||||
// NewBufferedLogger returns new BufferedLogger instance.
|
||||
func NewBufferedLogger() BufferedLogger {
|
||||
func NewBufferedLogger(level ...zapcore.Level) BufferedLogger {
|
||||
lvl := zapcore.DebugLevel
|
||||
if len(level) > 0 {
|
||||
lvl = level[0]
|
||||
}
|
||||
bl := &BufferLogger{}
|
||||
bl.Logger = zap.New(
|
||||
zapcore.NewCore(
|
||||
logger.NewJSONWithContextEncoder(
|
||||
logger.EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), zapcore.DebugLevel))
|
||||
logger.EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), lvl))
|
||||
return bl
|
||||
}
|
||||
|
||||
// NewBufferedLoggerSilent returns new BufferedLogger instance which won't duplicate entries to stdout/stderr.
|
||||
func NewBufferedLoggerSilent(level ...zapcore.Level) BufferedLogger {
|
||||
lvl := zapcore.DebugLevel
|
||||
if len(level) > 0 {
|
||||
lvl = level[0]
|
||||
}
|
||||
bl := &BufferLogger{}
|
||||
bl.Logger = zap.New(
|
||||
zapcore.NewCore(
|
||||
logger.NewJSONWithContextEncoder(
|
||||
logger.EncoderConfigJSON()), &bl.buf, lvl))
|
||||
return bl
|
||||
}
|
||||
|
||||
|
@ -2,6 +2,7 @@ package testutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/suite"
|
||||
@ -9,7 +10,8 @@ import (
|
||||
|
||||
type BufferLoggerTest struct {
|
||||
suite.Suite
|
||||
logger BufferedLogger
|
||||
logger BufferedLogger
|
||||
silentLogger BufferedLogger
|
||||
}
|
||||
|
||||
func TestBufferLogger(t *testing.T) {
|
||||
@ -18,10 +20,12 @@ func TestBufferLogger(t *testing.T) {
|
||||
|
||||
func (t *BufferLoggerTest) SetupSuite() {
|
||||
t.logger = NewBufferedLogger()
|
||||
t.silentLogger = NewBufferedLoggerSilent()
|
||||
}
|
||||
|
||||
func (t *BufferLoggerTest) SetupTest() {
|
||||
t.logger.Reset()
|
||||
t.silentLogger.Reset()
|
||||
}
|
||||
|
||||
func (t *BufferLoggerTest) Test_Read() {
|
||||
@ -31,25 +35,62 @@ func (t *BufferLoggerTest) Test_Read() {
|
||||
t.Require().NoError(err)
|
||||
t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"")
|
||||
t.Assert().Contains(string(data), "\"message\":\"test\"")
|
||||
|
||||
t.silentLogger.Debug("test")
|
||||
|
||||
data, err = io.ReadAll(t.silentLogger)
|
||||
t.Require().NoError(err)
|
||||
t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"")
|
||||
t.Assert().Contains(string(data), "\"message\":\"test\"")
|
||||
}
|
||||
|
||||
func (t *BufferLoggerTest) Test_Bytes() {
|
||||
t.logger.Debug("test")
|
||||
t.Assert().Contains(string(t.logger.Bytes()), "\"level_name\":\"DEBUG\"")
|
||||
t.Assert().Contains(string(t.logger.Bytes()), "\"message\":\"test\"")
|
||||
|
||||
t.silentLogger.Debug("test")
|
||||
t.Assert().Contains(string(t.silentLogger.Bytes()), "\"level_name\":\"DEBUG\"")
|
||||
t.Assert().Contains(string(t.silentLogger.Bytes()), "\"message\":\"test\"")
|
||||
}
|
||||
|
||||
func (t *BufferLoggerTest) Test_String() {
|
||||
t.logger.Debug("test")
|
||||
t.Assert().Contains(t.logger.String(), "\"level_name\":\"DEBUG\"")
|
||||
t.Assert().Contains(t.logger.String(), "\"message\":\"test\"")
|
||||
|
||||
t.silentLogger.Debug("test")
|
||||
t.Assert().Contains(t.silentLogger.String(), "\"level_name\":\"DEBUG\"")
|
||||
t.Assert().Contains(t.silentLogger.String(), "\"message\":\"test\"")
|
||||
}
|
||||
|
||||
func (t *BufferLoggerTest) TestRace() {
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
starter sync.WaitGroup
|
||||
)
|
||||
starter.Add(1)
|
||||
wg.Add(4)
|
||||
go func() {
|
||||
starter.Wait()
|
||||
t.logger.Debug("test")
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
starter.Wait()
|
||||
t.logger.String()
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
starter.Wait()
|
||||
t.silentLogger.Debug("test")
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
starter.Wait()
|
||||
t.silentLogger.String()
|
||||
wg.Done()
|
||||
}()
|
||||
starter.Done()
|
||||
wg.Wait()
|
||||
}
|
||||
|
@ -1,13 +1,77 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"go.uber.org/atomic"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin/binding"
|
||||
"github.com/go-playground/validator/v10"
|
||||
)
|
||||
|
||||
var (
|
||||
crmDomainStore = &domainStore{
|
||||
source: GetSaasDomains,
|
||||
matcher: func(domain string, domains []Domain) bool {
|
||||
if len(domains) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
secondLevel := strings.Join(strings.Split(domain, ".")[1:], ".")
|
||||
|
||||
for _, crmDomain := range domains {
|
||||
if crmDomain.Domain == secondLevel {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
},
|
||||
}
|
||||
boxDomainStore = &domainStore{
|
||||
source: GetBoxDomains,
|
||||
matcher: func(domain string, domains []Domain) bool {
|
||||
if len(domains) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, crmDomain := range domains {
|
||||
if crmDomain.Domain == domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
type domainStore struct {
|
||||
domains []Domain
|
||||
mutex sync.RWMutex
|
||||
source func() []Domain
|
||||
matcher func(string, []Domain) bool
|
||||
lastUpdate atomic.Time
|
||||
}
|
||||
|
||||
func (ds *domainStore) match(domain string) bool {
|
||||
if time.Since(ds.lastUpdate.Load()) > time.Hour {
|
||||
ds.update()
|
||||
}
|
||||
defer ds.mutex.RUnlock()
|
||||
ds.mutex.RLock()
|
||||
return ds.matcher(domain, ds.domains)
|
||||
}
|
||||
|
||||
func (ds *domainStore) update() {
|
||||
defer ds.mutex.Unlock()
|
||||
ds.mutex.Lock()
|
||||
ds.domains = ds.source()
|
||||
ds.lastUpdate.Store(time.Now())
|
||||
}
|
||||
|
||||
// init here will register `validateCrmURL` function for gin validator.
|
||||
func init() {
|
||||
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
|
||||
@ -26,40 +90,16 @@ func validateCrmURL(fl validator.FieldLevel) bool {
|
||||
|
||||
func isDomainValid(crmURL string) bool {
|
||||
parseURL, err := url.ParseRequestURI(crmURL)
|
||||
|
||||
if err != nil || nil == parseURL || !checkURLString(parseURL) {
|
||||
return false
|
||||
}
|
||||
|
||||
mainDomain := getMainDomain(parseURL.Hostname())
|
||||
|
||||
if checkDomains(GetSaasDomains(), mainDomain) {
|
||||
hostname := parseURL.Hostname()
|
||||
if crmDomainStore.match(hostname) {
|
||||
return true
|
||||
}
|
||||
|
||||
if checkDomains(GetBoxDomains(), parseURL.Hostname()) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func checkDomains(crmDomains []Domain, domain string) bool {
|
||||
if nil == crmDomains {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, crmDomain := range crmDomains {
|
||||
if crmDomain.Domain == domain {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func getMainDomain(hostname string) (mainDomain string) {
|
||||
return strings.Join(strings.Split(hostname, ".")[1:], ".")
|
||||
return boxDomainStore.match(hostname)
|
||||
}
|
||||
|
||||
func checkURLString(parseURL *url.URL) bool {
|
||||
|
@ -29,6 +29,11 @@ func (s *ValidatorSuite) SetupSuite() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *ValidatorSuite) SetupTest() {
|
||||
crmDomainStore.update()
|
||||
boxDomainStore.update()
|
||||
}
|
||||
|
||||
func (s *ValidatorSuite) getError(err error) string {
|
||||
if err == nil {
|
||||
return ""
|
||||
@ -72,14 +77,14 @@ func (s *ValidatorSuite) Test_ValidationSuccess() {
|
||||
"https://test.retailcrm.pro",
|
||||
"https://raisa.retailcrm.es",
|
||||
"https://blabla.simla.com",
|
||||
"https://blabla.ecomlogic.com",
|
||||
"https://crm.baucenter.ru",
|
||||
"https://crm.holodilnik.ru",
|
||||
"https://crm.eco.lanit.ru",
|
||||
"https://ecom.inventive.ru",
|
||||
"https://retailcrm.tvoydom.ru",
|
||||
}
|
||||
|
||||
for _, domain := range boxDomainStore.domains {
|
||||
crmDomains = append(crmDomains, "https://"+domain.Domain)
|
||||
}
|
||||
|
||||
s.Assert().True(len(crmDomains) > 4, "No box domains were tested, test is incomplete!")
|
||||
|
||||
for _, domain := range crmDomains {
|
||||
conn := models.Connection{
|
||||
Key: "key",
|
||||
@ -87,6 +92,6 @@ func (s *ValidatorSuite) Test_ValidationSuccess() {
|
||||
}
|
||||
|
||||
err := s.engine.Struct(conn)
|
||||
assert.NoError(s.T(), err, s.getError(err))
|
||||
assert.NoError(s.T(), err, domain+": "+s.getError(err))
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user