diff --git a/core/engine.go b/core/engine.go index 72dfd48..133b6c3 100644 --- a/core/engine.go +++ b/core/engine.go @@ -27,7 +27,10 @@ import ( "github.com/retailcrm/mg-transport-core/v2/core/logger" ) -const DefaultHTTPClientTimeout time.Duration = 30 +const ( + DefaultHTTPClientTimeout time.Duration = 30 + AppContextKey = "app" +) var boolTrue = true @@ -110,6 +113,9 @@ func (e *Engine) initGin() { } r := gin.New() + r.Use(func(c *gin.Context) { + c.Set(AppContextKey, e) + }) e.buildSentryConfig() e.InitSentrySDK() @@ -414,3 +420,19 @@ func (e *Engine) buildSentryConfig() { Debug: e.Config.IsDebug(), } } + +func GetApp(c *gin.Context) (app *Engine, exists bool) { + item, exists := c.Get(AppContextKey) + if !exists { + return nil, false + } + converted, ok := item.(*Engine) + if !ok { + return nil, false + } + return converted, true +} + +func MustGetApp(c *gin.Context) *Engine { + return c.MustGet(AppContextKey).(*Engine) +} diff --git a/core/logger/buffer_logger_test.go b/core/logger/buffer_logger_test.go index f21ca56..b7b9fe4 100644 --- a/core/logger/buffer_logger_test.go +++ b/core/logger/buffer_logger_test.go @@ -5,7 +5,6 @@ import ( "bytes" "encoding/json" "io" - "os" "sync" "github.com/guregu/null/v5" @@ -32,7 +31,7 @@ type jSONRecordScanner struct { func newJSONBufferedLogger(buf *bufferLogger) *jSONRecordScanner { if buf == nil { - buf = newBufferLogger() + buf = newBufferLoggerSilent() } return &jSONRecordScanner{scan: bufio.NewScanner(buf), buf: buf} } @@ -59,13 +58,17 @@ type bufferLogger struct { buf lockableBuffer } -// NewBufferedLogger returns new BufferedLogger instance. -func newBufferLogger() *bufferLogger { +// newBufferLoggerSilent returns new BufferedLogger instance which won't duplicate entries to stdout/stderr. +func newBufferLoggerSilent(level ...zapcore.Level) *bufferLogger { + lvl := zapcore.DebugLevel + if len(level) > 0 { + lvl = level[0] + } bl := &bufferLogger{} bl.Logger = zap.New( zapcore.NewCore( NewJSONWithContextEncoder( - EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), zapcore.DebugLevel)) + EncoderConfigJSON()), &bl.buf, lvl)) return bl } diff --git a/core/logger/default_test.go b/core/logger/default_test.go index 7a70b47..5a17bb0 100644 --- a/core/logger/default_test.go +++ b/core/logger/default_test.go @@ -32,7 +32,7 @@ func (s *TestDefaultSuite) TestNewDefault_Panic() { } func (s *TestDefaultSuite) TestWith() { - log := newBufferLogger() + log := newBufferLoggerSilent() log.With(zap.String(HandlerAttr, "Handler")).Info("test") items, err := newJSONBufferedLogger(log).ScanAll() @@ -42,7 +42,7 @@ func (s *TestDefaultSuite) TestWith() { } func (s *TestDefaultSuite) TestWithLazy() { - log := newBufferLogger() + log := newBufferLoggerSilent() log.WithLazy(zap.String(HandlerAttr, "Handler")).Info("test") items, err := newJSONBufferedLogger(log).ScanAll() @@ -52,7 +52,7 @@ func (s *TestDefaultSuite) TestWithLazy() { } func (s *TestDefaultSuite) TestForHandler() { - log := newBufferLogger() + log := newBufferLoggerSilent() log.ForHandler("Handler").Info("test") items, err := newJSONBufferedLogger(log).ScanAll() @@ -62,7 +62,7 @@ func (s *TestDefaultSuite) TestForHandler() { } func (s *TestDefaultSuite) TestForHandlerNoDuplicate() { - log := newBufferLogger() + log := newBufferLoggerSilent() log.ForHandler("handler1").ForHandler("handler2").Info("test") s.Assert().Contains(log.String(), "handler2") @@ -70,7 +70,7 @@ func (s *TestDefaultSuite) TestForHandlerNoDuplicate() { } func (s *TestDefaultSuite) TestForConnection() { - log := newBufferLogger() + log := newBufferLoggerSilent() log.ForConnection("connection").Info("test") items, err := newJSONBufferedLogger(log).ScanAll() @@ -80,7 +80,7 @@ func (s *TestDefaultSuite) TestForConnection() { } func (s *TestDefaultSuite) TestForConnectionNoDuplicate() { - log := newBufferLogger() + log := newBufferLoggerSilent() log.ForConnection("conn1").ForConnection("conn2").Info("test") s.Assert().Contains(log.String(), "conn2") @@ -88,7 +88,7 @@ func (s *TestDefaultSuite) TestForConnectionNoDuplicate() { } func (s *TestDefaultSuite) TestForAccount() { - log := newBufferLogger() + log := newBufferLoggerSilent() log.ForAccount("account").Info("test") items, err := newJSONBufferedLogger(log).ScanAll() @@ -98,7 +98,7 @@ func (s *TestDefaultSuite) TestForAccount() { } func (s *TestDefaultSuite) TestForAccountNoDuplicate() { - log := newBufferLogger() + log := newBufferLoggerSilent() log.ForAccount("acc1").ForAccount("acc2").Info("test") s.Assert().Contains(log.String(), "acc2") @@ -106,7 +106,7 @@ func (s *TestDefaultSuite) TestForAccountNoDuplicate() { } func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() { - log := newBufferLogger() + log := newBufferLoggerSilent() log. ForHandler("handler1"). ForHandler("handler2"). @@ -126,7 +126,7 @@ func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() { // TestPersistRecordsIncompatibleWith is not a unit test, but rather a demonstration how you shouldn't use For* methods. func (s *TestDefaultSuite) TestPersistRecordsIncompatibleWith() { - log := newBufferLogger() + log := newBufferLoggerSilent() log. ForHandler("handler1"). With(zap.Int("f1", 1)). diff --git a/core/logger/gin.go b/core/logger/gin.go index 9c6ded0..73afc96 100644 --- a/core/logger/gin.go +++ b/core/logger/gin.go @@ -1,24 +1,64 @@ package logger import ( + "path" + "regexp" "time" "github.com/gin-gonic/gin" "go.uber.org/zap" ) +const ( + LoggerContextKey = "logger" + LoggerRealContextKey = "loggerReal" +) + // GinMiddleware will construct Gin middleware which will log requests and provide logger with unique request ID. -func GinMiddleware(log Logger) gin.HandlerFunc { +func GinMiddleware(log Logger, skipPaths ...string) gin.HandlerFunc { + var ( + skip map[string]struct{} + matchSkip []*skippedPath + ) + if length := len(skipPaths); length > 0 { + skip = make(map[string]struct{}, length) + + for _, path := range skipPaths { + if skipped, ok := newSkippedPath(path); ok { + matchSkip = append(matchSkip, skipped) + continue + } + skip[path] = struct{}{} + } + } + + nilLogger := NewNil() + return func(c *gin.Context) { // Start timer start := time.Now() path := c.Request.URL.Path raw := c.Request.URL.RawQuery - + _, shouldSkip := skip[path] streamID := generateStreamID() log := log.With(StreamID(streamID)) + + if !shouldSkip && len(matchSkip) > 0 { + for _, skipper := range matchSkip { + if skipper.match(path) { + shouldSkip = true + break + } + } + } + + if shouldSkip { + c.Set(LoggerRealContextKey, log) + log = nilLogger + } + c.Set(StreamIDAttr, streamID) - c.Set("logger", log) + c.Set(LoggerContextKey, log) // Process request c.Next() @@ -28,19 +68,68 @@ func GinMiddleware(log Logger) gin.HandlerFunc { path = path + "?" + raw } - log.Info("request", - zap.String(HandlerAttr, "GIN"), - zap.String("startTime", start.Format(time.RFC3339)), - zap.String("endTime", end.Format(time.RFC3339)), - zap.Any("latency", end.Sub(start)/time.Millisecond), - zap.String("remoteAddress", c.ClientIP()), - zap.String(HTTPMethodAttr, c.Request.Method), - zap.String("path", path), - zap.Int("bodySize", c.Writer.Size()), - ) + if !shouldSkip { + log.Info("request", + zap.String(HandlerAttr, "GIN"), + zap.String("startTime", start.Format(time.RFC3339)), + zap.String("endTime", end.Format(time.RFC3339)), + zap.Any("latency", end.Sub(start)/time.Millisecond), + zap.String("remoteAddress", c.ClientIP()), + zap.String(HTTPMethodAttr, c.Request.Method), + zap.String("path", path), + zap.Int("bodySize", c.Writer.Size()), + ) + } } } func MustGet(c *gin.Context) Logger { return c.MustGet("logger").(Logger) } + +func MustGetReal(c *gin.Context) Logger { + log, ok := c.Get(LoggerContextKey) + if _, isNil := log.(*Nil); !ok || isNil { + return c.MustGet(LoggerRealContextKey).(Logger) + } + return log.(Logger) +} + +var ( + hasParamsMatcher = regexp.MustCompile(`/:\w+`) + hasWildcardParamsMatcher = regexp.MustCompile(`/\*\w+.*`) +) + +type skippedPath struct { + path string + expr *regexp.Regexp +} + +// newSkippedPath returns new path skipping struct. It returns nil, false if expr is simple and +// no complex logic is needed. +func newSkippedPath(expr string) (result *skippedPath, compatible bool) { + hasParams, hasWildcard := hasParamsMatcher.MatchString(expr), hasWildcardParamsMatcher.MatchString(expr) + if !hasParams && !hasWildcard { + return nil, false + } + if hasWildcard { + return &skippedPath{expr: matcherForWildcard(expr)}, true + } + return &skippedPath{path: matcherForPath(expr)}, true +} + +func matcherForWildcard(expr string) *regexp.Regexp { + return regexp.MustCompile(hasWildcardParamsMatcher.ReplaceAllString(expr, "/[\\w/]+")) +} + +func matcherForPath(expr string) string { + return hasParamsMatcher.ReplaceAllString(expr, "/*") +} + +func (p *skippedPath) match(route string) bool { + if p.expr != nil { + return p.expr.MatchString(route) + } + result, err := path.Match(p.path, route) + return result && err == nil +} diff --git a/core/logger/gin_test.go b/core/logger/gin_test.go index 6558eda..fb35633 100644 --- a/core/logger/gin_test.go +++ b/core/logger/gin_test.go @@ -1,6 +1,8 @@ package logger import ( + "encoding/json" + "fmt" "net/http" "net/http/httptest" "testing" @@ -11,7 +13,7 @@ import ( ) func TestGinMiddleware(t *testing.T) { - log := newBufferLogger() + log := newBufferLoggerSilent() rr := httptest.NewRecorder() r := gin.New() r.Use(GinMiddleware(log)) @@ -41,3 +43,81 @@ func TestGinMiddleware(t *testing.T) { assert.NotEmpty(t, items[1].Context["path"]) assert.NotEmpty(t, items[1].Context["bodySize"]) } + +func TestGinMiddleware_SkipPaths(t *testing.T) { + log := newBufferLoggerSilent() + rr := httptest.NewRecorder() + r := gin.New() + r.Use(GinMiddleware(log, "/hidden", "/hidden/:id", "/superhidden/*id")) + r.GET("/hidden", func(c *gin.Context) { + log := MustGet(c) + log.Info("hidden message from /hidden") + realLog := MustGetReal(c) + realLog.Info("visible message from /hidden") + c.JSON(http.StatusOK, gin.H{}) + }) + r.GET("/logged", func(c *gin.Context) { + log := MustGet(c) + log.Info("visible message from /logged") + c.JSON(http.StatusOK, gin.H{}) + }) + r.GET("/hidden/:id", func(c *gin.Context) { + log := MustGet(c) + log.Info("hidden message from /hidden/:id") + c.JSON(http.StatusOK, gin.H{}) + }) + r.GET("/superhidden/*id", func(c *gin.Context) { + log := MustGet(c) + log.Info("hidden message from /superhidden/*id") + c.JSON(http.StatusOK, gin.H{}) + }) + + r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/hidden", nil)) + require.Equal(t, http.StatusOK, rr.Code) + + r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/logged", nil)) + require.Equal(t, http.StatusOK, rr.Code) + + r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/hidden/param", nil)) + require.Equal(t, http.StatusOK, rr.Code) + + r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/superhidden/param/1/2/3", nil)) + require.Equal(t, http.StatusOK, rr.Code) + + items, err := newJSONBufferedLogger(log).ScanAll() + require.NoError(t, err) + require.Len(t, items, 3, printEntries(items)) + assert.Equal(t, "visible message from /hidden", items[0].Message, printEntries(items)) + assert.Equal(t, "visible message from /logged", items[1].Message, printEntries(items)) +} + +func TestSkippedPath(t *testing.T) { + cases := map[string]map[string]bool{ + "/hidden/:id": { + "/hidden/1": true, + "/hidden/2/3": false, + }, + "/hidden/*id": { + "/hidden/1": true, + "/hidden/2/3": true, + }, + } + + for pattern, items := range cases { + matcher, ok := newSkippedPath(pattern) + require.True(t, ok) + + for item, result := range items { + assert.Equal(t, result, matcher.match(item), fmt.Sprintf(`"%s" does not match "%s", internals: %#v`, + pattern, item, matcher)) + } + } +} + +func printEntries(entries []logRecord) string { + data, err := json.MarshalIndent(entries, "", " ") + if err != nil { + panic(err) + } + return string(data) +} diff --git a/core/logger/writer_adapter_test.go b/core/logger/writer_adapter_test.go index 68e34d2..8662d7c 100644 --- a/core/logger/writer_adapter_test.go +++ b/core/logger/writer_adapter_test.go @@ -9,7 +9,7 @@ import ( ) func TestWriterAdapter(t *testing.T) { - log := newBufferLogger() + log := newBufferLoggerSilent() adapter := WriterAdapter(log, zap.InfoLevel) msg := []byte("hello world") diff --git a/core/logger/zabbix_collector_adapter_test.go b/core/logger/zabbix_collector_adapter_test.go index 08f9dfb..b4190ea 100644 --- a/core/logger/zabbix_collector_adapter_test.go +++ b/core/logger/zabbix_collector_adapter_test.go @@ -9,7 +9,7 @@ import ( ) func TestZabbixCollectorAdapter(t *testing.T) { - log := newBufferLogger() + log := newBufferLoggerSilent() adapter := ZabbixCollectorAdapter(log) adapter.Errorf("highly unexpected error: %s", "unexpected error") adapter.Errorf("cannot stop collector: %s", "app error") diff --git a/core/util/testutil/buffer_logger.go b/core/util/testutil/buffer_logger.go index 5659ede..0ea0ddf 100644 --- a/core/util/testutil/buffer_logger.go +++ b/core/util/testutil/buffer_logger.go @@ -42,12 +42,30 @@ type BufferLogger struct { } // NewBufferedLogger returns new BufferedLogger instance. -func NewBufferedLogger() BufferedLogger { +func NewBufferedLogger(level ...zapcore.Level) BufferedLogger { + lvl := zapcore.DebugLevel + if len(level) > 0 { + lvl = level[0] + } bl := &BufferLogger{} bl.Logger = zap.New( zapcore.NewCore( logger.NewJSONWithContextEncoder( - logger.EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), zapcore.DebugLevel)) + logger.EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), lvl)) + return bl +} + +// NewBufferedLoggerSilent returns new BufferedLogger instance which won't duplicate entries to stdout/stderr. +func NewBufferedLoggerSilent(level ...zapcore.Level) BufferedLogger { + lvl := zapcore.DebugLevel + if len(level) > 0 { + lvl = level[0] + } + bl := &BufferLogger{} + bl.Logger = zap.New( + zapcore.NewCore( + logger.NewJSONWithContextEncoder( + logger.EncoderConfigJSON()), &bl.buf, lvl)) return bl } diff --git a/core/util/testutil/buffer_logger_test.go b/core/util/testutil/buffer_logger_test.go index 8d0771d..8eb5ac8 100644 --- a/core/util/testutil/buffer_logger_test.go +++ b/core/util/testutil/buffer_logger_test.go @@ -2,6 +2,7 @@ package testutil import ( "io" + "sync" "testing" "github.com/stretchr/testify/suite" @@ -9,7 +10,8 @@ import ( type BufferLoggerTest struct { suite.Suite - logger BufferedLogger + logger BufferedLogger + silentLogger BufferedLogger } func TestBufferLogger(t *testing.T) { @@ -18,10 +20,12 @@ func TestBufferLogger(t *testing.T) { func (t *BufferLoggerTest) SetupSuite() { t.logger = NewBufferedLogger() + t.silentLogger = NewBufferedLoggerSilent() } func (t *BufferLoggerTest) SetupTest() { t.logger.Reset() + t.silentLogger.Reset() } func (t *BufferLoggerTest) Test_Read() { @@ -31,25 +35,62 @@ func (t *BufferLoggerTest) Test_Read() { t.Require().NoError(err) t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"") t.Assert().Contains(string(data), "\"message\":\"test\"") + + t.silentLogger.Debug("test") + + data, err = io.ReadAll(t.silentLogger) + t.Require().NoError(err) + t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"") + t.Assert().Contains(string(data), "\"message\":\"test\"") } func (t *BufferLoggerTest) Test_Bytes() { t.logger.Debug("test") t.Assert().Contains(string(t.logger.Bytes()), "\"level_name\":\"DEBUG\"") t.Assert().Contains(string(t.logger.Bytes()), "\"message\":\"test\"") + + t.silentLogger.Debug("test") + t.Assert().Contains(string(t.silentLogger.Bytes()), "\"level_name\":\"DEBUG\"") + t.Assert().Contains(string(t.silentLogger.Bytes()), "\"message\":\"test\"") } func (t *BufferLoggerTest) Test_String() { t.logger.Debug("test") t.Assert().Contains(t.logger.String(), "\"level_name\":\"DEBUG\"") t.Assert().Contains(t.logger.String(), "\"message\":\"test\"") + + t.silentLogger.Debug("test") + t.Assert().Contains(t.silentLogger.String(), "\"level_name\":\"DEBUG\"") + t.Assert().Contains(t.silentLogger.String(), "\"message\":\"test\"") } func (t *BufferLoggerTest) TestRace() { + var ( + wg sync.WaitGroup + starter sync.WaitGroup + ) + starter.Add(1) + wg.Add(4) go func() { + starter.Wait() t.logger.Debug("test") + wg.Done() }() go func() { + starter.Wait() t.logger.String() + wg.Done() }() + go func() { + starter.Wait() + t.silentLogger.Debug("test") + wg.Done() + }() + go func() { + starter.Wait() + t.silentLogger.String() + wg.Done() + }() + starter.Done() + wg.Wait() } diff --git a/core/validator.go b/core/validator.go index 61e93d4..92f1b87 100644 --- a/core/validator.go +++ b/core/validator.go @@ -1,13 +1,77 @@ package core import ( + "go.uber.org/atomic" "net/url" "strings" + "sync" + "time" "github.com/gin-gonic/gin/binding" "github.com/go-playground/validator/v10" ) +var ( + crmDomainStore = &domainStore{ + source: GetSaasDomains, + matcher: func(domain string, domains []Domain) bool { + if len(domains) == 0 { + return false + } + + secondLevel := strings.Join(strings.Split(domain, ".")[1:], ".") + + for _, crmDomain := range domains { + if crmDomain.Domain == secondLevel { + return true + } + } + + return false + }, + } + boxDomainStore = &domainStore{ + source: GetBoxDomains, + matcher: func(domain string, domains []Domain) bool { + if len(domains) == 0 { + return false + } + + for _, crmDomain := range domains { + if crmDomain.Domain == domain { + return true + } + } + + return false + }, + } +) + +type domainStore struct { + domains []Domain + mutex sync.RWMutex + source func() []Domain + matcher func(string, []Domain) bool + lastUpdate atomic.Time +} + +func (ds *domainStore) match(domain string) bool { + if time.Since(ds.lastUpdate.Load()) > time.Hour { + ds.update() + } + defer ds.mutex.RUnlock() + ds.mutex.RLock() + return ds.matcher(domain, ds.domains) +} + +func (ds *domainStore) update() { + defer ds.mutex.Unlock() + ds.mutex.Lock() + ds.domains = ds.source() + ds.lastUpdate.Store(time.Now()) +} + // init here will register `validateCrmURL` function for gin validator. func init() { if v, ok := binding.Validator.Engine().(*validator.Validate); ok { @@ -26,40 +90,16 @@ func validateCrmURL(fl validator.FieldLevel) bool { func isDomainValid(crmURL string) bool { parseURL, err := url.ParseRequestURI(crmURL) - if err != nil || nil == parseURL || !checkURLString(parseURL) { return false } - mainDomain := getMainDomain(parseURL.Hostname()) - - if checkDomains(GetSaasDomains(), mainDomain) { + hostname := parseURL.Hostname() + if crmDomainStore.match(hostname) { return true } - if checkDomains(GetBoxDomains(), parseURL.Hostname()) { - return true - } - - return false -} - -func checkDomains(crmDomains []Domain, domain string) bool { - if nil == crmDomains { - return false - } - - for _, crmDomain := range crmDomains { - if crmDomain.Domain == domain { - return true - } - } - - return false -} - -func getMainDomain(hostname string) (mainDomain string) { - return strings.Join(strings.Split(hostname, ".")[1:], ".") + return boxDomainStore.match(hostname) } func checkURLString(parseURL *url.URL) bool { diff --git a/core/validator_test.go b/core/validator_test.go index 274436c..a260f03 100644 --- a/core/validator_test.go +++ b/core/validator_test.go @@ -29,6 +29,11 @@ func (s *ValidatorSuite) SetupSuite() { } } +func (s *ValidatorSuite) SetupTest() { + crmDomainStore.update() + boxDomainStore.update() +} + func (s *ValidatorSuite) getError(err error) string { if err == nil { return "" @@ -72,14 +77,14 @@ func (s *ValidatorSuite) Test_ValidationSuccess() { "https://test.retailcrm.pro", "https://raisa.retailcrm.es", "https://blabla.simla.com", - "https://blabla.ecomlogic.com", - "https://crm.baucenter.ru", - "https://crm.holodilnik.ru", - "https://crm.eco.lanit.ru", - "https://ecom.inventive.ru", - "https://retailcrm.tvoydom.ru", } + for _, domain := range boxDomainStore.domains { + crmDomains = append(crmDomains, "https://"+domain.Domain) + } + + s.Assert().True(len(crmDomains) > 4, "No box domains were tested, test is incomplete!") + for _, domain := range crmDomains { conn := models.Connection{ Key: "key", @@ -87,6 +92,6 @@ func (s *ValidatorSuite) Test_ValidationSuccess() { } err := s.engine.Struct(conn) - assert.NoError(s.T(), err, s.getError(err)) + assert.NoError(s.T(), err, domain+": "+s.getError(err)) } }