add skipPaths to gin log middleware, refactor validator

This commit is contained in:
Pavel 2024-12-17 13:33:04 +03:00
parent b8ccfa8c8c
commit f5e2bfefb2
11 changed files with 367 additions and 69 deletions

View File

@ -27,7 +27,10 @@ import (
"github.com/retailcrm/mg-transport-core/v2/core/logger"
)
const DefaultHTTPClientTimeout time.Duration = 30
const (
DefaultHTTPClientTimeout time.Duration = 30
AppContextKey = "app"
)
var boolTrue = true
@ -110,6 +113,9 @@ func (e *Engine) initGin() {
}
r := gin.New()
r.Use(func(c *gin.Context) {
c.Set(AppContextKey, e)
})
e.buildSentryConfig()
e.InitSentrySDK()
@ -414,3 +420,19 @@ func (e *Engine) buildSentryConfig() {
Debug: e.Config.IsDebug(),
}
}
func GetApp(c *gin.Context) (app *Engine, exists bool) {
item, exists := c.Get(AppContextKey)
if !exists {
return nil, false
}
converted, ok := item.(*Engine)
if !ok {
return nil, false
}
return converted, true
}
func MustGetApp(c *gin.Context) *Engine {
return c.MustGet(AppContextKey).(*Engine)
}

View File

@ -5,7 +5,6 @@ import (
"bytes"
"encoding/json"
"io"
"os"
"sync"
"github.com/guregu/null/v5"
@ -32,7 +31,7 @@ type jSONRecordScanner struct {
func newJSONBufferedLogger(buf *bufferLogger) *jSONRecordScanner {
if buf == nil {
buf = newBufferLogger()
buf = newBufferLoggerSilent()
}
return &jSONRecordScanner{scan: bufio.NewScanner(buf), buf: buf}
}
@ -59,13 +58,17 @@ type bufferLogger struct {
buf lockableBuffer
}
// NewBufferedLogger returns new BufferedLogger instance.
func newBufferLogger() *bufferLogger {
// newBufferLoggerSilent returns new BufferedLogger instance which won't duplicate entries to stdout/stderr.
func newBufferLoggerSilent(level ...zapcore.Level) *bufferLogger {
lvl := zapcore.DebugLevel
if len(level) > 0 {
lvl = level[0]
}
bl := &bufferLogger{}
bl.Logger = zap.New(
zapcore.NewCore(
NewJSONWithContextEncoder(
EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), zapcore.DebugLevel))
EncoderConfigJSON()), &bl.buf, lvl))
return bl
}

View File

@ -32,7 +32,7 @@ func (s *TestDefaultSuite) TestNewDefault_Panic() {
}
func (s *TestDefaultSuite) TestWith() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.With(zap.String(HandlerAttr, "Handler")).Info("test")
items, err := newJSONBufferedLogger(log).ScanAll()
@ -42,7 +42,7 @@ func (s *TestDefaultSuite) TestWith() {
}
func (s *TestDefaultSuite) TestWithLazy() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.WithLazy(zap.String(HandlerAttr, "Handler")).Info("test")
items, err := newJSONBufferedLogger(log).ScanAll()
@ -52,7 +52,7 @@ func (s *TestDefaultSuite) TestWithLazy() {
}
func (s *TestDefaultSuite) TestForHandler() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.ForHandler("Handler").Info("test")
items, err := newJSONBufferedLogger(log).ScanAll()
@ -62,7 +62,7 @@ func (s *TestDefaultSuite) TestForHandler() {
}
func (s *TestDefaultSuite) TestForHandlerNoDuplicate() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.ForHandler("handler1").ForHandler("handler2").Info("test")
s.Assert().Contains(log.String(), "handler2")
@ -70,7 +70,7 @@ func (s *TestDefaultSuite) TestForHandlerNoDuplicate() {
}
func (s *TestDefaultSuite) TestForConnection() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.ForConnection("connection").Info("test")
items, err := newJSONBufferedLogger(log).ScanAll()
@ -80,7 +80,7 @@ func (s *TestDefaultSuite) TestForConnection() {
}
func (s *TestDefaultSuite) TestForConnectionNoDuplicate() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.ForConnection("conn1").ForConnection("conn2").Info("test")
s.Assert().Contains(log.String(), "conn2")
@ -88,7 +88,7 @@ func (s *TestDefaultSuite) TestForConnectionNoDuplicate() {
}
func (s *TestDefaultSuite) TestForAccount() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.ForAccount("account").Info("test")
items, err := newJSONBufferedLogger(log).ScanAll()
@ -98,7 +98,7 @@ func (s *TestDefaultSuite) TestForAccount() {
}
func (s *TestDefaultSuite) TestForAccountNoDuplicate() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.ForAccount("acc1").ForAccount("acc2").Info("test")
s.Assert().Contains(log.String(), "acc2")
@ -106,7 +106,7 @@ func (s *TestDefaultSuite) TestForAccountNoDuplicate() {
}
func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.
ForHandler("handler1").
ForHandler("handler2").
@ -126,7 +126,7 @@ func (s *TestDefaultSuite) TestNoDuplicatesPersistRecords() {
// TestPersistRecordsIncompatibleWith is not a unit test, but rather a demonstration how you shouldn't use For* methods.
func (s *TestDefaultSuite) TestPersistRecordsIncompatibleWith() {
log := newBufferLogger()
log := newBufferLoggerSilent()
log.
ForHandler("handler1").
With(zap.Int("f1", 1)).

View File

@ -1,24 +1,64 @@
package logger
import (
"path"
"regexp"
"time"
"github.com/gin-gonic/gin"
"go.uber.org/zap"
)
const (
LoggerContextKey = "logger"
LoggerRealContextKey = "loggerReal"
)
// GinMiddleware will construct Gin middleware which will log requests and provide logger with unique request ID.
func GinMiddleware(log Logger) gin.HandlerFunc {
func GinMiddleware(log Logger, skipPaths ...string) gin.HandlerFunc {
var (
skip map[string]struct{}
matchSkip []*skippedPath
)
if length := len(skipPaths); length > 0 {
skip = make(map[string]struct{}, length)
for _, path := range skipPaths {
if skipped, ok := newSkippedPath(path); ok {
matchSkip = append(matchSkip, skipped)
continue
}
skip[path] = struct{}{}
}
}
nilLogger := NewNil()
return func(c *gin.Context) {
// Start timer
start := time.Now()
path := c.Request.URL.Path
raw := c.Request.URL.RawQuery
_, shouldSkip := skip[path]
streamID := generateStreamID()
log := log.With(StreamID(streamID))
if !shouldSkip && len(matchSkip) > 0 {
for _, skipper := range matchSkip {
if skipper.match(path) {
shouldSkip = true
break
}
}
}
if shouldSkip {
c.Set(LoggerRealContextKey, log)
log = nilLogger
}
c.Set(StreamIDAttr, streamID)
c.Set("logger", log)
c.Set(LoggerContextKey, log)
// Process request
c.Next()
@ -28,19 +68,68 @@ func GinMiddleware(log Logger) gin.HandlerFunc {
path = path + "?" + raw
}
log.Info("request",
zap.String(HandlerAttr, "GIN"),
zap.String("startTime", start.Format(time.RFC3339)),
zap.String("endTime", end.Format(time.RFC3339)),
zap.Any("latency", end.Sub(start)/time.Millisecond),
zap.String("remoteAddress", c.ClientIP()),
zap.String(HTTPMethodAttr, c.Request.Method),
zap.String("path", path),
zap.Int("bodySize", c.Writer.Size()),
)
if !shouldSkip {
log.Info("request",
zap.String(HandlerAttr, "GIN"),
zap.String("startTime", start.Format(time.RFC3339)),
zap.String("endTime", end.Format(time.RFC3339)),
zap.Any("latency", end.Sub(start)/time.Millisecond),
zap.String("remoteAddress", c.ClientIP()),
zap.String(HTTPMethodAttr, c.Request.Method),
zap.String("path", path),
zap.Int("bodySize", c.Writer.Size()),
)
}
}
}
func MustGet(c *gin.Context) Logger {
return c.MustGet("logger").(Logger)
}
func MustGetReal(c *gin.Context) Logger {
log, ok := c.Get(LoggerContextKey)
if _, isNil := log.(*Nil); !ok || isNil {
return c.MustGet(LoggerRealContextKey).(Logger)
}
return log.(Logger)
}
var (
hasParamsMatcher = regexp.MustCompile(`/:\w+`)
hasWildcardParamsMatcher = regexp.MustCompile(`/\*\w+.*`)
)
type skippedPath struct {
path string
expr *regexp.Regexp
}
// newSkippedPath returns new path skipping struct. It returns nil, false if expr is simple and
// no complex logic is needed.
func newSkippedPath(expr string) (result *skippedPath, compatible bool) {
hasParams, hasWildcard := hasParamsMatcher.MatchString(expr), hasWildcardParamsMatcher.MatchString(expr)
if !hasParams && !hasWildcard {
return nil, false
}
if hasWildcard {
return &skippedPath{expr: matcherForWildcard(expr)}, true
}
return &skippedPath{path: matcherForPath(expr)}, true
}
func matcherForWildcard(expr string) *regexp.Regexp {
return regexp.MustCompile(hasWildcardParamsMatcher.ReplaceAllString(expr, "/[\\w/]+"))
}
func matcherForPath(expr string) string {
return hasParamsMatcher.ReplaceAllString(expr, "/*")
}
func (p *skippedPath) match(route string) bool {
if p.expr != nil {
return p.expr.MatchString(route)
}
result, err := path.Match(p.path, route)
return result && err == nil
}

View File

@ -1,6 +1,8 @@
package logger
import (
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"testing"
@ -11,7 +13,7 @@ import (
)
func TestGinMiddleware(t *testing.T) {
log := newBufferLogger()
log := newBufferLoggerSilent()
rr := httptest.NewRecorder()
r := gin.New()
r.Use(GinMiddleware(log))
@ -41,3 +43,81 @@ func TestGinMiddleware(t *testing.T) {
assert.NotEmpty(t, items[1].Context["path"])
assert.NotEmpty(t, items[1].Context["bodySize"])
}
func TestGinMiddleware_SkipPaths(t *testing.T) {
log := newBufferLoggerSilent()
rr := httptest.NewRecorder()
r := gin.New()
r.Use(GinMiddleware(log, "/hidden", "/hidden/:id", "/superhidden/*id"))
r.GET("/hidden", func(c *gin.Context) {
log := MustGet(c)
log.Info("hidden message from /hidden")
realLog := MustGetReal(c)
realLog.Info("visible message from /hidden")
c.JSON(http.StatusOK, gin.H{})
})
r.GET("/logged", func(c *gin.Context) {
log := MustGet(c)
log.Info("visible message from /logged")
c.JSON(http.StatusOK, gin.H{})
})
r.GET("/hidden/:id", func(c *gin.Context) {
log := MustGet(c)
log.Info("hidden message from /hidden/:id")
c.JSON(http.StatusOK, gin.H{})
})
r.GET("/superhidden/*id", func(c *gin.Context) {
log := MustGet(c)
log.Info("hidden message from /superhidden/*id")
c.JSON(http.StatusOK, gin.H{})
})
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/hidden", nil))
require.Equal(t, http.StatusOK, rr.Code)
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/logged", nil))
require.Equal(t, http.StatusOK, rr.Code)
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/hidden/param", nil))
require.Equal(t, http.StatusOK, rr.Code)
r.ServeHTTP(rr, httptest.NewRequest(http.MethodGet, "/superhidden/param/1/2/3", nil))
require.Equal(t, http.StatusOK, rr.Code)
items, err := newJSONBufferedLogger(log).ScanAll()
require.NoError(t, err)
require.Len(t, items, 3, printEntries(items))
assert.Equal(t, "visible message from /hidden", items[0].Message, printEntries(items))
assert.Equal(t, "visible message from /logged", items[1].Message, printEntries(items))
}
func TestSkippedPath(t *testing.T) {
cases := map[string]map[string]bool{
"/hidden/:id": {
"/hidden/1": true,
"/hidden/2/3": false,
},
"/hidden/*id": {
"/hidden/1": true,
"/hidden/2/3": true,
},
}
for pattern, items := range cases {
matcher, ok := newSkippedPath(pattern)
require.True(t, ok)
for item, result := range items {
assert.Equal(t, result, matcher.match(item), fmt.Sprintf(`"%s" does not match "%s", internals: %#v`,
pattern, item, matcher))
}
}
}
func printEntries(entries []logRecord) string {
data, err := json.MarshalIndent(entries, "", " ")
if err != nil {
panic(err)
}
return string(data)
}

View File

@ -9,7 +9,7 @@ import (
)
func TestWriterAdapter(t *testing.T) {
log := newBufferLogger()
log := newBufferLoggerSilent()
adapter := WriterAdapter(log, zap.InfoLevel)
msg := []byte("hello world")

View File

@ -9,7 +9,7 @@ import (
)
func TestZabbixCollectorAdapter(t *testing.T) {
log := newBufferLogger()
log := newBufferLoggerSilent()
adapter := ZabbixCollectorAdapter(log)
adapter.Errorf("highly unexpected error: %s", "unexpected error")
adapter.Errorf("cannot stop collector: %s", "app error")

View File

@ -42,12 +42,30 @@ type BufferLogger struct {
}
// NewBufferedLogger returns new BufferedLogger instance.
func NewBufferedLogger() BufferedLogger {
func NewBufferedLogger(level ...zapcore.Level) BufferedLogger {
lvl := zapcore.DebugLevel
if len(level) > 0 {
lvl = level[0]
}
bl := &BufferLogger{}
bl.Logger = zap.New(
zapcore.NewCore(
logger.NewJSONWithContextEncoder(
logger.EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), zapcore.DebugLevel))
logger.EncoderConfigJSON()), zap.CombineWriteSyncers(os.Stdout, os.Stderr, &bl.buf), lvl))
return bl
}
// NewBufferedLoggerSilent returns new BufferedLogger instance which won't duplicate entries to stdout/stderr.
func NewBufferedLoggerSilent(level ...zapcore.Level) BufferedLogger {
lvl := zapcore.DebugLevel
if len(level) > 0 {
lvl = level[0]
}
bl := &BufferLogger{}
bl.Logger = zap.New(
zapcore.NewCore(
logger.NewJSONWithContextEncoder(
logger.EncoderConfigJSON()), &bl.buf, lvl))
return bl
}

View File

@ -2,6 +2,7 @@ package testutil
import (
"io"
"sync"
"testing"
"github.com/stretchr/testify/suite"
@ -9,7 +10,8 @@ import (
type BufferLoggerTest struct {
suite.Suite
logger BufferedLogger
logger BufferedLogger
silentLogger BufferedLogger
}
func TestBufferLogger(t *testing.T) {
@ -18,10 +20,12 @@ func TestBufferLogger(t *testing.T) {
func (t *BufferLoggerTest) SetupSuite() {
t.logger = NewBufferedLogger()
t.silentLogger = NewBufferedLoggerSilent()
}
func (t *BufferLoggerTest) SetupTest() {
t.logger.Reset()
t.silentLogger.Reset()
}
func (t *BufferLoggerTest) Test_Read() {
@ -31,25 +35,62 @@ func (t *BufferLoggerTest) Test_Read() {
t.Require().NoError(err)
t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(string(data), "\"message\":\"test\"")
t.silentLogger.Debug("test")
data, err = io.ReadAll(t.silentLogger)
t.Require().NoError(err)
t.Assert().Contains(string(data), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(string(data), "\"message\":\"test\"")
}
func (t *BufferLoggerTest) Test_Bytes() {
t.logger.Debug("test")
t.Assert().Contains(string(t.logger.Bytes()), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(string(t.logger.Bytes()), "\"message\":\"test\"")
t.silentLogger.Debug("test")
t.Assert().Contains(string(t.silentLogger.Bytes()), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(string(t.silentLogger.Bytes()), "\"message\":\"test\"")
}
func (t *BufferLoggerTest) Test_String() {
t.logger.Debug("test")
t.Assert().Contains(t.logger.String(), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(t.logger.String(), "\"message\":\"test\"")
t.silentLogger.Debug("test")
t.Assert().Contains(t.silentLogger.String(), "\"level_name\":\"DEBUG\"")
t.Assert().Contains(t.silentLogger.String(), "\"message\":\"test\"")
}
func (t *BufferLoggerTest) TestRace() {
var (
wg sync.WaitGroup
starter sync.WaitGroup
)
starter.Add(1)
wg.Add(4)
go func() {
starter.Wait()
t.logger.Debug("test")
wg.Done()
}()
go func() {
starter.Wait()
t.logger.String()
wg.Done()
}()
go func() {
starter.Wait()
t.silentLogger.Debug("test")
wg.Done()
}()
go func() {
starter.Wait()
t.silentLogger.String()
wg.Done()
}()
starter.Done()
wg.Wait()
}

View File

@ -1,13 +1,77 @@
package core
import (
"go.uber.org/atomic"
"net/url"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin/binding"
"github.com/go-playground/validator/v10"
)
var (
crmDomainStore = &domainStore{
source: GetSaasDomains,
matcher: func(domain string, domains []Domain) bool {
if len(domains) == 0 {
return false
}
secondLevel := strings.Join(strings.Split(domain, ".")[1:], ".")
for _, crmDomain := range domains {
if crmDomain.Domain == secondLevel {
return true
}
}
return false
},
}
boxDomainStore = &domainStore{
source: GetBoxDomains,
matcher: func(domain string, domains []Domain) bool {
if len(domains) == 0 {
return false
}
for _, crmDomain := range domains {
if crmDomain.Domain == domain {
return true
}
}
return false
},
}
)
type domainStore struct {
domains []Domain
mutex sync.RWMutex
source func() []Domain
matcher func(string, []Domain) bool
lastUpdate atomic.Time
}
func (ds *domainStore) match(domain string) bool {
if time.Since(ds.lastUpdate.Load()) > time.Hour {
ds.update()
}
defer ds.mutex.RUnlock()
ds.mutex.RLock()
return ds.matcher(domain, ds.domains)
}
func (ds *domainStore) update() {
defer ds.mutex.Unlock()
ds.mutex.Lock()
ds.domains = ds.source()
ds.lastUpdate.Store(time.Now())
}
// init here will register `validateCrmURL` function for gin validator.
func init() {
if v, ok := binding.Validator.Engine().(*validator.Validate); ok {
@ -26,40 +90,16 @@ func validateCrmURL(fl validator.FieldLevel) bool {
func isDomainValid(crmURL string) bool {
parseURL, err := url.ParseRequestURI(crmURL)
if err != nil || nil == parseURL || !checkURLString(parseURL) {
return false
}
mainDomain := getMainDomain(parseURL.Hostname())
if checkDomains(GetSaasDomains(), mainDomain) {
hostname := parseURL.Hostname()
if crmDomainStore.match(hostname) {
return true
}
if checkDomains(GetBoxDomains(), parseURL.Hostname()) {
return true
}
return false
}
func checkDomains(crmDomains []Domain, domain string) bool {
if nil == crmDomains {
return false
}
for _, crmDomain := range crmDomains {
if crmDomain.Domain == domain {
return true
}
}
return false
}
func getMainDomain(hostname string) (mainDomain string) {
return strings.Join(strings.Split(hostname, ".")[1:], ".")
return boxDomainStore.match(hostname)
}
func checkURLString(parseURL *url.URL) bool {

View File

@ -29,6 +29,11 @@ func (s *ValidatorSuite) SetupSuite() {
}
}
func (s *ValidatorSuite) SetupTest() {
crmDomainStore.update()
boxDomainStore.update()
}
func (s *ValidatorSuite) getError(err error) string {
if err == nil {
return ""
@ -72,14 +77,14 @@ func (s *ValidatorSuite) Test_ValidationSuccess() {
"https://test.retailcrm.pro",
"https://raisa.retailcrm.es",
"https://blabla.simla.com",
"https://blabla.ecomlogic.com",
"https://crm.baucenter.ru",
"https://crm.holodilnik.ru",
"https://crm.eco.lanit.ru",
"https://ecom.inventive.ru",
"https://retailcrm.tvoydom.ru",
}
for _, domain := range boxDomainStore.domains {
crmDomains = append(crmDomains, "https://"+domain.Domain)
}
s.Assert().True(len(crmDomains) > 4, "No box domains were tested, test is incomplete!")
for _, domain := range crmDomains {
conn := models.Connection{
Key: "key",
@ -87,6 +92,6 @@ func (s *ValidatorSuite) Test_ValidationSuccess() {
}
err := s.engine.Struct(conn)
assert.NoError(s.T(), err, s.getError(err))
assert.NoError(s.T(), err, domain+": "+s.getError(err))
}
}