Refactoring and new features (#35)

* new error collector & improvements for errors
* update golangci-lint, microoptimizations, linter fixes
* UseTLS10 method
* remove dead code
* add 1.17 to the test matrix
* fix for docstring
* split the core package symbols into the subpackages (if feasible)
This commit is contained in:
Pavel 2021-12-01 15:40:23 +03:00 committed by GitHub
parent b0d5488f5a
commit 52109ee4ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 2324 additions and 754 deletions

View File

@ -22,14 +22,14 @@ jobs:
- name: Lint code with golangci-lint - name: Lint code with golangci-lint
uses: golangci/golangci-lint-action@v2 uses: golangci/golangci-lint-action@v2
with: with:
version: v1.36 version: v1.42.1
only-new-issues: true only-new-issues: true
tests: tests:
name: Tests name: Tests
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
go-version: ['1.16'] go-version: ['1.16', '1.17']
steps: steps:
- name: Set up Go ${{ matrix.go-version }} - name: Set up Go ${{ matrix.go-version }}
uses: actions/setup-go@v2 uses: actions/setup-go@v2

View File

@ -32,19 +32,15 @@ linters:
- gocyclo - gocyclo
- godot - godot
- goimports - goimports
- golint - revive
- gomnd
- gosec - gosec
- ifshort - ifshort
- interfacer
- lll - lll
- makezero - makezero
- maligned
- misspell - misspell
- nestif - nestif
- prealloc - prealloc
- predeclared - predeclared
- scopelint
- sqlclosecheck - sqlclosecheck
- unconvert - unconvert
- whitespace - whitespace
@ -56,9 +52,11 @@ linters-settings:
enable: enable:
- assign - assign
- atomic - atomic
- atomicalign
- bools - bools
- buildtag - buildtag
- copylocks - copylocks
- fieldalignment
- httpresponse - httpresponse
- loopclosure - loopclosure
- lostcancel - lostcancel
@ -152,12 +150,10 @@ linters-settings:
local-prefixes: github.com/retailcrm/mg-transport-core local-prefixes: github.com/retailcrm/mg-transport-core
lll: lll:
line-length: 120 line-length: 120
maligned:
suggest-new: true
misspell: misspell:
locale: US locale: US
nestif: nestif:
min-complexity: 4 min-complexity: 6
whitespace: whitespace:
multi-if: false multi-if: false
multi-func: false multi-func: false
@ -166,7 +162,6 @@ issues:
exclude-rules: exclude-rules:
- path: _test\.go - path: _test\.go
linters: linters:
- gomnd
- lll - lll
- bodyclose - bodyclose
- errcheck - errcheck
@ -175,7 +170,6 @@ issues:
- ineffassign - ineffassign
- whitespace - whitespace
- makezero - makezero
- maligned
- ifshort - ifshort
- errcheck - errcheck
- funlen - funlen

View File

@ -5,7 +5,7 @@ import (
"github.com/jessevdk/go-flags" "github.com/jessevdk/go-flags"
"github.com/retailcrm/mg-transport-core/core" "github.com/retailcrm/mg-transport-core/v2/core/db"
) )
// Options for tool command. // Options for tool command.
@ -20,7 +20,7 @@ func init() {
_, err := parser.AddCommand("migration", _, err := parser.AddCommand("migration",
"Create new empty migration in specified directory.", "Create new empty migration in specified directory.",
"Create new empty migration in specified directory.", "Create new empty migration in specified directory.",
&core.NewMigrationCommand{}, &db.NewMigrationCommand{},
) )
if err != nil { if err != nil {
@ -30,7 +30,7 @@ func init() {
func main() { func main() {
if _, err := parser.Parse(); err != nil { if _, err := parser.Parse(); err != nil {
if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { // nolint:errorlint
os.Exit(0) os.Exit(0)
} else { } else {
os.Exit(1) os.Exit(1)

View File

@ -1,28 +1,22 @@
package core package config
import ( import (
"io/ioutil" "io/ioutil"
"path/filepath" "path/filepath"
"regexp"
"time" "time"
"github.com/op/go-logging" "github.com/op/go-logging"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
var ( // Configuration settings data structure.
markdownSymbols = []string{"*", "_", "`", "["} type Configuration interface {
slashRegex = regexp.MustCompile(`/+$`)
)
// ConfigInterface settings data structure.
type ConfigInterface interface {
GetVersion() string GetVersion() string
GetSentryDSN() string GetSentryDSN() string
GetLogLevel() logging.Level GetLogLevel() logging.Level
GetHTTPConfig() HTTPServerConfig GetHTTPConfig() HTTPServerConfig
GetDBConfig() DatabaseConfig GetDBConfig() DatabaseConfig
GetAWSConfig() ConfigAWS GetAWSConfig() AWS
GetTransportInfo() InfoInterface GetTransportInfo() InfoInterface
GetHTTPClientConfig() *HTTPClientConfig GetHTTPClientConfig() *HTTPClientConfig
GetUpdateInterval() int GetUpdateInterval() int
@ -39,16 +33,16 @@ type InfoInterface interface {
// Config struct. // Config struct.
type Config struct { type Config struct {
Version string `yaml:"version"`
LogLevel logging.Level `yaml:"log_level"`
Database DatabaseConfig `yaml:"database"`
SentryDSN string `yaml:"sentry_dsn"`
HTTPServer HTTPServerConfig `yaml:"http_server"`
Debug bool `yaml:"debug"`
UpdateInterval int `yaml:"update_interval"`
ConfigAWS ConfigAWS `yaml:"config_aws"`
TransportInfo Info `yaml:"transport_info"`
HTTPClientConfig *HTTPClientConfig `yaml:"http_client"` HTTPClientConfig *HTTPClientConfig `yaml:"http_client"`
ConfigAWS AWS `yaml:"config_aws"`
TransportInfo Info `yaml:"transport_info"`
HTTPServer HTTPServerConfig `yaml:"http_server"`
Version string `yaml:"version"`
SentryDSN string `yaml:"sentry_dsn"`
Database DatabaseConfig `yaml:"database"`
UpdateInterval int `yaml:"update_interval"`
LogLevel logging.Level `yaml:"log_level"`
Debug bool `yaml:"debug"`
} }
// Info struct. // Info struct.
@ -59,8 +53,8 @@ type Info struct {
Secret string `yaml:"secret"` Secret string `yaml:"secret"`
} }
// ConfigAWS struct. // AWS struct.
type ConfigAWS struct { type AWS struct {
AccessKeyID string `yaml:"access_key_id"` AccessKeyID string `yaml:"access_key_id"`
SecretAccessKey string `yaml:"secret_access_key"` SecretAccessKey string `yaml:"secret_access_key"`
Region string `yaml:"region"` Region string `yaml:"region"`
@ -72,19 +66,19 @@ type ConfigAWS struct {
// DatabaseConfig struct. // DatabaseConfig struct.
type DatabaseConfig struct { type DatabaseConfig struct {
Connection interface{} `yaml:"connection"` Connection interface{} `yaml:"connection"`
Logging bool `yaml:"logging"`
TablePrefix string `yaml:"table_prefix"` TablePrefix string `yaml:"table_prefix"`
MaxOpenConnections int `yaml:"max_open_connections"` MaxOpenConnections int `yaml:"max_open_connections"`
MaxIdleConnections int `yaml:"max_idle_connections"` MaxIdleConnections int `yaml:"max_idle_connections"`
ConnectionLifetime int `yaml:"connection_lifetime"` ConnectionLifetime int `yaml:"connection_lifetime"`
Logging bool `yaml:"logging"`
} }
// HTTPClientConfig struct. // HTTPClientConfig struct.
type HTTPClientConfig struct { type HTTPClientConfig struct {
Timeout time.Duration `yaml:"timeout"`
SSLVerification *bool `yaml:"ssl_verification"` SSLVerification *bool `yaml:"ssl_verification"`
MockAddress string `yaml:"mock_address"` MockAddress string `yaml:"mock_address"`
MockedDomains []string `yaml:"mocked_domains"` MockedDomains []string `yaml:"mocked_domains"`
Timeout time.Duration `yaml:"timeout"`
} }
// HTTPServerConfig struct. // HTTPServerConfig struct.
@ -157,7 +151,7 @@ func (c Config) IsDebug() bool {
} }
// GetAWSConfig AWS configuration. // GetAWSConfig AWS configuration.
func (c Config) GetAWSConfig() ConfigAWS { func (c Config) GetAWSConfig() AWS {
return c.ConfigAWS return c.ConfigAWS
} }

View File

@ -1,4 +1,4 @@
package core package config
import ( import (
"io/ioutil" "io/ioutil"

View File

@ -1,4 +1,4 @@
package core package db
import ( import (
"fmt" "fmt"
@ -16,9 +16,9 @@ var migrations *Migrate
type Migrate struct { type Migrate struct {
db *gorm.DB db *gorm.DB
first *gormigrate.Migration first *gormigrate.Migration
versions []string
migrations map[string]*gormigrate.Migration migrations map[string]*gormigrate.Migration
GORMigrate *gormigrate.Gormigrate GORMigrate *gormigrate.Gormigrate
versions []string
prepared bool prepared bool
} }
@ -123,7 +123,7 @@ func (m *Migrate) MigrateNextTo(version string) error {
case current < next: case current < next:
return m.GORMigrate.MigrateTo(next) return m.GORMigrate.MigrateTo(next)
case current > next: case current > next:
return errors.New(fmt.Sprintf("current migration version '%s' is higher than fetched version '%s'", current, next)) return fmt.Errorf("current migration version '%s' is higher than fetched version '%s'", current, next)
default: default:
return nil return nil
} }
@ -144,7 +144,7 @@ func (m *Migrate) MigratePreviousTo(version string) error {
case current > prev: case current > prev:
return m.GORMigrate.RollbackTo(prev) return m.GORMigrate.RollbackTo(prev)
case current < prev: case current < prev:
return errors.New(fmt.Sprintf("current migration version '%s' is lower than fetched version '%s'", current, prev)) return fmt.Errorf("current migration version '%s' is lower than fetched version '%s'", current, prev)
case prev == "0": case prev == "0":
return m.GORMigrate.RollbackMigration(m.first) return m.GORMigrate.RollbackMigration(m.first)
default: default:
@ -241,8 +241,11 @@ func (m *Migrate) prepareMigrations() error {
return nil return nil
} }
i := 0
keys = make([]string, len(m.migrations))
for key := range m.migrations { for key := range m.migrations {
keys = append(keys, key) keys[i] = key
i++
} }
sort.Strings(keys) sort.Strings(keys)

View File

@ -1,4 +1,4 @@
package core package db
import ( import (
"database/sql" "database/sql"

View File

@ -1,4 +1,4 @@
package core package db
import ( import (
"fmt" "fmt"
@ -32,7 +32,7 @@ func init() {
// NewMigrationCommand struct. // NewMigrationCommand struct.
type NewMigrationCommand struct { type NewMigrationCommand struct {
Directory string `short:"d" long:"directory" default:"./migrations" description:"Directory where migration will be created"` Directory string `short:"d" long:"directory" default:"./migrations" description:"Directory where migration will be created"` // nolint:lll
} }
// FileExists returns true if provided file exist and it's not directory. // FileExists returns true if provided file exist and it's not directory.

View File

@ -1,4 +1,4 @@
package core package db
import ( import (
"fmt" "fmt"
@ -26,7 +26,7 @@ func (s *MigrationGeneratorSuite) SetupSuite() {
func (s *MigrationGeneratorSuite) Test_FileExists() { func (s *MigrationGeneratorSuite) Test_FileExists() {
var ( var (
seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) // nolint:gosec
notExist = fmt.Sprintf("/tmp/%d", seededRand.Int31()) notExist = fmt.Sprintf("/tmp/%d", seededRand.Int31())
) )

18
core/db/models/account.go Normal file
View File

@ -0,0 +1,18 @@
package models
import "time"
// Account model.
type Account struct {
CreatedAt time.Time
UpdatedAt time.Time
ChannelSettingsHash string `gorm:"column:channel_settings_hash; type:varchar(70)" binding:"max=70"`
Name string `gorm:"column:name; type:varchar(100)" json:"name,omitempty" binding:"max=100"`
Lang string `gorm:"column:lang; type:varchar(2)" json:"lang,omitempty" binding:"max=2"`
Channel uint64 `gorm:"column:channel; not null; unique" json:"channel,omitempty"`
ID int `gorm:"primary_key"`
ConnectionID int `gorm:"column:connection_id" json:"connectionId,omitempty"`
}
// Accounts list.
type Accounts []Account

View File

@ -0,0 +1,17 @@
package models
import "time"
// Connection model.
type Connection struct {
CreatedAt time.Time
UpdatedAt time.Time
Key string `gorm:"column:api_key; type:varchar(100); not null" json:"api_key,omitempty" binding:"required,max=100"` // nolint:lll
URL string `gorm:"column:api_url; type:varchar(255); not null" json:"api_url,omitempty" binding:"required,validateCrmURL,max=255"` // nolint:lll
GateURL string `gorm:"column:mg_url; type:varchar(255); not null;" json:"mg_url,omitempty" binding:"max=255"`
GateToken string `gorm:"column:mg_token; type:varchar(100); not null; unique" json:"mg_token,omitempty" binding:"max=100"` // nolint:lll
ClientID string `gorm:"column:client_id; type:varchar(70); not null; unique" json:"clientId,omitempty"`
Accounts []Account `gorm:"foreignkey:ConnectionID"`
ID int `gorm:"primary_key"`
Active bool `json:"active,omitempty"`
}

24
core/db/models/user.go Normal file
View File

@ -0,0 +1,24 @@
package models
import "time"
// User model.
type User struct {
CreatedAt time.Time
UpdatedAt time.Time
ExternalID string `gorm:"column:external_id; type:varchar(255); not null; unique"`
UserPhotoURL string `gorm:"column:user_photo_url; type:varchar(255)" binding:"max=255"`
UserPhotoID string `gorm:"column:user_photo_id; type:varchar(100)" binding:"max=100"`
ID int `gorm:"primary_key"`
}
// TableName will return table name for User
// It will not work if User is not embedded, but mapped as another type
// type MyUser User // will not work
// but
// type MyUser struct { // will work
// User
// }
func (User) TableName() string {
return "mg_user"
}

View File

@ -1,4 +1,4 @@
package core package models
import ( import (
"testing" "testing"

View File

@ -1,9 +1,12 @@
package core package db
import ( import (
"time" "time"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"github.com/retailcrm/mg-transport-core/v2/core/config"
// PostgreSQL is an default. // PostgreSQL is an default.
_ "github.com/jinzhu/gorm/dialects/postgres" _ "github.com/jinzhu/gorm/dialects/postgres"
) )
@ -14,13 +17,14 @@ type ORM struct {
} }
// NewORM will init new database connection. // NewORM will init new database connection.
func NewORM(config DatabaseConfig) *ORM { func NewORM(config config.DatabaseConfig) *ORM {
orm := &ORM{} orm := &ORM{}
orm.createDB(config) orm.CreateDB(config)
return orm return orm
} }
func (orm *ORM) createDB(config DatabaseConfig) { // CreateDB connection using provided config.
func (orm *ORM) CreateDB(config config.DatabaseConfig) {
db, err := gorm.Open("postgres", config.Connection) db, err := gorm.Open("postgres", config.Connection)
if err != nil { if err != nil {
panic(err) panic(err)

View File

@ -1,4 +1,4 @@
package core package db
import ( import (
"database/sql" "database/sql"
@ -7,6 +7,8 @@ import (
"github.com/DATA-DOG/go-sqlmock" "github.com/DATA-DOG/go-sqlmock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/retailcrm/mg-transport-core/v2/core/config"
) )
func TestORM_NewORM(t *testing.T) { func TestORM_NewORM(t *testing.T) {
@ -22,7 +24,7 @@ func TestORM_NewORM(t *testing.T) {
db, _, err = sqlmock.New() db, _, err = sqlmock.New()
require.NoError(t, err) require.NoError(t, err)
config := DatabaseConfig{ config := config.DatabaseConfig{
Connection: db, Connection: db,
Logging: true, Logging: true,
TablePrefix: "", TablePrefix: "",
@ -39,7 +41,7 @@ func TestORM_createDB_Fail(t *testing.T) {
assert.NotNil(t, recover()) assert.NotNil(t, recover())
}() }()
NewORM(DatabaseConfig{Connection: nil}) NewORM(config.DatabaseConfig{Connection: nil})
} }
func TestORM_CloseDB(t *testing.T) { func TestORM_CloseDB(t *testing.T) {
@ -56,7 +58,7 @@ func TestORM_CloseDB(t *testing.T) {
db, dbMock, err = sqlmock.New() db, dbMock, err = sqlmock.New()
require.NoError(t, err) require.NoError(t, err)
config := DatabaseConfig{ config := config.DatabaseConfig{
Connection: db, Connection: db,
Logging: true, Logging: true,
TablePrefix: "", TablePrefix: "",

View File

@ -1,7 +1,6 @@
// Copyright (c) 2019 RetailDriver LLC
// Use of this source code is governed by a MIT
/* /*
Package core provides different functions like error-reporting, logging, localization, etc. in order to make it easier to create transports. Package core provides different functions like error-reporting, logging, localization, etc.
to make it easier to create transports.
Usage: Usage:
package main package main
@ -110,5 +109,7 @@ Migration generator
This library contains helper tool for transports. You can install it via go: This library contains helper tool for transports. You can install it via go:
$ go get -u github.com/retailcrm/mg-transport-core/cmd/transport-core-tool $ go get -u github.com/retailcrm/mg-transport-core/cmd/transport-core-tool
Currently, it only can generate new migrations for your transport. Currently, it only can generate new migrations for your transport.
Copyright (c) 2019 RetailDriver LLC. Usage of this source code is governed by a MIT license.
*/ */
package core package core

View File

@ -12,33 +12,41 @@ import (
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"github.com/op/go-logging" "github.com/op/go-logging"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/retailcrm/mg-transport-core/v2/core/config"
"github.com/retailcrm/mg-transport-core/v2/core/db"
"github.com/retailcrm/mg-transport-core/v2/core/middleware"
"github.com/retailcrm/mg-transport-core/v2/core/util"
"github.com/retailcrm/mg-transport-core/v2/core/util/httputil"
"github.com/retailcrm/mg-transport-core/v2/core/logger"
) )
var boolTrue = true var boolTrue = true
// DefaultHTTPClientConfig is a default config for HTTP client. It will be used by Engine for building HTTP client // DefaultHTTPClientConfig is a default config for HTTP client. It will be used by Engine for building HTTP client
// if HTTP client config is not present in the configuration. // if HTTP client config is not present in the configuration.
var DefaultHTTPClientConfig = &HTTPClientConfig{ var DefaultHTTPClientConfig = &config.HTTPClientConfig{
Timeout: 30, Timeout: 30,
SSLVerification: &boolTrue, SSLVerification: &boolTrue,
} }
// Engine struct. // Engine struct.
type Engine struct { type Engine struct {
Localizer logger logger.Logger
ORM
Sentry
Utils
ginEngine *gin.Engine
httpClient *http.Client
logger LoggerInterface
mutex sync.RWMutex
csrf *CSRF
jobManager *JobManager
PreloadLanguages []language.Tag
Sessions sessions.Store Sessions sessions.Store
Config ConfigInterface
LogFormatter logging.Formatter LogFormatter logging.Formatter
Config config.Configuration
ginEngine *gin.Engine
csrf *middleware.CSRF
httpClient *http.Client
jobManager *JobManager
db.ORM
Localizer
util.Utils
PreloadLanguages []language.Tag
Sentry
mutex sync.RWMutex
prepared bool prepared bool
} }
@ -52,9 +60,9 @@ func New() *Engine {
loadMutex: &sync.RWMutex{}, loadMutex: &sync.RWMutex{},
}, },
PreloadLanguages: []language.Tag{}, PreloadLanguages: []language.Tag{},
ORM: ORM{}, ORM: db.ORM{},
Sentry: Sentry{}, Sentry: Sentry{},
Utils: Utils{}, Utils: util.Utils{},
ginEngine: nil, ginEngine: nil,
logger: nil, logger: nil,
mutex: sync.RWMutex{}, mutex: sync.RWMutex{},
@ -91,7 +99,7 @@ func (e *Engine) Prepare() *Engine {
e.DefaultError = "error" e.DefaultError = "error"
} }
if e.LogFormatter == nil { if e.LogFormatter == nil {
e.LogFormatter = DefaultLogFormatter() e.LogFormatter = logger.DefaultLogFormatter()
} }
if e.LocaleMatcher == nil { if e.LocaleMatcher == nil {
e.LocaleMatcher = DefaultLocalizerMatcher() e.LocaleMatcher = DefaultLocalizerMatcher()
@ -107,10 +115,10 @@ func (e *Engine) Prepare() *Engine {
e.Localizer.Preload(e.PreloadLanguages) e.Localizer.Preload(e.PreloadLanguages)
} }
e.createDB(e.Config.GetDBConfig()) e.CreateDB(e.Config.GetDBConfig())
e.createRavenClient(e.Config.GetSentryDSN()) e.createRavenClient(e.Config.GetSentryDSN())
e.resetUtils(e.Config.GetAWSConfig(), e.Config.IsDebug(), 0) e.ResetUtils(e.Config.GetAWSConfig(), e.Config.IsDebug(), 0)
e.SetLogger(NewLogger(e.Config.GetTransportInfo().GetCode(), e.Config.GetLogLevel(), e.LogFormatter)) e.SetLogger(logger.NewStandard(e.Config.GetTransportInfo().GetCode(), e.Config.GetLogLevel(), e.LogFormatter))
e.Sentry.Localizer = &e.Localizer e.Sentry.Localizer = &e.Localizer
e.Sentry.Stacktrace = true e.Sentry.Stacktrace = true
e.Utils.Logger = e.Logger() e.Utils.Logger = e.Logger()
@ -176,12 +184,12 @@ func (e *Engine) JobManager() *JobManager {
} }
// Logger returns current logger. // Logger returns current logger.
func (e *Engine) Logger() LoggerInterface { func (e *Engine) Logger() logger.Logger {
return e.logger return e.logger
} }
// SetLogger sets provided logger instance to engine. // SetLogger sets provided logger instance to engine.
func (e *Engine) SetLogger(l LoggerInterface) *Engine { func (e *Engine) SetLogger(l logger.Logger) *Engine {
if l == nil { if l == nil {
return e return e
} }
@ -194,11 +202,11 @@ func (e *Engine) SetLogger(l LoggerInterface) *Engine {
// BuildHTTPClient builds HTTP client with provided configuration. // BuildHTTPClient builds HTTP client with provided configuration.
func (e *Engine) BuildHTTPClient(certs *x509.CertPool, replaceDefault ...bool) *Engine { func (e *Engine) BuildHTTPClient(certs *x509.CertPool, replaceDefault ...bool) *Engine {
client, err := NewHTTPClientBuilder(). client, err := httputil.NewHTTPClientBuilder().
WithLogger(e.Logger()). WithLogger(e.Logger()).
SetLogging(e.Config.IsDebug()). SetLogging(e.Config.IsDebug()).
SetCertPool(certs). SetCertPool(certs).
FromEngine(e). FromConfig(e.GetHTTPClientConfig()).
Build(replaceDefault...) Build(replaceDefault...)
if err != nil { if err != nil {
@ -211,7 +219,7 @@ func (e *Engine) BuildHTTPClient(certs *x509.CertPool, replaceDefault ...bool) *
} }
// GetHTTPClientConfig returns configuration for HTTP client. // GetHTTPClientConfig returns configuration for HTTP client.
func (e *Engine) GetHTTPClientConfig() *HTTPClientConfig { func (e *Engine) GetHTTPClientConfig() *config.HTTPClientConfig {
if e.Config.GetHTTPClientConfig() != nil { if e.Config.GetHTTPClientConfig() != nil {
return e.Config.GetHTTPClientConfig() return e.Config.GetHTTPClientConfig()
} }
@ -266,12 +274,13 @@ func (e *Engine) WithFilesystemSessions(path string, keyLength ...int) *Engine {
// InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized, // InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized,
// use engine.WithCookieStore or engine.WithFilesystemStore for that. // use engine.WithCookieStore or engine.WithFilesystemStore for that.
// Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt. // Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt.
func (e *Engine) InitCSRF(secret string, abortFunc CSRFAbortFunc, getter CSRFTokenGetter) *Engine { func (e *Engine) InitCSRF(
secret string, abortFunc middleware.CSRFAbortFunc, getter middleware.CSRFTokenGetter) *Engine {
if e.Sessions == nil { if e.Sessions == nil {
panic("engine.Sessions must be initialized first") panic("engine.Sessions must be initialized first")
} }
e.csrf = NewCSRF("", secret, "", e.Sessions, abortFunc, getter) e.csrf = middleware.NewCSRF("", secret, "", e.Sessions, abortFunc, getter)
return e return e
} }

View File

@ -17,6 +17,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/retailcrm/mg-transport-core/v2/core/config"
"github.com/retailcrm/mg-transport-core/v2/core/middleware"
"github.com/retailcrm/mg-transport-core/v2/core/util/httputil"
"github.com/retailcrm/mg-transport-core/v2/core/logger"
) )
type EngineTest struct { type EngineTest struct {
@ -38,10 +44,10 @@ func (e *EngineTest) SetupTest() {
createTestLangFiles(e.T()) createTestLangFiles(e.T())
e.engine.Config = Config{ e.engine.Config = config.Config{
Version: "1", Version: "1",
LogLevel: 5, LogLevel: 5,
Database: DatabaseConfig{ Database: config.DatabaseConfig{
Connection: db, Connection: db,
Logging: true, Logging: true,
TablePrefix: "", TablePrefix: "",
@ -50,14 +56,14 @@ func (e *EngineTest) SetupTest() {
ConnectionLifetime: 60, ConnectionLifetime: 60,
}, },
SentryDSN: "sentry dsn", SentryDSN: "sentry dsn",
HTTPServer: HTTPServerConfig{ HTTPServer: config.HTTPServerConfig{
Host: "0.0.0.0", Host: "0.0.0.0",
Listen: ":3001", Listen: ":3001",
}, },
Debug: true, Debug: true,
UpdateInterval: 30, UpdateInterval: 30,
ConfigAWS: ConfigAWS{}, ConfigAWS: config.AWS{},
TransportInfo: Info{ TransportInfo: config.Info{
Name: "test", Name: "test",
Code: "test", Code: "test",
LogoPath: "/test.svg", LogoPath: "/test.svg",
@ -113,7 +119,7 @@ func (e *EngineTest) Test_Prepare() {
func (e *EngineTest) Test_initGin_Release() { func (e *EngineTest) Test_initGin_Release() {
engine := New() engine := New()
engine.Config = Config{Debug: false} engine.Config = config.Config{Debug: false}
engine.initGin() engine.initGin()
assert.NotNil(e.T(), engine.ginEngine) assert.NotNil(e.T(), engine.ginEngine)
} }
@ -169,8 +175,8 @@ func (e *EngineTest) Test_ConfigureRouter() {
} }
func (e *EngineTest) Test_BuildHTTPClient() { func (e *EngineTest) Test_BuildHTTPClient() {
e.engine.Config = &Config{ e.engine.Config = &config.Config{
HTTPClientConfig: &HTTPClientConfig{ HTTPClientConfig: &config.HTTPClientConfig{
Timeout: 30, Timeout: 30,
SSLVerification: boolPtr(true), SSLVerification: boolPtr(true),
}, },
@ -186,7 +192,7 @@ func (e *EngineTest) Test_BuildHTTPClient() {
} }
func (e *EngineTest) Test_BuildHTTPClient_NoConfig() { func (e *EngineTest) Test_BuildHTTPClient_NoConfig() {
e.engine.Config = &Config{} e.engine.Config = &config.Config{}
e.engine.BuildHTTPClient(x509.NewCertPool()) e.engine.BuildHTTPClient(x509.NewCertPool())
assert.NotNil(e.T(), e.engine.httpClient) assert.NotNil(e.T(), e.engine.httpClient)
@ -198,11 +204,11 @@ func (e *EngineTest) Test_BuildHTTPClient_NoConfig() {
} }
func (e *EngineTest) Test_GetHTTPClientConfig() { func (e *EngineTest) Test_GetHTTPClientConfig() {
e.engine.Config = &Config{} e.engine.Config = &config.Config{}
assert.Equal(e.T(), DefaultHTTPClientConfig, e.engine.GetHTTPClientConfig()) assert.Equal(e.T(), DefaultHTTPClientConfig, e.engine.GetHTTPClientConfig())
e.engine.Config = &Config{ e.engine.Config = &config.Config{
HTTPClientConfig: &HTTPClientConfig{ HTTPClientConfig: &config.HTTPClientConfig{
Timeout: 10, Timeout: 10,
SSLVerification: boolPtr(true), SSLVerification: boolPtr(true),
}, },
@ -230,7 +236,7 @@ func (e *EngineTest) Test_SetLogger() {
defer func() { defer func() {
e.engine.logger = origLogger e.engine.logger = origLogger
}() }()
e.engine.logger = &Logger{} e.engine.logger = &logger.StandardLogger{}
e.engine.SetLogger(nil) e.engine.SetLogger(nil)
assert.NotNil(e.T(), e.engine.logger) assert.NotNil(e.T(), e.engine.logger)
} }
@ -241,7 +247,7 @@ func (e *EngineTest) Test_SetHTTPClient() {
e.engine.httpClient = origClient e.engine.httpClient = origClient
}() }()
e.engine.httpClient = nil e.engine.httpClient = nil
httpClient, err := NewHTTPClientBuilder().Build() httpClient, err := httputil.NewHTTPClientBuilder().Build()
require.NoError(e.T(), err) require.NoError(e.T(), err)
assert.NotNil(e.T(), httpClient) assert.NotNil(e.T(), httpClient)
e.engine.SetHTTPClient(&http.Client{}) e.engine.SetHTTPClient(&http.Client{})
@ -257,7 +263,7 @@ func (e *EngineTest) Test_HTTPClient() {
}() }()
e.engine.httpClient = nil e.engine.httpClient = nil
require.Same(e.T(), http.DefaultClient, e.engine.HTTPClient()) require.Same(e.T(), http.DefaultClient, e.engine.HTTPClient())
httpClient, err := NewHTTPClientBuilder().Build() httpClient, err := httputil.NewHTTPClientBuilder().Build()
require.NoError(e.T(), err) require.NoError(e.T(), err)
e.engine.httpClient = httpClient e.engine.httpClient = httpClient
assert.Same(e.T(), httpClient, e.engine.HTTPClient()) assert.Same(e.T(), httpClient, e.engine.HTTPClient())
@ -270,7 +276,7 @@ func (e *EngineTest) Test_InitCSRF_Fail() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.Sessions = nil e.engine.Sessions = nil
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter)
assert.Nil(e.T(), e.engine.csrf) assert.Nil(e.T(), e.engine.csrf)
} }
@ -281,7 +287,7 @@ func (e *EngineTest) Test_InitCSRF() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.WithCookieSessions(4) e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter)
assert.NotNil(e.T(), e.engine.csrf) assert.NotNil(e.T(), e.engine.csrf)
} }
@ -291,7 +297,7 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware_Fail() {
}() }()
e.engine.csrf = nil e.engine.csrf = nil
e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods) e.engine.VerifyCSRFMiddleware(middleware.DefaultIgnoredMethods)
} }
func (e *EngineTest) Test_VerifyCSRFMiddleware() { func (e *EngineTest) Test_VerifyCSRFMiddleware() {
@ -301,8 +307,8 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.WithCookieSessions(4) e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter)
e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods) e.engine.VerifyCSRFMiddleware(middleware.DefaultIgnoredMethods)
} }
func (e *EngineTest) Test_GenerateCSRFMiddleware_Fail() { func (e *EngineTest) Test_GenerateCSRFMiddleware_Fail() {
@ -321,7 +327,7 @@ func (e *EngineTest) Test_GenerateCSRFMiddleware() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.WithCookieSessions(4) e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter)
e.engine.GenerateCSRFMiddleware() e.engine.GenerateCSRFMiddleware()
} }
@ -350,7 +356,7 @@ func (e *EngineTest) Test_GetCSRFToken() {
e.engine.csrf = nil e.engine.csrf = nil
e.engine.WithCookieSessions(4) e.engine.WithCookieSessions(4)
e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter)
assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c)) assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c))
assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c)) assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c))
} }

View File

@ -1,36 +0,0 @@
package core
import "net/http"
// ErrorResponse struct.
type ErrorResponse struct {
Error string `json:"error"`
}
// ErrorsResponse struct.
type ErrorsResponse struct {
Error []string `json:"error"`
}
// GetErrorResponse returns ErrorResponse with specified status code
// Usage (with gin):
// context.JSON(GetErrorResponse(http.StatusPaymentRequired, "Not enough money"))
func GetErrorResponse(statusCode int, error string) (int, interface{}) {
return statusCode, ErrorResponse{
Error: error,
}
}
// BadRequest returns ErrorResponse with code 400
// Usage (with gin):
// context.JSON(BadRequest("invalid data"))
func BadRequest(error string) (int, interface{}) {
return GetErrorResponse(http.StatusBadRequest, error)
}
// InternalServerError returns ErrorResponse with code 500
// Usage (with gin):
// context.JSON(BadRequest("invalid data"))
func InternalServerError(error string) (int, interface{}) {
return GetErrorResponse(http.StatusInternalServerError, error)
}

View File

@ -1,38 +0,0 @@
package core
import (
"github.com/pkg/errors"
)
// ErrorCollector can be used to group several error calls into one call.
// It is mostly useful in GORM migrations, where you should return only one errors, but several can occur.
// Error messages are composed into one message. For example:
// err := core.ErrorCollector(
// errors.New("first"),
// errors.New("second")
// )
//
// // Will output `first < second`
// fmt.Println(err.Error())
// Example with GORM migration, returns one migration error with all error messages:
// return core.ErrorCollector(
// db.CreateTable(models.Account{}, models.Connection{}).Error,
// db.Table("account").AddUniqueIndex("account_key", "channel").Error,
// )
func ErrorCollector(errorsList ...error) error {
var errorMsg string
for _, errItem := range errorsList {
if errItem == nil {
continue
}
errorMsg += "< " + errItem.Error() + " "
}
if errorMsg != "" {
return errors.New(errorMsg[2 : len(errorMsg)-1])
}
return nil
}

View File

@ -1,40 +0,0 @@
package core
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestErrorCollector_NoError(t *testing.T) {
err := ErrorCollector(nil, nil, nil)
assert.NoError(t, err)
assert.Nil(t, err)
}
func TestErrorCollector_SeveralErrors(t *testing.T) {
err := ErrorCollector(nil, errors.New("error text"), nil)
assert.Error(t, err)
assert.Equal(t, "error text", err.Error())
}
func TestErrorCollector_EmptyErrorMessage(t *testing.T) {
err := ErrorCollector(nil, errors.New(""), nil)
assert.Error(t, err)
assert.Equal(t, "", err.Error())
}
func TestErrorCollector_AllErrors(t *testing.T) {
err := ErrorCollector(
errors.New("first"),
errors.New("second"),
errors.New("third"),
)
assert.Error(t, err)
assert.Equal(t, "first < second < third", err.Error())
}

View File

@ -6,31 +6,28 @@ import (
"sync" "sync"
"time" "time"
"github.com/op/go-logging" "github.com/retailcrm/mg-transport-core/v2/core/logger"
) )
// JobFunc is empty func which should be executed in a parallel goroutine. // JobFunc is empty func which should be executed in a parallel goroutine.
type JobFunc func(JobLogFunc) error type JobFunc func(logger.Logger) error
// JobLogFunc is a function which logs data from job.
type JobLogFunc func(string, logging.Level, ...interface{})
// JobErrorHandler is a function to handle jobs errors. First argument is a job name. // JobErrorHandler is a function to handle jobs errors. First argument is a job name.
type JobErrorHandler func(string, error, JobLogFunc) type JobErrorHandler func(string, error, logger.Logger)
// JobPanicHandler is a function to handle jobs panics. First argument is a job name. // JobPanicHandler is a function to handle jobs panics. First argument is a job name.
type JobPanicHandler func(string, interface{}, JobLogFunc) type JobPanicHandler func(string, interface{}, logger.Logger)
// Job represents single job. Regular job will be executed every Interval. // Job represents single job. Regular job will be executed every Interval.
type Job struct { type Job struct {
Command JobFunc Command JobFunc
ErrorHandler JobErrorHandler ErrorHandler JobErrorHandler
PanicHandler JobPanicHandler PanicHandler JobPanicHandler
stopChannel chan bool
Interval time.Duration Interval time.Duration
writeLock sync.RWMutex writeLock sync.RWMutex
Regular bool Regular bool
active bool active bool
stopChannel chan bool
} }
// JobManager controls jobs execution flow. Jobs can be added just for later use (e.g. JobManager can be used as // JobManager controls jobs execution flow. Jobs can be added just for later use (e.g. JobManager can be used as
@ -39,9 +36,9 @@ type Job struct {
// SetLogger(logger). // SetLogger(logger).
// SetLogging(false) // SetLogging(false)
// _ = manager.RegisterJob("updateTokens", &Job{ // _ = manager.RegisterJob("updateTokens", &Job{
// Command: func(logFunc JobLogFunc) error { // Command: func(log logger.Logger) error {
// // logic goes here... // // logic goes here...
// logFunc("All tokens were updated successfully", logging.INFO) // logger.Info("All tokens were updated successfully")
// return nil // return nil
// }, // },
// ErrorHandler: DefaultJobErrorHandler(), // ErrorHandler: DefaultJobErrorHandler(),
@ -51,13 +48,14 @@ type Job struct {
// }) // })
// manager.Start() // manager.Start()
type JobManager struct { type JobManager struct {
logger logger.Logger
nilLogger logger.Logger
jobs *sync.Map jobs *sync.Map
enableLogging bool enableLogging bool
logger LoggerInterface
} }
// getWrappedFunc wraps job into function. // getWrappedFunc wraps job into function.
func (j *Job) getWrappedFunc(name string, log JobLogFunc) func() { func (j *Job) getWrappedFunc(name string, log logger.Logger) func() {
return func() { return func() {
defer func() { defer func() {
if r := recover(); r != nil && j.PanicHandler != nil { if r := recover(); r != nil && j.PanicHandler != nil {
@ -72,7 +70,7 @@ func (j *Job) getWrappedFunc(name string, log JobLogFunc) func() {
} }
// getWrappedTimerFunc returns job timer func to run in the separate goroutine. // getWrappedTimerFunc returns job timer func to run in the separate goroutine.
func (j *Job) getWrappedTimerFunc(name string, log JobLogFunc) func(chan bool) { func (j *Job) getWrappedTimerFunc(name string, log logger.Logger) func(chan bool) {
return func(stopChannel chan bool) { return func(stopChannel chan bool) {
for range time.NewTicker(j.Interval).C { for range time.NewTicker(j.Interval).C {
select { select {
@ -86,7 +84,7 @@ func (j *Job) getWrappedTimerFunc(name string, log JobLogFunc) func(chan bool) {
} }
// run job. // run job.
func (j *Job) run(name string, log JobLogFunc) *Job { func (j *Job) run(name string, log logger.Logger) {
j.writeLock.RLock() j.writeLock.RLock()
if j.Regular && j.Interval > 0 && !j.active { if j.Regular && j.Interval > 0 && !j.active {
@ -100,12 +98,10 @@ func (j *Job) run(name string, log JobLogFunc) *Job {
} else { } else {
j.writeLock.RUnlock() j.writeLock.RUnlock()
} }
return j
} }
// stop running job. // stop running job.
func (j *Job) stop() *Job { func (j *Job) stop() {
j.writeLock.RLock() j.writeLock.RLock()
if j.active && j.stopChannel != nil { if j.active && j.stopChannel != nil {
@ -119,47 +115,43 @@ func (j *Job) stop() *Job {
} else { } else {
j.writeLock.RUnlock() j.writeLock.RUnlock()
} }
return j
} }
// runOnce run job once. // runOnce run job once.
func (j *Job) runOnce(name string, log JobLogFunc) *Job { func (j *Job) runOnce(name string, log logger.Logger) {
go j.getWrappedFunc(name, log)() go j.getWrappedFunc(name, log)()
return j
} }
// runOnceSync run job once in current goroutine. // runOnceSync run job once in current goroutine.
func (j *Job) runOnceSync(name string, log JobLogFunc) *Job { func (j *Job) runOnceSync(name string, log logger.Logger) {
j.getWrappedFunc(name, log)() j.getWrappedFunc(name, log)()
return j
} }
// NewJobManager is a JobManager constructor. // NewJobManager is a JobManager constructor.
func NewJobManager() *JobManager { func NewJobManager() *JobManager {
return &JobManager{jobs: &sync.Map{}} return &JobManager{jobs: &sync.Map{}, nilLogger: logger.NewNil()}
} }
// DefaultJobErrorHandler returns default error handler for a job. // DefaultJobErrorHandler returns default error handler for a job.
func DefaultJobErrorHandler() JobErrorHandler { func DefaultJobErrorHandler() JobErrorHandler {
return func(name string, err error, log JobLogFunc) { return func(name string, err error, log logger.Logger) {
if err != nil && name != "" { if err != nil && name != "" {
log("Job `%s` errored with an error: `%s`", logging.ERROR, name, err.Error()) log.Errorf("Job `%s` errored with an error: `%s`", name, err.Error())
} }
} }
} }
// DefaultJobPanicHandler returns default panic handler for a job. // DefaultJobPanicHandler returns default panic handler for a job.
func DefaultJobPanicHandler() JobPanicHandler { func DefaultJobPanicHandler() JobPanicHandler {
return func(name string, recoverValue interface{}, log JobLogFunc) { return func(name string, recoverValue interface{}, log logger.Logger) {
if recoverValue != nil && name != "" { if recoverValue != nil && name != "" {
log("Job `%s` panicked with value: `%#v`", logging.ERROR, name, recoverValue) log.Errorf("Job `%s` panicked with value: `%#v`", name, recoverValue)
} }
} }
} }
// SetLogger sets logger into JobManager. // SetLogger sets logger into JobManager.
func (j *JobManager) SetLogger(logger LoggerInterface) *JobManager { func (j *JobManager) SetLogger(logger logger.Logger) *JobManager {
if logger != nil { if logger != nil {
j.logger = logger j.logger = logger
} }
@ -167,6 +159,14 @@ func (j *JobManager) SetLogger(logger LoggerInterface) *JobManager {
return j return j
} }
// Logger returns logger.
func (j *JobManager) Logger() logger.Logger {
if !j.enableLogging {
return j.nilLogger
}
return j.logger
}
// SetLogging enables or disables JobManager logging. // SetLogging enables or disables JobManager logging.
func (j *JobManager) SetLogging(enableLogging bool) *JobManager { func (j *JobManager) SetLogging(enableLogging bool) *JobManager {
j.enableLogging = enableLogging j.enableLogging = enableLogging
@ -211,7 +211,7 @@ func (j *JobManager) FetchJob(name string) (value *Job, ok bool) {
// UpdateJob updates job. // UpdateJob updates job.
func (j *JobManager) UpdateJob(name string, job *Job) error { func (j *JobManager) UpdateJob(name string, job *Job) error {
if job, ok := j.FetchJob(name); ok { if _, ok := j.FetchJob(name); ok {
_ = j.UnregisterJob(name) _ = j.UnregisterJob(name)
return j.RegisterJob(name, job) return j.RegisterJob(name, job)
} }
@ -219,10 +219,11 @@ func (j *JobManager) UpdateJob(name string, job *Job) error {
return fmt.Errorf("cannot find job `%s`", name) return fmt.Errorf("cannot find job `%s`", name)
} }
// RunJob starts provided regular job if it's exists. It's async operation and error returns only of job wasn't executed at all. // RunJob starts provided regular job if it's exists.
// It runs asynchronously and error returns only of job wasn't executed at all.
func (j *JobManager) RunJob(name string) error { func (j *JobManager) RunJob(name string) error {
if job, ok := j.FetchJob(name); ok { if job, ok := j.FetchJob(name); ok {
job.run(name, j.log) job.run(name, j.Logger())
return nil return nil
} }
@ -242,7 +243,7 @@ func (j *JobManager) StopJob(name string) error {
// RunJobOnce starts provided job once if it exists. It's also async. // RunJobOnce starts provided job once if it exists. It's also async.
func (j *JobManager) RunJobOnce(name string) error { func (j *JobManager) RunJobOnce(name string) error {
if job, ok := j.FetchJob(name); ok { if job, ok := j.FetchJob(name); ok {
job.runOnce(name, j.log) job.runOnce(name, j.Logger())
return nil return nil
} }
@ -252,7 +253,7 @@ func (j *JobManager) RunJobOnce(name string) error {
// RunJobOnceSync starts provided job once in current goroutine if job exists. Will wait for job to end it's work. // RunJobOnceSync starts provided job once in current goroutine if job exists. Will wait for job to end it's work.
func (j *JobManager) RunJobOnceSync(name string) error { func (j *JobManager) RunJobOnceSync(name string) error {
if job, ok := j.FetchJob(name); ok { if job, ok := j.FetchJob(name); ok {
job.runOnceSync(name, j.log) job.runOnceSync(name, j.Logger())
return nil return nil
} }
@ -264,48 +265,7 @@ func (j *JobManager) Start() {
j.jobs.Range(func(key, value interface{}) bool { j.jobs.Range(func(key, value interface{}) bool {
name := key.(string) name := key.(string)
job := value.(*Job) job := value.(*Job)
job.run(name, j.log) job.run(name, j.Logger())
return true return true
}) })
} }
// log logs via logger or as plaintext.
func (j *JobManager) log(format string, severity logging.Level, args ...interface{}) {
if !j.enableLogging {
return
}
if j.logger != nil {
switch severity {
case logging.CRITICAL:
j.logger.Criticalf(format, args...)
case logging.ERROR:
j.logger.Errorf(format, args...)
case logging.WARNING:
j.logger.Warningf(format, args...)
case logging.NOTICE:
j.logger.Noticef(format, args...)
case logging.INFO:
j.logger.Infof(format, args...)
case logging.DEBUG:
j.logger.Debugf(format, args...)
}
return
}
switch severity {
case logging.CRITICAL:
fmt.Print("[CRITICAL] ", fmt.Sprintf(format, args...))
case logging.ERROR:
fmt.Print("[ERROR] ", fmt.Sprintf(format, args...))
case logging.WARNING:
fmt.Print("[WARNING] ", fmt.Sprintf(format, args...))
case logging.NOTICE:
fmt.Print("[NOTICE] ", fmt.Sprintf(format, args...))
case logging.INFO:
fmt.Print("[INFO] ", fmt.Sprintf(format, args...))
case logging.DEBUG:
fmt.Print("[DEBUG] ", fmt.Sprintf(format, args...))
}
}

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"math/rand" "math/rand"
"strings"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -12,18 +13,20 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/retailcrm/mg-transport-core/v2/core/logger"
) )
type JobTest struct { type JobTest struct {
suite.Suite suite.Suite
job *Job job *Job
syncBool bool
executedChan chan bool executedChan chan bool
randomNumber chan int randomNumber chan int
executeErr chan error executeErr chan error
panicValue chan interface{} panicValue chan interface{}
lastLog string lastLog string
lastMsgLevel logging.Level lastMsgLevel logging.Level
syncBool bool
} }
type JobManagerTest struct { type JobManagerTest struct {
@ -33,6 +36,70 @@ type JobManagerTest struct {
syncRunnerFlag bool syncRunnerFlag bool
} }
type callbackLoggerFunc func(level logging.Level, format string, args ...interface{})
type callbackLogger struct {
fn callbackLoggerFunc
}
func (n *callbackLogger) Fatal(args ...interface{}) {
n.fn(logging.CRITICAL, "", args...)
}
func (n *callbackLogger) Fatalf(format string, args ...interface{}) {
n.fn(logging.CRITICAL, format, args...)
}
func (n *callbackLogger) Panic(args ...interface{}) {
n.fn(logging.CRITICAL, "", args...)
}
func (n *callbackLogger) Panicf(format string, args ...interface{}) {
n.fn(logging.CRITICAL, format, args...)
}
func (n *callbackLogger) Critical(args ...interface{}) {
n.fn(logging.CRITICAL, "", args...)
}
func (n *callbackLogger) Criticalf(format string, args ...interface{}) {
n.fn(logging.CRITICAL, format, args...)
}
func (n *callbackLogger) Error(args ...interface{}) {
n.fn(logging.ERROR, "", args...)
}
func (n *callbackLogger) Errorf(format string, args ...interface{}) {
n.fn(logging.ERROR, format, args...)
}
func (n *callbackLogger) Warning(args ...interface{}) {
n.fn(logging.WARNING, "", args...)
}
func (n *callbackLogger) Warningf(format string, args ...interface{}) {
n.fn(logging.WARNING, format, args...)
}
func (n *callbackLogger) Notice(args ...interface{}) {
n.fn(logging.NOTICE, "", args...)
}
func (n *callbackLogger) Noticef(format string, args ...interface{}) {
n.fn(logging.NOTICE, format, args...)
}
func (n *callbackLogger) Info(args ...interface{}) {
n.fn(logging.INFO, "", args...)
}
func (n *callbackLogger) Infof(format string, args ...interface{}) {
n.fn(logging.INFO, format, args...)
}
func (n *callbackLogger) Debug(args ...interface{}) {
n.fn(logging.DEBUG, "", args...)
}
func (n *callbackLogger) Debugf(format string, args ...interface{}) {
n.fn(logging.DEBUG, format, args...)
}
func TestJob(t *testing.T) { func TestJob(t *testing.T) {
suite.Run(t, new(JobTest)) suite.Run(t, new(JobTest))
} }
@ -48,10 +115,10 @@ func TestDefaultJobErrorHandler(t *testing.T) {
fn := DefaultJobErrorHandler() fn := DefaultJobErrorHandler()
require.NotNil(t, fn) require.NotNil(t, fn)
fn("job", errors.New("test"), func(s string, level logging.Level, i ...interface{}) { fn("job", errors.New("test"), &callbackLogger{fn: func(level logging.Level, s string, i ...interface{}) {
require.Len(t, i, 2) require.Len(t, i, 2)
assert.Equal(t, fmt.Sprintf("%s", i[1]), "test") assert.Equal(t, fmt.Sprintf("%s", i[1]), "test")
}) }})
} }
func TestDefaultJobPanicHandler(t *testing.T) { func TestDefaultJobPanicHandler(t *testing.T) {
@ -61,41 +128,52 @@ func TestDefaultJobPanicHandler(t *testing.T) {
fn := DefaultJobPanicHandler() fn := DefaultJobPanicHandler()
require.NotNil(t, fn) require.NotNil(t, fn)
fn("job", errors.New("test"), func(s string, level logging.Level, i ...interface{}) { fn("job", errors.New("test"), &callbackLogger{fn: func(level logging.Level, s string, i ...interface{}) {
require.Len(t, i, 2) require.Len(t, i, 2)
assert.Equal(t, fmt.Sprintf("%s", i[1]), "test") assert.Equal(t, fmt.Sprintf("%s", i[1]), "test")
}) }})
} }
func (t *JobTest) testErrorHandler() JobErrorHandler { func (t *JobTest) testErrorHandler() JobErrorHandler {
return func(name string, err error, logFunc JobLogFunc) { return func(name string, err error, log logger.Logger) {
t.executeErr <- err t.executeErr <- err
} }
} }
func (t *JobTest) testPanicHandler() JobPanicHandler { func (t *JobTest) testPanicHandler() JobPanicHandler {
return func(name string, i interface{}, logFunc JobLogFunc) { return func(name string, i interface{}, log logger.Logger) {
t.panicValue <- i t.panicValue <- i
} }
} }
func (t *JobTest) testLogFunc() JobLogFunc { func (t *JobTest) testLogger() logger.Logger {
return func(s string, level logging.Level, i ...interface{}) { return &callbackLogger{fn: func(level logging.Level, format string, args ...interface{}) {
t.lastLog = fmt.Sprintf(s, i...) if format == "" {
t.lastMsgLevel = level var sb strings.Builder
} sb.Grow(3 * len(args)) // nolint:gomnd
for i := 0; i < len(args); i++ {
sb.WriteString("%v ")
} }
func (t *JobTest) executed(wait time.Duration, defaultVal bool) bool { format = strings.TrimRight(sb.String(), " ")
}
t.lastLog = fmt.Sprintf(format, args...)
t.lastMsgLevel = level
}}
}
func (t *JobTest) executed() bool {
if t.executedChan == nil { if t.executedChan == nil {
return defaultVal return false
} }
select { select {
case c := <-t.executedChan: case c := <-t.executedChan:
return c return c
case <-time.After(wait): case <-time.After(time.Millisecond):
return defaultVal return false
} }
} }
@ -140,7 +218,7 @@ func (t *JobTest) clear() {
func (t *JobTest) onceJob() { func (t *JobTest) onceJob() {
t.job = &Job{ t.job = &Job{
Command: func(logFunc JobLogFunc) error { Command: func(log logger.Logger) error {
t.executedChan <- true t.executedChan <- true
return nil return nil
}, },
@ -153,7 +231,7 @@ func (t *JobTest) onceJob() {
func (t *JobTest) onceErrorJob() { func (t *JobTest) onceErrorJob() {
t.job = &Job{ t.job = &Job{
Command: func(logFunc JobLogFunc) error { Command: func(log logger.Logger) error {
t.executedChan <- true t.executedChan <- true
return errors.New("test error") return errors.New("test error")
}, },
@ -166,7 +244,7 @@ func (t *JobTest) onceErrorJob() {
func (t *JobTest) oncePanicJob() { func (t *JobTest) oncePanicJob() {
t.job = &Job{ t.job = &Job{
Command: func(logFunc JobLogFunc) error { Command: func(log logger.Logger) error {
t.executedChan <- true t.executedChan <- true
panic("test panic") panic("test panic")
}, },
@ -180,9 +258,9 @@ func (t *JobTest) oncePanicJob() {
func (t *JobTest) regularJob() { func (t *JobTest) regularJob() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
t.job = &Job{ t.job = &Job{
Command: func(logFunc JobLogFunc) error { Command: func(log logger.Logger) error {
t.executedChan <- true t.executedChan <- true
t.randomNumber <- rand.Int() t.randomNumber <- rand.Int() // nolint:gosec
return nil return nil
}, },
ErrorHandler: t.testErrorHandler(), ErrorHandler: t.testErrorHandler(),
@ -195,7 +273,7 @@ func (t *JobTest) regularJob() {
func (t *JobTest) regularSyncJob() { func (t *JobTest) regularSyncJob() {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
t.job = &Job{ t.job = &Job{
Command: func(logFunc JobLogFunc) error { Command: func(log logger.Logger) error {
t.syncBool = true t.syncBool = true
return nil return nil
}, },
@ -213,10 +291,10 @@ func (t *JobTest) Test_getWrappedFunc() {
t.clear() t.clear()
t.onceJob() t.onceJob()
fn := t.job.getWrappedFunc("job", t.testLogFunc()) fn := t.job.getWrappedFunc("job", t.testLogger())
require.NotNil(t.T(), fn) require.NotNil(t.T(), fn)
go fn() go fn()
assert.True(t.T(), t.executed(time.Millisecond, false)) assert.True(t.T(), t.executed())
assert.False(t.T(), t.errored(time.Millisecond)) assert.False(t.T(), t.errored(time.Millisecond))
assert.False(t.T(), t.panicked(time.Millisecond)) assert.False(t.T(), t.panicked(time.Millisecond))
} }
@ -228,10 +306,10 @@ func (t *JobTest) Test_getWrappedFuncError() {
t.clear() t.clear()
t.onceErrorJob() t.onceErrorJob()
fn := t.job.getWrappedFunc("job", t.testLogFunc()) fn := t.job.getWrappedFunc("job", t.testLogger())
require.NotNil(t.T(), fn) require.NotNil(t.T(), fn)
go fn() go fn()
assert.True(t.T(), t.executed(time.Millisecond, false)) assert.True(t.T(), t.executed())
assert.True(t.T(), t.errored(time.Millisecond)) assert.True(t.T(), t.errored(time.Millisecond))
assert.False(t.T(), t.panicked(time.Millisecond)) assert.False(t.T(), t.panicked(time.Millisecond))
} }
@ -243,10 +321,10 @@ func (t *JobTest) Test_getWrappedFuncPanic() {
t.clear() t.clear()
t.oncePanicJob() t.oncePanicJob()
fn := t.job.getWrappedFunc("job", t.testLogFunc()) fn := t.job.getWrappedFunc("job", t.testLogger())
require.NotNil(t.T(), fn) require.NotNil(t.T(), fn)
go fn() go fn()
assert.True(t.T(), t.executed(time.Millisecond, false)) assert.True(t.T(), t.executed())
assert.False(t.T(), t.errored(time.Millisecond)) assert.False(t.T(), t.errored(time.Millisecond))
assert.True(t.T(), t.panicked(time.Millisecond)) assert.True(t.T(), t.panicked(time.Millisecond))
} }
@ -257,10 +335,10 @@ func (t *JobTest) Test_run() {
}() }()
t.regularJob() t.regularJob()
t.job.run("job", t.testLogFunc()) t.job.run("job", t.testLogger())
time.Sleep(time.Millisecond * 5) time.Sleep(time.Millisecond * 10)
t.job.stop() t.job.stop()
require.True(t.T(), t.executed(time.Millisecond, false)) require.True(t.T(), t.executed())
} }
func (t *JobTest) Test_runOnce() { func (t *JobTest) Test_runOnce() {
@ -269,9 +347,9 @@ func (t *JobTest) Test_runOnce() {
}() }()
t.regularJob() t.regularJob()
t.job.runOnce("job", t.testLogFunc()) t.job.runOnce("job", t.testLogger())
time.Sleep(time.Millisecond * 5) time.Sleep(time.Millisecond * 5)
require.True(t.T(), t.executed(time.Millisecond, false)) require.True(t.T(), t.executed())
first := 0 first := 0
select { select {
@ -301,7 +379,7 @@ func (t *JobTest) Test_runOnceSync() {
t.clear() t.clear()
t.regularSyncJob() t.regularSyncJob()
require.False(t.T(), t.syncBool) require.False(t.T(), t.syncBool)
t.job.runOnceSync("job", t.testLogFunc()) t.job.runOnceSync("job", t.testLogger())
assert.True(t.T(), t.syncBool) assert.True(t.T(), t.syncBool)
} }
@ -326,11 +404,11 @@ func (t *JobManagerTest) WaitForJob() bool {
func (t *JobManagerTest) Test_SetLogger() { func (t *JobManagerTest) Test_SetLogger() {
t.manager.logger = nil t.manager.logger = nil
t.manager.SetLogger(NewLogger("test", logging.ERROR, DefaultLogFormatter())) t.manager.SetLogger(logger.NewStandard("test", logging.ERROR, logger.DefaultLogFormatter()))
assert.IsType(t.T(), &Logger{}, t.manager.logger) assert.IsType(t.T(), &logger.StandardLogger{}, t.manager.logger)
t.manager.SetLogger(nil) t.manager.SetLogger(nil)
assert.IsType(t.T(), &Logger{}, t.manager.logger) assert.IsType(t.T(), &logger.StandardLogger{}, t.manager.logger)
} }
func (t *JobManagerTest) Test_SetLogging() { func (t *JobManagerTest) Test_SetLogging() {
@ -351,7 +429,7 @@ func (t *JobManagerTest) Test_RegisterJobNil() {
func (t *JobManagerTest) Test_RegisterJob() { func (t *JobManagerTest) Test_RegisterJob() {
require.NotNil(t.T(), t.manager.jobs) require.NotNil(t.T(), t.manager.jobs)
err := t.manager.RegisterJob("job", &Job{ err := t.manager.RegisterJob("job", &Job{
Command: func(log JobLogFunc) error { Command: func(log logger.Logger) error {
t.runnerWG.Done() t.runnerWG.Done()
return nil return nil
}, },
@ -360,7 +438,7 @@ func (t *JobManagerTest) Test_RegisterJob() {
}) })
assert.NoError(t.T(), err) assert.NoError(t.T(), err)
err = t.manager.RegisterJob("job_regular", &Job{ err = t.manager.RegisterJob("job_regular", &Job{
Command: func(log JobLogFunc) error { Command: func(log logger.Logger) error {
t.runnerWG.Done() t.runnerWG.Done()
return nil return nil
}, },
@ -371,7 +449,7 @@ func (t *JobManagerTest) Test_RegisterJob() {
}) })
assert.NoError(t.T(), err) assert.NoError(t.T(), err)
err = t.manager.RegisterJob("job_sync", &Job{ err = t.manager.RegisterJob("job_sync", &Job{
Command: func(log JobLogFunc) error { Command: func(log logger.Logger) error {
t.syncRunnerFlag = true t.syncRunnerFlag = true
return nil return nil
}, },
@ -398,7 +476,7 @@ func (t *JobManagerTest) Test_FetchJob() {
require.Nil(t.T(), recover()) require.Nil(t.T(), recover())
}() }()
require.NoError(t.T(), t.manager.RegisterJob("test_fetch", &Job{Command: func(logFunc JobLogFunc) error { require.NoError(t.T(), t.manager.RegisterJob("test_fetch", &Job{Command: func(log logger.Logger) error {
return nil return nil
}})) }}))
require.NotNil(t.T(), t.manager.jobs) require.NotNil(t.T(), t.manager.jobs)
@ -479,8 +557,8 @@ func (t *JobManagerTest) Test_Start() {
manager := NewJobManager() manager := NewJobManager()
_ = manager.RegisterJob("job", &Job{ _ = manager.RegisterJob("job", &Job{
Command: func(logFunc JobLogFunc) error { Command: func(log logger.Logger) error {
logFunc("alive!", logging.INFO) log.Info("alive!")
return nil return nil
}, },
ErrorHandler: DefaultJobErrorHandler(), ErrorHandler: DefaultJobErrorHandler(),
@ -488,25 +566,3 @@ func (t *JobManagerTest) Test_Start() {
}) })
manager.Start() manager.Start()
} }
func (t *JobManagerTest) Test_log() {
defer func() {
require.Nil(t.T(), recover())
}()
testLog := func() {
t.manager.log("test", logging.CRITICAL)
t.manager.log("test", logging.ERROR)
t.manager.log("test", logging.WARNING)
t.manager.log("test", logging.NOTICE)
t.manager.log("test", logging.INFO)
t.manager.log("test", logging.DEBUG)
}
t.manager.SetLogging(false)
testLog()
t.manager.SetLogging(true)
t.manager.logger = nil
testLog()
t.manager.logger = NewLogger("test", logging.DEBUG, DefaultLogFormatter())
testLog()
}

View File

@ -11,6 +11,8 @@ import (
"github.com/nicksnyder/go-i18n/v2/i18n" "github.com/nicksnyder/go-i18n/v2/i18n"
"golang.org/x/text/language" "golang.org/x/text/language"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
) )
// DefaultLanguages for transports. // DefaultLanguages for transports.
@ -138,11 +140,6 @@ func (l *Localizer) LocalizationFuncMap() template.FuncMap {
} }
} }
// getLocaleBundle returns current locale bundle and creates it if needed.
func (l *Localizer) getLocaleBundle() *i18n.Bundle {
return l.createLocaleBundleByTag(l.LanguageTag)
}
// createLocaleBundleByTag creates locale bundle by language tag. // createLocaleBundleByTag creates locale bundle by language tag.
func (l *Localizer) createLocaleBundleByTag(tag language.Tag) *i18n.Bundle { func (l *Localizer) createLocaleBundleByTag(tag language.Tag) *i18n.Bundle {
bundle := i18n.NewBundle(tag) bundle := i18n.NewBundle(tag)
@ -279,8 +276,8 @@ func (l *Localizer) GetLocalizedMessage(messageID string) string {
return l.getCurrentLocalizer().MustLocalize(&i18n.LocalizeConfig{MessageID: messageID}) return l.getCurrentLocalizer().MustLocalize(&i18n.LocalizeConfig{MessageID: messageID})
} }
// GetLocalizedTemplateMessage will return localized message with specified data. It doesn't use `Must` prefix in order to keep BC. // GetLocalizedTemplateMessage will return localized message with specified data.
// It uses text/template syntax: https://golang.org/pkg/text/template/ // It doesn't use `Must` prefix in order to keep BC. It uses text/template syntax: https://golang.org/pkg/text/template/
func (l *Localizer) GetLocalizedTemplateMessage(messageID string, templateData map[string]interface{}) string { func (l *Localizer) GetLocalizedTemplateMessage(messageID string, templateData map[string]interface{}) string {
return l.getCurrentLocalizer().MustLocalize(&i18n.LocalizeConfig{ return l.getCurrentLocalizer().MustLocalize(&i18n.LocalizeConfig{
MessageID: messageID, MessageID: messageID,
@ -302,13 +299,52 @@ func (l *Localizer) LocalizeTemplateMessage(messageID string, templateData map[s
}) })
} }
// BadRequestLocalized is same as BadRequest(string), but passed string will be localized. // BadRequestLocalized is same as errorutil.BadRequest(string), but passed string will be localized.
func (l *Localizer) BadRequestLocalized(err string) (int, interface{}) { func (l *Localizer) BadRequestLocalized(err string) (int, interface{}) {
return BadRequest(l.GetLocalizedMessage(err)) return errorutil.BadRequest(l.GetLocalizedMessage(err))
} }
// GetContextLocalizer returns localizer from context if it exist there. // UnauthorizedLocalized is same as errorutil.Unauthorized(string), but passed string will be localized.
func GetContextLocalizer(c *gin.Context) (*Localizer, bool) { func (l *Localizer) UnauthorizedLocalized(err string) (int, interface{}) {
return errorutil.Unauthorized(l.GetLocalizedMessage(err))
}
// ForbiddenLocalized is same as errorutil.Forbidden(string), but passed string will be localized.
func (l *Localizer) ForbiddenLocalized(err string) (int, interface{}) {
return errorutil.Forbidden(l.GetLocalizedMessage(err))
}
// InternalServerErrorLocalized is same as errorutil.InternalServerError(string), but passed string will be localized.
func (l *Localizer) InternalServerErrorLocalized(err string) (int, interface{}) {
return errorutil.InternalServerError(l.GetLocalizedMessage(err))
}
// GetContextLocalizer returns localizer from context if it is present there.
// Language will be set using Accept-Language header and root language tag.
func GetContextLocalizer(c *gin.Context) (loc *Localizer, ok bool) {
loc, ok = extractLocalizerFromContext(c)
if loc != nil {
loc.SetLocale(c.GetHeader("Accept-Language"))
lang := GetRootLanguageTag(loc.LanguageTag)
if lang != loc.LanguageTag {
loc.SetLanguage(lang)
loc.LoadTranslations()
}
}
return
}
// MustGetContextLocalizer returns Localizer instance if it exists in provided context. Panics otherwise.
func MustGetContextLocalizer(c *gin.Context) *Localizer {
if localizer, ok := GetContextLocalizer(c); ok {
return localizer
}
panic("localizer is not present in provided context")
}
// extractLocalizerFromContext returns localizer from context if it exist there.
func extractLocalizerFromContext(c *gin.Context) (*Localizer, bool) {
if c == nil { if c == nil {
return nil, false return nil, false
} }
@ -322,10 +358,14 @@ func GetContextLocalizer(c *gin.Context) (*Localizer, bool) {
return nil, false return nil, false
} }
// MustGetContextLocalizer returns Localizer instance if it exists in provided context. Panics otherwise. // GetRootLanguageTag returns root language tag for country-specific tags (e.g "es" for "es_CA").
func MustGetContextLocalizer(c *gin.Context) *Localizer { // Useful when you don't have country-specific language variations.
if localizer, ok := GetContextLocalizer(c); ok { func GetRootLanguageTag(t language.Tag) language.Tag {
return localizer for {
parent := t.Parent()
if parent == language.Und {
return t
}
t = parent
} }
panic("localizer is not present in provided context")
} }

View File

@ -17,6 +17,8 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
) )
var ( var (
@ -120,7 +122,7 @@ func (l *LocalizerTest) Test_LocalizationMiddleware_Httptest() {
wg.Add(1) wg.Add(1)
go func(m map[language.Tag]string, wg *sync.WaitGroup) { go func(m map[language.Tag]string, wg *sync.WaitGroup) {
var tag language.Tag var tag language.Tag
switch rand.Intn(3-1) + 1 { switch rand.Intn(3-1) + 1 { // nolint:gosec
case 1: case 1:
tag = language.English tag = language.English
case 2: case 2:
@ -183,7 +185,7 @@ func (l *LocalizerTest) Test_BadRequestLocalized() {
status, resp := l.localizer.BadRequestLocalized("message") status, resp := l.localizer.BadRequestLocalized("message")
assert.Equal(l.T(), http.StatusBadRequest, status) assert.Equal(l.T(), http.StatusBadRequest, status)
assert.Equal(l.T(), "Test message", resp.(ErrorResponse).Error) assert.Equal(l.T(), "Test message", resp.(errorutil.Response).Error)
} }
// getContextWithLang generates context with Accept-Language header. // getContextWithLang generates context with Accept-Language header.

View File

@ -0,0 +1,96 @@
package logger
import (
"fmt"
"github.com/op/go-logging"
)
// DefaultAccountLoggerFormat contains default prefix format for the AccountLoggerDecorator.
// Its messages will look like this (assuming you will provide the connection URL and account name):
// messageHandler (https://any.simla.com => @tg_account): sent message with id=1
const DefaultAccountLoggerFormat = "%s (%s => %s):"
type ComponentAware interface {
SetComponent(string)
}
type ConnectionAware interface {
SetConnectionIdentifier(string)
}
type AccountAware interface {
SetAccountIdentifier(string)
}
type PrefixFormatAware interface {
SetPrefixFormat(string)
}
type AccountLogger interface {
PrefixedLogger
ComponentAware
ConnectionAware
AccountAware
PrefixFormatAware
}
type AccountLoggerDecorator struct {
format string
component string
connIdentifier string
accIdentifier string
PrefixDecorator
}
func DecorateForAccount(base Logger, component, connIdentifier, accIdentifier string) AccountLogger {
return (&AccountLoggerDecorator{
PrefixDecorator: PrefixDecorator{
backend: base,
},
component: component,
connIdentifier: connIdentifier,
accIdentifier: accIdentifier,
}).updatePrefix()
}
// NewForAccount returns logger for account. It uses StandardLogger under the hood.
func NewForAccount(
transportCode, component, connIdentifier, accIdentifier string,
logLevel logging.Level,
logFormat logging.Formatter) AccountLogger {
return DecorateForAccount(NewStandard(transportCode, logLevel, logFormat),
component, connIdentifier, accIdentifier)
}
func (a *AccountLoggerDecorator) SetComponent(s string) {
a.component = s
a.updatePrefix()
}
func (a *AccountLoggerDecorator) SetConnectionIdentifier(s string) {
a.connIdentifier = s
a.updatePrefix()
}
func (a *AccountLoggerDecorator) SetAccountIdentifier(s string) {
a.accIdentifier = s
a.updatePrefix()
}
func (a *AccountLoggerDecorator) SetPrefixFormat(s string) {
a.format = s
a.updatePrefix()
}
func (a *AccountLoggerDecorator) updatePrefix() AccountLogger {
a.SetPrefix(fmt.Sprintf(a.prefixFormat(), a.component, a.connIdentifier, a.accIdentifier))
return a
}
func (a *AccountLoggerDecorator) prefixFormat() string {
if a.format == "" {
return DefaultAccountLoggerFormat
}
return a.format
}

View File

@ -0,0 +1,98 @@
package logger
import (
"bytes"
"fmt"
"testing"
"github.com/op/go-logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
const (
testComponent = "ComponentName"
testConnectionID = "https://test.retailcrm.pro"
testAccountID = "@account_name"
)
type AccountLoggerDecoratorTest struct {
suite.Suite
buf *bytes.Buffer
logger AccountLogger
}
func TestAccountLoggerDecorator(t *testing.T) {
suite.Run(t, new(AccountLoggerDecoratorTest))
}
func TestNewForAccount(t *testing.T) {
buf := &bytes.Buffer{}
logger := NewForAccount("code", "component", "conn", "acc", logging.DEBUG, DefaultLogFormatter())
logger.(*AccountLoggerDecorator).backend.(*StandardLogger).
SetBaseLogger(NewBase(buf, "code", logging.DEBUG, DefaultLogFormatter()))
logger.Debugf("message %s", "text")
assert.Contains(t, buf.String(), fmt.Sprintf(DefaultAccountLoggerFormat+" message", "component", "conn", "acc"))
}
func (t *AccountLoggerDecoratorTest) SetupSuite() {
t.buf = &bytes.Buffer{}
t.logger = DecorateForAccount((&StandardLogger{}).
SetBaseLogger(NewBase(t.buf, "code", logging.DEBUG, DefaultLogFormatter())),
testComponent, testConnectionID, testAccountID)
}
func (t *AccountLoggerDecoratorTest) SetupTest() {
t.buf.Reset()
t.logger.SetComponent(testComponent)
t.logger.SetConnectionIdentifier(testConnectionID)
t.logger.SetAccountIdentifier(testAccountID)
t.logger.SetPrefixFormat(DefaultAccountLoggerFormat)
}
func (t *AccountLoggerDecoratorTest) Test_LogWithNewFormat() {
t.logger.SetPrefixFormat("[%s (%s: %s)] =>")
t.logger.Infof("test message")
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(),
fmt.Sprintf("[%s (%s: %s)] =>", testComponent, testConnectionID, testAccountID))
}
func (t *AccountLoggerDecoratorTest) Test_Log() {
t.logger.Infof("test message")
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(),
fmt.Sprintf(DefaultAccountLoggerFormat, testComponent, testConnectionID, testAccountID))
}
func (t *AccountLoggerDecoratorTest) Test_SetComponent() {
t.logger.SetComponent("NewComponent")
t.logger.Infof("test message")
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(),
fmt.Sprintf(DefaultAccountLoggerFormat, "NewComponent", testConnectionID, testAccountID))
}
func (t *AccountLoggerDecoratorTest) Test_SetConnectionIdentifier() {
t.logger.SetComponent("NewComponent")
t.logger.SetConnectionIdentifier("https://test.simla.com")
t.logger.Infof("test message")
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(),
fmt.Sprintf(DefaultAccountLoggerFormat, "NewComponent", "https://test.simla.com", testAccountID))
}
func (t *AccountLoggerDecoratorTest) Test_SetAccountIdentifier() {
t.logger.SetComponent("NewComponent")
t.logger.SetConnectionIdentifier("https://test.simla.com")
t.logger.SetAccountIdentifier("@new_account_name")
t.logger.Infof("test message")
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(),
fmt.Sprintf(DefaultAccountLoggerFormat, "NewComponent", "https://test.simla.com", "@new_account_name"))
}

View File

@ -1,14 +1,15 @@
package core package logger
import ( import (
"io"
"os" "os"
"sync" "sync"
"github.com/op/go-logging" "github.com/op/go-logging"
) )
// LoggerInterface contains methods which should be present in logger implementation. // Logger contains methods which should be present in logger implementation.
type LoggerInterface interface { type Logger interface {
Fatal(args ...interface{}) Fatal(args ...interface{})
Fatalf(format string, args ...interface{}) Fatalf(format string, args ...interface{})
Panic(args ...interface{}) Panic(args ...interface{})
@ -27,30 +28,30 @@ type LoggerInterface interface {
Debugf(format string, args ...interface{}) Debugf(format string, args ...interface{})
} }
// Logger component. Uses github.com/op/go-logging under the hood. // StandardLogger is a default implementation of Logger. Uses github.com/op/go-logging under the hood.
// This logger can prevent any write operations (disabled by default, use .Exclusive() method to enable). // This logger can prevent any write operations (disabled by default, use .Exclusive() method to enable).
type Logger struct { type StandardLogger struct {
logger *logging.Logger logger *logging.Logger
mutex *sync.RWMutex mutex *sync.RWMutex
} }
// NewLogger will create new goroutine-safe logger with specified formatter. // NewStandard will create new StandardLogger with specified formatter.
// Usage: // Usage:
// logger := NewLogger("telegram", logging.ERROR, DefaultLogFormatter()) // logger := NewLogger("telegram", logging.ERROR, DefaultLogFormatter())
func NewLogger(transportCode string, logLevel logging.Level, logFormat logging.Formatter) *Logger { func NewStandard(transportCode string, logLevel logging.Level, logFormat logging.Formatter) *StandardLogger {
return &Logger{ return &StandardLogger{
logger: newInheritedLogger(transportCode, logLevel, logFormat), logger: NewBase(os.Stdout, transportCode, logLevel, logFormat),
} }
} }
// newInheritedLogger is a constructor for underlying logger in Logger struct. // NewBase is a constructor for underlying logger in the StandardLogger struct.
func newInheritedLogger(transportCode string, logLevel logging.Level, logFormat logging.Formatter) *logging.Logger { func NewBase(out io.Writer, transportCode string, logLevel logging.Level, logFormat logging.Formatter) *logging.Logger {
logger := logging.MustGetLogger(transportCode) logger := logging.MustGetLogger(transportCode)
logBackend := logging.NewLogBackend(os.Stdout, "", 0) logBackend := logging.NewLogBackend(out, "", 0)
formatBackend := logging.NewBackendFormatter(logBackend, logFormat) formatBackend := logging.NewBackendFormatter(logBackend, logFormat)
backend1Leveled := logging.AddModuleLevel(logBackend) backend1Leveled := logging.AddModuleLevel(formatBackend)
backend1Leveled.SetLevel(logLevel, "") backend1Leveled.SetLevel(logLevel, "")
logging.SetBackend(formatBackend) logger.SetBackend(backend1Leveled)
return logger return logger
} }
@ -63,7 +64,7 @@ func DefaultLogFormatter() logging.Formatter {
} }
// Exclusive makes logger goroutine-safe. // Exclusive makes logger goroutine-safe.
func (l *Logger) Exclusive() *Logger { func (l *StandardLogger) Exclusive() *StandardLogger {
if l.mutex == nil { if l.mutex == nil {
l.mutex = &sync.RWMutex{} l.mutex = &sync.RWMutex{}
} }
@ -71,127 +72,133 @@ func (l *Logger) Exclusive() *Logger {
return l return l
} }
// SetBaseLogger replaces base logger with the provided instance.
func (l *StandardLogger) SetBaseLogger(logger *logging.Logger) *StandardLogger {
l.logger = logger
return l
}
// lock locks logger. // lock locks logger.
func (l *Logger) lock() { func (l *StandardLogger) lock() {
if l.mutex != nil { if l.mutex != nil {
l.mutex.Lock() l.mutex.Lock()
} }
} }
// unlock unlocks logger. // unlock unlocks logger.
func (l *Logger) unlock() { func (l *StandardLogger) unlock() {
if l.mutex != nil { if l.mutex != nil {
l.mutex.Unlock() l.mutex.Unlock()
} }
} }
// Fatal is equivalent to l.Critical(fmt.Sprint()) followed by a call to os.Exit(1). // Fatal is equivalent to l.Critical(fmt.Sprint()) followed by a call to os.Exit(1).
func (l *Logger) Fatal(args ...interface{}) { func (l *StandardLogger) Fatal(args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Fatal(args...) l.logger.Fatal(args...)
} }
// Fatalf is equivalent to l.Critical followed by a call to os.Exit(1). // Fatalf is equivalent to l.Critical followed by a call to os.Exit(1).
func (l *Logger) Fatalf(format string, args ...interface{}) { func (l *StandardLogger) Fatalf(format string, args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Fatalf(format, args...) l.logger.Fatalf(format, args...)
} }
// Panic is equivalent to l.Critical(fmt.Sprint()) followed by a call to panic(). // Panic is equivalent to l.Critical(fmt.Sprint()) followed by a call to panic().
func (l *Logger) Panic(args ...interface{}) { func (l *StandardLogger) Panic(args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Panic(args...) l.logger.Panic(args...)
} }
// Panicf is equivalent to l.Critical followed by a call to panic(). // Panicf is equivalent to l.Critical followed by a call to panic().
func (l *Logger) Panicf(format string, args ...interface{}) { func (l *StandardLogger) Panicf(format string, args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Panicf(format, args...) l.logger.Panicf(format, args...)
} }
// Critical logs a message using CRITICAL as log level. // Critical logs a message using CRITICAL as log level.
func (l *Logger) Critical(args ...interface{}) { func (l *StandardLogger) Critical(args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Critical(args...) l.logger.Critical(args...)
} }
// Criticalf logs a message using CRITICAL as log level. // Criticalf logs a message using CRITICAL as log level.
func (l *Logger) Criticalf(format string, args ...interface{}) { func (l *StandardLogger) Criticalf(format string, args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Criticalf(format, args...) l.logger.Criticalf(format, args...)
} }
// Error logs a message using ERROR as log level. // Error logs a message using ERROR as log level.
func (l *Logger) Error(args ...interface{}) { func (l *StandardLogger) Error(args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Error(args...) l.logger.Error(args...)
} }
// Errorf logs a message using ERROR as log level. // Errorf logs a message using ERROR as log level.
func (l *Logger) Errorf(format string, args ...interface{}) { func (l *StandardLogger) Errorf(format string, args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Errorf(format, args...) l.logger.Errorf(format, args...)
} }
// Warning logs a message using WARNING as log level. // Warning logs a message using WARNING as log level.
func (l *Logger) Warning(args ...interface{}) { func (l *StandardLogger) Warning(args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Warning(args...) l.logger.Warning(args...)
} }
// Warningf logs a message using WARNING as log level. // Warningf logs a message using WARNING as log level.
func (l *Logger) Warningf(format string, args ...interface{}) { func (l *StandardLogger) Warningf(format string, args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Warningf(format, args...) l.logger.Warningf(format, args...)
} }
// Notice logs a message using NOTICE as log level. // Notice logs a message using NOTICE as log level.
func (l *Logger) Notice(args ...interface{}) { func (l *StandardLogger) Notice(args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Notice(args...) l.logger.Notice(args...)
} }
// Noticef logs a message using NOTICE as log level. // Noticef logs a message using NOTICE as log level.
func (l *Logger) Noticef(format string, args ...interface{}) { func (l *StandardLogger) Noticef(format string, args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Noticef(format, args...) l.logger.Noticef(format, args...)
} }
// Info logs a message using INFO as log level. // Info logs a message using INFO as log level.
func (l *Logger) Info(args ...interface{}) { func (l *StandardLogger) Info(args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Info(args...) l.logger.Info(args...)
} }
// Infof logs a message using INFO as log level. // Infof logs a message using INFO as log level.
func (l *Logger) Infof(format string, args ...interface{}) { func (l *StandardLogger) Infof(format string, args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Infof(format, args...) l.logger.Infof(format, args...)
} }
// Debug logs a message using DEBUG as log level. // Debug logs a message using DEBUG as log level.
func (l *Logger) Debug(args ...interface{}) { func (l *StandardLogger) Debug(args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Debug(args...) l.logger.Debug(args...)
} }
// Debugf logs a message using DEBUG as log level. // Debugf logs a message using DEBUG as log level.
func (l *Logger) Debugf(format string, args ...interface{}) { func (l *StandardLogger) Debugf(format string, args ...interface{}) {
l.lock() l.lock()
defer l.unlock() defer l.unlock()
l.logger.Debugf(format, args...) l.logger.Debugf(format, args...)

168
core/logger/logger_test.go Normal file
View File

@ -0,0 +1,168 @@
package logger
import (
"bytes"
"testing"
"github.com/op/go-logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type StandardLoggerTest struct {
suite.Suite
logger *StandardLogger
buf *bytes.Buffer
}
func TestLogger_NewLogger(t *testing.T) {
logger := NewStandard("code", logging.DEBUG, DefaultLogFormatter())
assert.NotNil(t, logger)
}
func TestLogger_DefaultLogFormatter(t *testing.T) {
formatter := DefaultLogFormatter()
assert.NotNil(t, formatter)
assert.IsType(t, logging.MustStringFormatter(`%{message}`), formatter)
}
func Test_Logger(t *testing.T) {
suite.Run(t, new(StandardLoggerTest))
}
func (t *StandardLoggerTest) SetupSuite() {
t.buf = &bytes.Buffer{}
t.logger = (&StandardLogger{}).
Exclusive().
SetBaseLogger(NewBase(t.buf, "code", logging.DEBUG, DefaultLogFormatter()))
}
func (t *StandardLoggerTest) SetupTest() {
t.buf.Reset()
}
// TODO Cover Fatal and Fatalf (implementation below is no-op)
// func (t *StandardLoggerTest) Test_Fatal() {
// if os.Getenv("FLAG") == "1" {
// t.logger.Fatal("test", "fatal")
// return
// }
// cmd := exec.Command(os.Args[0], "-test.run=TestGetConfig")
// cmd.Env = append(os.Environ(), "FLAG=1")
// err := cmd.Run()
// e, ok := err.(*exec.ExitError)
// expectedErrorString := "test fatal"
// t.Assert().Equal(true, ok)
// t.Assert().Equal(expectedErrorString, e.Error())
// }
func (t *StandardLoggerTest) Test_Panic() {
defer func() {
t.Assert().NotNil(recover())
t.Assert().Contains(t.buf.String(), "CRIT")
t.Assert().Contains(t.buf.String(), "panic")
}()
t.logger.Panic("panic")
}
func (t *StandardLoggerTest) Test_Panicf() {
defer func() {
t.Assert().NotNil(recover())
t.Assert().Contains(t.buf.String(), "CRIT")
t.Assert().Contains(t.buf.String(), "panicf")
}()
t.logger.Panicf("panicf")
}
func (t *StandardLoggerTest) Test_Critical() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "CRIT")
t.Assert().Contains(t.buf.String(), "critical")
}()
t.logger.Critical("critical")
}
func (t *StandardLoggerTest) Test_Criticalf() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "CRIT")
t.Assert().Contains(t.buf.String(), "critical")
}()
t.logger.Criticalf("critical")
}
func (t *StandardLoggerTest) Test_Warning() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "WARN")
t.Assert().Contains(t.buf.String(), "warning")
}()
t.logger.Warning("warning")
}
func (t *StandardLoggerTest) Test_Notice() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "NOTI")
t.Assert().Contains(t.buf.String(), "notice")
}()
t.logger.Notice("notice")
}
func (t *StandardLoggerTest) Test_Info() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(), "info")
}()
t.logger.Info("info")
}
func (t *StandardLoggerTest) Test_Debug() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "DEBU")
t.Assert().Contains(t.buf.String(), "debug")
}()
t.logger.Debug("debug")
}
func (t *StandardLoggerTest) Test_Warningf() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "WARN")
t.Assert().Contains(t.buf.String(), "warning")
}()
t.logger.Warningf("%s", "warning")
}
func (t *StandardLoggerTest) Test_Noticef() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "NOTI")
t.Assert().Contains(t.buf.String(), "notice")
}()
t.logger.Noticef("%s", "notice")
}
func (t *StandardLoggerTest) Test_Infof() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(), "info")
}()
t.logger.Infof("%s", "info")
}
func (t *StandardLoggerTest) Test_Debugf() {
defer func() {
t.Require().Nil(recover())
t.Assert().Contains(t.buf.String(), "DEBU")
t.Assert().Contains(t.buf.String(), "debug")
}()
t.logger.Debugf("%s", "debug")
}

45
core/logger/nil_logger.go Normal file
View File

@ -0,0 +1,45 @@
package logger
import (
"fmt"
"os"
)
// Nil provides Logger implementation that does almost nothing when called.
// Panic, Panicf, Fatal and Fatalf methods still cause panic and immediate program termination respectively.
// All other methods won't do anything at all.
type Nil struct{}
func (n Nil) Fatal(args ...interface{}) {
os.Exit(1)
}
func (n Nil) Fatalf(format string, args ...interface{}) {
os.Exit(1)
}
func (n Nil) Panic(args ...interface{}) {
panic(fmt.Sprint(args...))
}
func (n Nil) Panicf(format string, args ...interface{}) {
panic(fmt.Sprintf(format, args...))
}
func (n Nil) Critical(args ...interface{}) {}
func (n Nil) Criticalf(format string, args ...interface{}) {}
func (n Nil) Error(args ...interface{}) {}
func (n Nil) Errorf(format string, args ...interface{}) {}
func (n Nil) Warning(args ...interface{}) {}
func (n Nil) Warningf(format string, args ...interface{}) {}
func (n Nil) Notice(args ...interface{}) {}
func (n Nil) Noticef(format string, args ...interface{}) {}
func (n Nil) Info(args ...interface{}) {}
func (n Nil) Infof(format string, args ...interface{}) {}
func (n Nil) Debug(args ...interface{}) {}
func (n Nil) Debugf(format string, args ...interface{}) {}
// NewNil is a Nil logger constructor.
func NewNil() Logger {
return &Nil{}
}

View File

@ -0,0 +1,91 @@
package logger
import (
"bytes"
"io"
"os"
"testing"
"time"
"github.com/stretchr/testify/suite"
)
type NilTest struct {
suite.Suite
logger Logger
realStdout *os.File
r *os.File
w *os.File
}
func TestNilLogger(t *testing.T) {
suite.Run(t, new(NilTest))
}
func (t *NilTest) SetupSuite() {
t.logger = NewNil()
}
func (t *NilTest) SetupTest() {
t.realStdout = os.Stdout
t.r, t.w, _ = os.Pipe()
os.Stdout = t.w
}
func (t *NilTest) TearDownTest() {
if t.realStdout != nil {
t.Require().NoError(t.w.Close())
os.Stdout = t.realStdout
}
}
func (t *NilTest) readStdout() string {
outC := make(chan string)
go func() {
var buf bytes.Buffer
_, err := io.Copy(&buf, t.r)
t.Require().NoError(err)
outC <- buf.String()
close(outC)
}()
t.Require().NoError(t.w.Close())
os.Stdout = t.realStdout
t.realStdout = nil
select {
case c := <-outC:
return c
case <-time.After(time.Second):
return ""
}
}
func (t *NilTest) Test_Noop() {
t.logger.Critical("message")
t.logger.Criticalf("message")
t.logger.Error("message")
t.logger.Errorf("message")
t.logger.Warning("message")
t.logger.Warningf("message")
t.logger.Notice("message")
t.logger.Noticef("message")
t.logger.Info("message")
t.logger.Infof("message")
t.logger.Debug("message")
t.logger.Debugf("message")
t.Assert().Empty(t.readStdout())
}
func (t *NilTest) Test_Panic() {
t.Assert().Panics(func() {
t.logger.Panic("")
})
}
func (t *NilTest) Test_Panicf() {
t.Assert().Panics(func() {
t.logger.Panicf("")
})
}

View File

@ -0,0 +1,106 @@
package logger
import "github.com/op/go-logging"
// PrefixAware is implemented if the logger allows you to change the prefix.
type PrefixAware interface {
SetPrefix(string)
}
// PrefixedLogger is a base interface for the logger with prefix.
type PrefixedLogger interface {
Logger
PrefixAware
}
// PrefixDecorator is an implementation of the PrefixedLogger. It will allow you to decorate any Logger with
// the provided predefined prefix.
type PrefixDecorator struct {
backend Logger
prefix []interface{}
}
// DecorateWithPrefix using provided base logger and provided prefix.
// No internal state of the base logger will be touched.
func DecorateWithPrefix(backend Logger, prefix string) PrefixedLogger {
return &PrefixDecorator{backend: backend, prefix: []interface{}{prefix}}
}
// NewWithPrefix returns logger with prefix. It uses StandardLogger under the hood.
func NewWithPrefix(transportCode, prefix string, logLevel logging.Level, logFormat logging.Formatter) PrefixedLogger {
return DecorateWithPrefix(NewStandard(transportCode, logLevel, logFormat), prefix)
}
// SetPrefix will replace existing prefix with the provided value.
// Use this format for prefixes: "prefix here:" - omit space at the end (it will be inserted automatically).
func (p *PrefixDecorator) SetPrefix(prefix string) {
p.prefix = []interface{}{prefix}
}
func (p *PrefixDecorator) getFormat(fmt string) string {
return p.prefix[0].(string) + " " + fmt
}
func (p *PrefixDecorator) Fatal(args ...interface{}) {
p.backend.Fatal(append(p.prefix, args...)...)
}
func (p *PrefixDecorator) Fatalf(format string, args ...interface{}) {
p.backend.Fatalf(p.getFormat(format), args...)
}
func (p *PrefixDecorator) Panic(args ...interface{}) {
p.backend.Panic(append(p.prefix, args...)...)
}
func (p *PrefixDecorator) Panicf(format string, args ...interface{}) {
p.backend.Panicf(p.getFormat(format), args...)
}
func (p *PrefixDecorator) Critical(args ...interface{}) {
p.backend.Critical(append(p.prefix, args...)...)
}
func (p *PrefixDecorator) Criticalf(format string, args ...interface{}) {
p.backend.Criticalf(p.getFormat(format), args...)
}
func (p *PrefixDecorator) Error(args ...interface{}) {
p.backend.Error(append(p.prefix, args...)...)
}
func (p *PrefixDecorator) Errorf(format string, args ...interface{}) {
p.backend.Errorf(p.getFormat(format), args...)
}
func (p *PrefixDecorator) Warning(args ...interface{}) {
p.backend.Warning(append(p.prefix, args...)...)
}
func (p *PrefixDecorator) Warningf(format string, args ...interface{}) {
p.backend.Warningf(p.getFormat(format), args...)
}
func (p *PrefixDecorator) Notice(args ...interface{}) {
p.backend.Notice(append(p.prefix, args...)...)
}
func (p *PrefixDecorator) Noticef(format string, args ...interface{}) {
p.backend.Noticef(p.getFormat(format), args...)
}
func (p *PrefixDecorator) Info(args ...interface{}) {
p.backend.Info(append(p.prefix, args...)...)
}
func (p *PrefixDecorator) Infof(format string, args ...interface{}) {
p.backend.Infof(p.getFormat(format), args...)
}
func (p *PrefixDecorator) Debug(args ...interface{}) {
p.backend.Debug(append(p.prefix, args...)...)
}
func (p *PrefixDecorator) Debugf(format string, args ...interface{}) {
p.backend.Debugf(p.getFormat(format), args...)
}

View File

@ -0,0 +1,166 @@
package logger
import (
"bytes"
"testing"
"github.com/op/go-logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
const testPrefix = "TestPrefix:"
type PrefixDecoratorTest struct {
suite.Suite
buf *bytes.Buffer
logger PrefixedLogger
}
func TestPrefixDecorator(t *testing.T) {
suite.Run(t, new(PrefixDecoratorTest))
}
func TestNewWithPrefix(t *testing.T) {
buf := &bytes.Buffer{}
logger := NewWithPrefix("code", "Prefix:", logging.DEBUG, DefaultLogFormatter())
logger.(*PrefixDecorator).backend.(*StandardLogger).
SetBaseLogger(NewBase(buf, "code", logging.DEBUG, DefaultLogFormatter()))
logger.Debugf("message %s", "text")
assert.Contains(t, buf.String(), "Prefix: message text")
}
func (t *PrefixDecoratorTest) SetupSuite() {
t.buf = &bytes.Buffer{}
t.logger = DecorateWithPrefix((&StandardLogger{}).
SetBaseLogger(NewBase(t.buf, "code", logging.DEBUG, DefaultLogFormatter())), testPrefix)
}
func (t *PrefixDecoratorTest) SetupTest() {
t.buf.Reset()
t.logger.SetPrefix(testPrefix)
}
func (t *PrefixDecoratorTest) Test_SetPrefix() {
t.logger.Info("message")
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
t.logger.SetPrefix(testPrefix + testPrefix)
t.logger.Info("message")
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(), testPrefix+testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Panic() {
t.Require().Panics(func() {
t.logger.Panic("message")
})
t.Assert().Contains(t.buf.String(), "CRIT")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Panicf() {
t.Require().Panics(func() {
t.logger.Panicf("%s", "message")
})
t.Assert().Contains(t.buf.String(), "CRIT")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Critical() {
t.Require().NotPanics(func() {
t.logger.Critical("message")
})
t.Assert().Contains(t.buf.String(), "CRIT")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Criticalf() {
t.Require().NotPanics(func() {
t.logger.Criticalf("%s", "message")
})
t.Assert().Contains(t.buf.String(), "CRIT")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Error() {
t.Require().NotPanics(func() {
t.logger.Error("message")
})
t.Assert().Contains(t.buf.String(), "ERRO")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Errorf() {
t.Require().NotPanics(func() {
t.logger.Errorf("%s", "message")
})
t.Assert().Contains(t.buf.String(), "ERRO")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Warning() {
t.Require().NotPanics(func() {
t.logger.Warning("message")
})
t.Assert().Contains(t.buf.String(), "WARN")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Warningf() {
t.Require().NotPanics(func() {
t.logger.Warningf("%s", "message")
})
t.Assert().Contains(t.buf.String(), "WARN")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Notice() {
t.Require().NotPanics(func() {
t.logger.Notice("message")
})
t.Assert().Contains(t.buf.String(), "NOTI")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Noticef() {
t.Require().NotPanics(func() {
t.logger.Noticef("%s", "message")
})
t.Assert().Contains(t.buf.String(), "NOTI")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Info() {
t.Require().NotPanics(func() {
t.logger.Info("message")
})
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Infof() {
t.Require().NotPanics(func() {
t.logger.Infof("%s", "message")
})
t.Assert().Contains(t.buf.String(), "INFO")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Debug() {
t.Require().NotPanics(func() {
t.logger.Debug("message")
})
t.Assert().Contains(t.buf.String(), "DEBU")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}
func (t *PrefixDecoratorTest) Test_Debugf() {
t.Require().NotPanics(func() {
t.logger.Debugf("%s", "message")
})
t.Assert().Contains(t.buf.String(), "DEBU")
t.Assert().Contains(t.buf.String(), testPrefix+" message")
}

View File

@ -1,121 +0,0 @@
package core
import (
// "os"
// "os/exec".
"testing"
"github.com/op/go-logging"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/suite"
)
type LoggerTest struct {
suite.Suite
logger *Logger
}
func TestLogger_NewLogger(t *testing.T) {
logger := NewLogger("code", logging.DEBUG, DefaultLogFormatter())
assert.NotNil(t, logger)
}
func TestLogger_DefaultLogFormatter(t *testing.T) {
formatter := DefaultLogFormatter()
assert.NotNil(t, formatter)
assert.IsType(t, logging.MustStringFormatter(`%{message}`), formatter)
}
func Test_Logger(t *testing.T) {
suite.Run(t, new(LoggerTest))
}
func (t *LoggerTest) SetupSuite() {
t.logger = NewLogger("code", logging.DEBUG, DefaultLogFormatter()).Exclusive()
}
// TODO Cover Fatal and Fatalf (implementation below is no-op)
// func (t *LoggerTest) Test_Fatal() {
// if os.Getenv("FLAG") == "1" {
// t.logger.Fatal("test", "fatal")
// return
// }
// cmd := exec.Command(os.Args[0], "-test.run=TestGetConfig")
// cmd.Env = append(os.Environ(), "FLAG=1")
// err := cmd.Run()
// e, ok := err.(*exec.ExitError)
// expectedErrorString := "test fatal"
// assert.Equal(t.T(), true, ok)
// assert.Equal(t.T(), expectedErrorString, e.Error())
// }
func (t *LoggerTest) Test_Panic() {
defer func() {
assert.NotNil(t.T(), recover())
}()
t.logger.Panic("panic")
}
func (t *LoggerTest) Test_Panicf() {
defer func() {
assert.NotNil(t.T(), recover())
}()
t.logger.Panicf("panic")
}
func (t *LoggerTest) Test_Critical() {
defer func() {
if v := recover(); v != nil {
t.T().Fatal(v)
}
}()
t.logger.Critical("critical")
}
func (t *LoggerTest) Test_Criticalf() {
defer func() {
if v := recover(); v != nil {
t.T().Fatal(v)
}
}()
t.logger.Criticalf("critical")
}
func (t *LoggerTest) Test_Warning() {
defer func() {
if v := recover(); v != nil {
t.T().Fatal(v)
}
}()
t.logger.Warning("warning")
}
func (t *LoggerTest) Test_Notice() {
defer func() {
if v := recover(); v != nil {
t.T().Fatal(v)
}
}()
t.logger.Notice("notice")
}
func (t *LoggerTest) Test_Info() {
defer func() {
if v := recover(); v != nil {
t.T().Fatal(v)
}
}()
t.logger.Info("info")
}
func (t *LoggerTest) Test_Debug() {
defer func() {
if v := recover(); v != nil {
t.T().Fatal(v)
}
}()
t.logger.Debug("debug")
}

View File

@ -1,4 +1,4 @@
package core package middleware
import ( import (
"bytes" "bytes"
@ -41,6 +41,11 @@ const (
CSRFErrorTokenMismatch CSRFErrorTokenMismatch
) )
const (
keySize = 8
randomStringSize = 64
)
// DefaultCSRFTokenGetter default getter. // DefaultCSRFTokenGetter default getter.
var DefaultCSRFTokenGetter = func(c *gin.Context) string { var DefaultCSRFTokenGetter = func(c *gin.Context) string {
r := c.Request r := c.Request
@ -70,31 +75,43 @@ var DefaultIgnoredMethods = []string{"GET", "HEAD", "OPTIONS"}
// CSRF struct. Provides CSRF token verification. // CSRF struct. Provides CSRF token verification.
type CSRF struct { type CSRF struct {
store sessions.Store
abortFunc CSRFAbortFunc
csrfTokenGetter CSRFTokenGetter
salt string salt string
secret string secret string
sessionName string sessionName string
abortFunc CSRFAbortFunc
csrfTokenGetter CSRFTokenGetter
store sessions.Store
} }
// NewCSRF creates CSRF struct with specified configuration and session store. // NewCSRF creates CSRF struct with specified configuration and session store.
// GenerateCSRFMiddleware and VerifyCSRFMiddleware returns CSRF middlewares. // GenerateCSRFMiddleware and VerifyCSRFMiddleware returns CSRF middlewares.
// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default, // Salt must be different every time (pass empty salt to use random), secret must be provided,
// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token. // sessionName is optional - pass empty to use default,
// store will be used to store sessions, abortFunc will be called to return error if token is invalid,
// csrfTokenGetter will be used to obtain token.
//
// Usage (with random salt): // Usage (with random salt):
// core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) { // core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) {
// c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"}) // c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"})
// }, core.DefaultCSRFTokenGetter) // }, core.DefaultCSRFTokenGetter)
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data! //
// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field)
// - don't forget to restore Body data!
//
// Body in http.Request is io.ReadCloser instance. Reading CSRF token from form like that: // Body in http.Request is io.ReadCloser instance. Reading CSRF token from form like that:
// if t := r.FormValue("csrf_token"); len(t) > 0 { // if t := r.FormValue("csrf_token"); len(t) > 0 {
// return t // return t
// } // }
// will close body - and all next middlewares won't be able to read body at all! // will close body - and all next middlewares won't be able to read body at all!
//
// Use DefaultCSRFTokenGetter as example to implement your own token getter. // Use DefaultCSRFTokenGetter as example to implement your own token getter.
// CSRFErrorReason will be passed to abortFunc and can be used for better error messages. // CSRFErrorReason will be passed to abortFunc and can be used for better error messages.
func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF { func NewCSRF(
salt, secret, sessionName string,
store sessions.Store,
abortFunc CSRFAbortFunc,
csrfTokenGetter CSRFTokenGetter,
) *CSRF {
if store == nil { if store == nil {
panic("store must not be nil") panic("store must not be nil")
} }
@ -157,10 +174,10 @@ func (x *CSRF) generateCSRFToken() string {
// Default secure salt length: 8 bytes. // Default secure salt length: 8 bytes.
// Default pseudo-random salt length: 64 bytes. // Default pseudo-random salt length: 64 bytes.
func (x *CSRF) generateSalt() string { func (x *CSRF) generateSalt() string {
salt := securecookie.GenerateRandomKey(8) salt := securecookie.GenerateRandomKey(keySize)
if salt == nil { if salt == nil {
return x.pseudoRandomString(64) return x.pseudoRandomString(randomStringSize)
} }
return string(salt) return string(salt)
@ -171,15 +188,15 @@ func (x *CSRF) pseudoRandomString(length int) string {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
data := make([]byte, length) data := make([]byte, length)
for i := 0; i < length; i++ { for i := 0; i < length; i++ { // it is supposed to use pseudo-random data.
data[i] = byte(65 + rand.Intn(90-65)) data[i] = byte(65 + rand.Intn(90-65)) // nolint:gosec,gomnd
} }
return string(data) return string(data)
} }
// CSRFFromContext returns csrf token or random token. It shouldn't return empty string because it will make csrf protection useless. // CSRFFromContext returns csrf token or random token. It shouldn't return empty string because
// e.g. any request without token will work fine, which is inacceptable. // it will make csrf protection useless. e.g. any request without token will work fine, which is unacceptable.
func (x *CSRF) CSRFFromContext(c *gin.Context) string { func (x *CSRF) CSRFFromContext(c *gin.Context) string {
if i, ok := c.Get("csrf_token"); ok { if i, ok := c.Get("csrf_token"); ok {
if token, ok := i.(string); ok { if token, ok := i.(string); ok {

View File

@ -1,4 +1,4 @@
package core package middleware
import ( import (
"bytes" "bytes"
@ -21,10 +21,10 @@ type CSRFTest struct {
} }
type requestOptions struct { type requestOptions struct {
Body io.Reader
Headers map[string]string
Method string Method string
URL string URL string
Headers map[string]string
Body io.Reader
} }
func TestCSRF_DefaultCSRFTokenGetter_Empty(t *testing.T) { func TestCSRF_DefaultCSRFTokenGetter_Empty(t *testing.T) {

View File

@ -1,4 +1,4 @@
package core package middleware
import ( import (
"encoding/json" "encoding/json"

View File

@ -1,4 +1,4 @@
package core package middleware
import ( import (
"crypto/hmac" "crypto/hmac"

View File

@ -1,53 +0,0 @@
package core
import "time"
// Connection model.
type Connection struct {
ID int `gorm:"primary_key"`
ClientID string `gorm:"column:client_id; type:varchar(70); not null; unique" json:"clientId,omitempty"`
Key string `gorm:"column:api_key; type:varchar(100); not null" json:"api_key,omitempty" binding:"required,max=100"`
URL string `gorm:"column:api_url; type:varchar(255); not null" json:"api_url,omitempty" binding:"required,validateCrmURL,max=255"` // nolint:lll
GateURL string `gorm:"column:mg_url; type:varchar(255); not null;" json:"mg_url,omitempty" binding:"max=255"`
GateToken string `gorm:"column:mg_token; type:varchar(100); not null; unique" json:"mg_token,omitempty" binding:"max=100"`
CreatedAt time.Time
UpdatedAt time.Time
Active bool `json:"active,omitempty"`
Accounts []Account `gorm:"foreignkey:ConnectionID"`
}
// Account model.
type Account struct {
ID int `gorm:"primary_key"`
ConnectionID int `gorm:"column:connection_id" json:"connectionId,omitempty"`
Channel uint64 `gorm:"column:channel; not null; unique" json:"channel,omitempty"`
ChannelSettingsHash string `gorm:"column:channel_settings_hash; type:varchar(70)" binding:"max=70"`
Name string `gorm:"column:name; type:varchar(100)" json:"name,omitempty" binding:"max=100"`
Lang string `gorm:"column:lang; type:varchar(2)" json:"lang,omitempty" binding:"max=2"`
CreatedAt time.Time
UpdatedAt time.Time
}
// User model.
type User struct {
ID int `gorm:"primary_key"`
ExternalID string `gorm:"column:external_id; type:varchar(255); not null; unique"`
UserPhotoURL string `gorm:"column:user_photo_url; type:varchar(255)" binding:"max=255"`
UserPhotoID string `gorm:"column:user_photo_id; type:varchar(100)" binding:"max=100"`
CreatedAt time.Time
UpdatedAt time.Time
}
// TableName will return table name for User
// It will not work if User is not embedded, but mapped as another type
// type MyUser User // will not work
// but
// type MyUser struct { // will work
// User
// }
func (User) TableName() string {
return "mg_user"
}
// Accounts list.
type Accounts []Account

View File

@ -9,7 +9,9 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/retailcrm/mg-transport-core/core/stacktrace" "github.com/retailcrm/mg-transport-core/v2/core/logger"
"github.com/retailcrm/mg-transport-core/v2/core/stacktrace"
"github.com/getsentry/raven-go" "github.com/getsentry/raven-go"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -34,19 +36,19 @@ type SentryTagged interface {
// Sentry struct. Holds SentryTaggedStruct list. // Sentry struct. Holds SentryTaggedStruct list.
type Sentry struct { type Sentry struct {
Logger logger.Logger
Client stacktrace.RavenClientInterface
Localizer *Localizer
DefaultError string
TaggedTypes SentryTaggedTypes TaggedTypes SentryTaggedTypes
Stacktrace bool Stacktrace bool
DefaultError string
Localizer *Localizer
Logger LoggerInterface
Client stacktrace.RavenClientInterface
} }
// SentryTaggedStruct holds information about type, it's key in gin.Context (for middleware), and it's properties. // SentryTaggedStruct holds information about type, it's key in gin.Context (for middleware), and it's properties.
type SentryTaggedStruct struct { type SentryTaggedStruct struct {
Type reflect.Type Type reflect.Type
GinContextKey string
Tags SentryTags Tags SentryTags
GinContextKey string
} }
// SentryTaggedScalar variable from context. // SentryTaggedScalar variable from context.
@ -56,7 +58,13 @@ type SentryTaggedScalar struct {
} }
// NewSentry constructor. // NewSentry constructor.
func NewSentry(sentryDSN string, defaultError string, taggedTypes SentryTaggedTypes, logger LoggerInterface, localizer *Localizer) *Sentry { func NewSentry(
sentryDSN string,
defaultError string,
taggedTypes SentryTaggedTypes,
logger logger.Logger,
localizer *Localizer,
) *Sentry {
sentry := &Sentry{ sentry := &Sentry{
DefaultError: defaultError, DefaultError: defaultError,
TaggedTypes: taggedTypes, TaggedTypes: taggedTypes,
@ -196,7 +204,7 @@ func (s *Sentry) ErrorResponseHandler() ErrorHandlerFunc {
} }
// ErrorCaptureHandler will generate error data and send it to sentry. // ErrorCaptureHandler will generate error data and send it to sentry.
func (s *Sentry) ErrorCaptureHandler() ErrorHandlerFunc { func (s *Sentry) ErrorCaptureHandler() ErrorHandlerFunc { // nolint:gocognit
return func(recovery interface{}, c *gin.Context) { return func(recovery interface{}, c *gin.Context) {
tags := map[string]string{ tags := map[string]string{
"endpoint": c.Request.RequestURI, "endpoint": c.Request.RequestURI,

View File

@ -16,12 +16,14 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
) )
type sampleStruct struct { type sampleStruct struct {
ID int
Pointer *int Pointer *int
Field string Field string
ID int
} }
type ravenPacket struct { type ravenPacket struct {
@ -51,19 +53,9 @@ func (r ravenPacket) getException() (*raven.Exception, bool) {
return nil, false return nil, false
} }
func (r ravenPacket) getRequest() (*raven.Http, bool) {
if i, ok := r.getInterface("request"); ok {
if r, ok := i.(*raven.Http); ok {
return r, true
}
}
return nil, false
}
type ravenClientMock struct { type ravenClientMock struct {
raven.Client
captured []ravenPacket captured []ravenPacket
raven.Client
mu sync.RWMutex mu sync.RWMutex
wg sync.WaitGroup wg sync.WaitGroup
} }
@ -94,7 +86,7 @@ func (r *ravenClientMock) CaptureMessageAndWait(message string, tags map[string]
r.mu.Lock() r.mu.Lock()
defer r.mu.Unlock() defer r.mu.Unlock()
defer r.wg.Done() defer r.wg.Done()
eventID := strconv.FormatUint(rand.Uint64(), 10) eventID := strconv.FormatUint(rand.Uint64(), 10) // nolint:gosec
r.captured = append(r.captured, ravenPacket{ r.captured = append(r.captured, ravenPacket{
EventID: eventID, EventID: eventID,
Message: message, Message: message,
@ -127,8 +119,8 @@ func (n *simpleError) Error() string {
// wrappableError is a simple implementation of wrappable error. // wrappableError is a simple implementation of wrappable error.
type wrappableError struct { type wrappableError struct {
msg string
err error err error
msg string
} }
func newWrappableError(msg string, child error) error { func newWrappableError(msg string, child error) error {
@ -304,7 +296,7 @@ func (s *SentryTest) TestSentry_CaptureRegularError() {
c.Error(newSimpleError("test")) c.Error(newSimpleError("test"))
}) })
var resp ErrorsResponse var resp errorutil.ListResponse
req, err := http.NewRequest(http.MethodGet, "/test_regularError", nil) req, err := http.NewRequest(http.MethodGet, "/test_regularError", nil)
require.NoError(s.T(), err) require.NoError(s.T(), err)
@ -338,7 +330,7 @@ func (s *SentryTest) TestSentry_CaptureWrappedError() {
c.Error(first) c.Error(first)
}) })
var resp ErrorsResponse var resp errorutil.ListResponse
req, err := http.NewRequest(http.MethodGet, "/test_wrappableError", nil) req, err := http.NewRequest(http.MethodGet, "/test_wrappableError", nil)
require.NoError(s.T(), err) require.NoError(s.T(), err)
@ -389,7 +381,7 @@ func (s *SentryTest) TestSentry_CaptureTags() {
}), }),
} }
var resp ErrorsResponse var resp errorutil.ListResponse
req, err := http.NewRequest(http.MethodGet, "/test_taggedError", nil) req, err := http.NewRequest(http.MethodGet, "/test_taggedError", nil)
require.NoError(s.T(), err) require.NoError(s.T(), err)

View File

@ -0,0 +1,57 @@
package stacktrace
import (
"path/filepath"
"github.com/getsentry/raven-go"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
)
// ErrorNodesList is the interface for the errorutil.errList.
type ErrorNodesList interface {
Iterate() <-chan errorutil.Node
Len() int
}
// ErrCollectorBuilder builds stacktrace from the list of errors collected by errorutil.Collector.
type ErrCollectorBuilder struct {
AbstractStackBuilder
}
// IsErrorNodesList returns true if error contains error nodes.
func IsErrorNodesList(err error) bool {
_, ok := err.(ErrorNodesList) // nolint:errorlint
return ok
}
// AsErrorNodesList returns ErrorNodesList instance from the error.
func AsErrorNodesList(err error) ErrorNodesList {
return err.(ErrorNodesList) // nolint:errorlint
}
// Build stacktrace.
func (b *ErrCollectorBuilder) Build() StackBuilderInterface {
if !IsErrorNodesList(b.err) {
b.buildErr = ErrUnfeasibleBuilder
return b
}
i := 0
errs := AsErrorNodesList(b.err)
frames := make([]*raven.StacktraceFrame, errs.Len())
for err := range errs.Iterate() {
frames[i] = raven.NewStacktraceFrame(
err.PC, filepath.Base(err.File), err.File, err.Line, 3, b.client.IncludePaths())
i++
}
if len(frames) <= 1 {
b.buildErr = ErrUnfeasibleBuilder
return b
}
b.stack = &raven.Stacktrace{Frames: frames}
return b
}

View File

@ -0,0 +1,55 @@
package stacktrace
import (
"errors"
"path/filepath"
"runtime"
"strings"
"testing"
"github.com/getsentry/raven-go"
"github.com/stretchr/testify/suite"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
)
type ErrCollectorBuilderTest struct {
builder *ErrCollectorBuilder
c *errorutil.Collector
suite.Suite
}
func TestErrCollectorBuilder(t *testing.T) {
suite.Run(t, new(ErrCollectorBuilderTest))
}
func (t *ErrCollectorBuilderTest) SetupTest() {
t.c = errorutil.NewCollector()
client, _ := raven.New("fake dsn")
t.builder = &ErrCollectorBuilder{AbstractStackBuilder{
client: client,
err: t.c,
}}
}
func (t *ErrCollectorBuilderTest) TestBuild() {
t.c.Do(
errors.New("first"),
errors.New("second"),
errors.New("third"))
stack, err := t.builder.Build().GetResult()
_, file, _, _ := runtime.Caller(0)
t.Require().NoError(err)
t.Require().NotZero(stack)
t.Assert().Len(stack.Frames, 3)
for _, frame := range stack.Frames {
t.Assert().Equal(file, frame.Filename)
t.Assert().Equal(file, frame.AbsolutePath)
t.Assert().Equal("go", frame.Function)
t.Assert().Equal(strings.TrimSuffix(filepath.Base(file), ".go"), frame.Module)
t.Assert().NotZero(frame.Lineno)
}
}

View File

@ -14,6 +14,13 @@ type PkgErrorTraceable interface {
StackTrace() pkgErrors.StackTrace StackTrace() pkgErrors.StackTrace
} }
// IsPkgErrorsError returns true if passed error might be github.com/pkg/errors error.
func IsPkgErrorsError(err error) bool {
_, okTraceable := err.(PkgErrorTraceable) // nolint:errorlint
_, okCauseable := err.(PkgErrorCauseable) // nolint:errorlint
return okTraceable || okCauseable
}
// PkgErrorsStackTransformer transforms stack data from github.com/pkg/errors error to stacktrace.Stacktrace. // PkgErrorsStackTransformer transforms stack data from github.com/pkg/errors error to stacktrace.Stacktrace.
type PkgErrorsStackTransformer struct { type PkgErrorsStackTransformer struct {
stack pkgErrors.StackTrace stack pkgErrors.StackTrace
@ -44,7 +51,7 @@ type PkgErrorsBuilder struct {
// Build stacktrace. // Build stacktrace.
func (b *PkgErrorsBuilder) Build() StackBuilderInterface { func (b *PkgErrorsBuilder) Build() StackBuilderInterface {
if !isPkgErrors(b.err) { if !IsPkgErrorsError(b.err) {
b.buildErr = ErrUnfeasibleBuilder b.buildErr = ErrUnfeasibleBuilder
return b return b
} }
@ -71,7 +78,7 @@ func (b *PkgErrorsBuilder) Build() StackBuilderInterface {
// getErrorCause will try to extract original error from wrapper - it is used only if stacktrace is not present. // getErrorCause will try to extract original error from wrapper - it is used only if stacktrace is not present.
func (b *PkgErrorsBuilder) getErrorCause(err error) error { func (b *PkgErrorsBuilder) getErrorCause(err error) error {
causeable, ok := err.(PkgErrorCauseable) causeable, ok := err.(PkgErrorCauseable) // nolint:errorlint
if !ok { if !ok {
return nil return nil
} }
@ -81,7 +88,7 @@ func (b *PkgErrorsBuilder) getErrorCause(err error) error {
// getErrorStackTrace will try to extract stacktrace from error using StackTrace method // getErrorStackTrace will try to extract stacktrace from error using StackTrace method
// (default errors doesn't have it). // (default errors doesn't have it).
func (b *PkgErrorsBuilder) getErrorStack(err error) pkgErrors.StackTrace { func (b *PkgErrorsBuilder) getErrorStack(err error) pkgErrors.StackTrace {
traceable, ok := err.(PkgErrorTraceable) traceable, ok := err.(PkgErrorTraceable) // nolint:errorlint
if !ok { if !ok {
return nil return nil
} }

View File

@ -13,8 +13,8 @@ import (
// errorWithCause has Cause() method, but doesn't have StackTrace() method. // errorWithCause has Cause() method, but doesn't have StackTrace() method.
type errorWithCause struct { type errorWithCause struct {
msg string
cause error cause error
msg string
} }
func newErrorWithCause(msg string, cause error) error { func newErrorWithCause(msg string, cause error) error {
@ -53,7 +53,7 @@ func (s *PkgErrorsStackProviderSuite) Test_Empty() {
func (s *PkgErrorsStackProviderSuite) Test_Full() { func (s *PkgErrorsStackProviderSuite) Test_Full() {
testErr := pkgErrors.New("test") testErr := pkgErrors.New("test")
s.transformer.stack = testErr.(PkgErrorTraceable).StackTrace() s.transformer.stack = testErr.(PkgErrorTraceable).StackTrace() // nolint:errorlint
assert.NotEmpty(s.T(), s.transformer.Stack()) assert.NotEmpty(s.T(), s.transformer.Stack())
} }

View File

@ -55,16 +55,14 @@ func (b *RavenStacktraceBuilder) Build(context int, appPackagePrefixes []string)
} }
// convertFrame converts single generic stacktrace frame to github.com/pkg/errors.Frame. // convertFrame converts single generic stacktrace frame to github.com/pkg/errors.Frame.
func (b *RavenStacktraceBuilder) convertFrame(f Frame, context int, appPackagePrefixes []string) *raven.StacktraceFrame { func (b *RavenStacktraceBuilder) convertFrame(
f Frame, context int, appPackagePrefixes []string) *raven.StacktraceFrame {
// This code is borrowed from github.com/pkg/errors.Frame. // This code is borrowed from github.com/pkg/errors.Frame.
pc := uintptr(f) - 1 pc := uintptr(f) - 1
fn := runtime.FuncForPC(pc) line := 0
var file string file := "unknown"
var line int if fn := runtime.FuncForPC(pc); fn != nil {
if fn != nil {
file, line = fn.FileLine(pc) file, line = fn.FileLine(pc)
} else {
file = "unknown"
} }
return raven.NewStacktraceFrame(pc, path.Dir(file), file, line, context, appPackagePrefixes) return raven.NewStacktraceFrame(pc, path.Dir(file), file, line, context, appPackagePrefixes)
} }

View File

@ -3,20 +3,17 @@ package stacktrace
// GetStackBuilderByErrorType tries to guess which stacktrace builder would be feasible for passed error. // GetStackBuilderByErrorType tries to guess which stacktrace builder would be feasible for passed error.
// For example, errors from github.com/pkg/errors have StackTrace() method, and Go 1.13 errors can be unwrapped. // For example, errors from github.com/pkg/errors have StackTrace() method, and Go 1.13 errors can be unwrapped.
func GetStackBuilderByErrorType(err error) StackBuilderInterface { func GetStackBuilderByErrorType(err error) StackBuilderInterface {
if isPkgErrors(err) { if IsPkgErrorsError(err) {
return &PkgErrorsBuilder{AbstractStackBuilder{err: err}} return &PkgErrorsBuilder{AbstractStackBuilder{err: err}}
} }
if _, ok := err.(Unwrappable); ok { if IsUnwrappableError(err) {
return &UnwrapBuilder{AbstractStackBuilder{err: err}} return &UnwrapBuilder{AbstractStackBuilder{err: err}}
} }
if IsErrorNodesList(err) {
return &ErrCollectorBuilder{AbstractStackBuilder{err: err}}
}
return &GenericStackBuilder{AbstractStackBuilder{err: err}} return &GenericStackBuilder{AbstractStackBuilder{err: err}}
} }
// isPkgErrors returns true if passed error might be github.com/pkg/errors error.
func isPkgErrors(err error) bool {
_, okTraceable := err.(PkgErrorTraceable)
_, okCauseable := err.(PkgErrorCauseable)
return okTraceable || okCauseable
}

View File

@ -6,6 +6,8 @@ import (
pkgErrors "github.com/pkg/errors" pkgErrors "github.com/pkg/errors"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
) )
func TestGetStackBuilderByErrorType_PkgErrors(t *testing.T) { func TestGetStackBuilderByErrorType_PkgErrors(t *testing.T) {
@ -20,6 +22,12 @@ func TestGetStackBuilderByErrorType_UnwrapBuilder(t *testing.T) {
assert.IsType(t, &UnwrapBuilder{}, builder) assert.IsType(t, &UnwrapBuilder{}, builder)
} }
func TestGetStackBuilderByErrorType_ErrCollectorBuilder(t *testing.T) {
testErr := errorutil.NewCollector().Do(errors.New("first"), errors.New("second")).AsError()
builder := GetStackBuilderByErrorType(testErr)
assert.IsType(t, &ErrCollectorBuilder{}, builder)
}
func TestGetStackBuilderByErrorType_Generic(t *testing.T) { func TestGetStackBuilderByErrorType_Generic(t *testing.T) {
defaultErr := errors.New("default err") defaultErr := errors.New("default err")
builder := GetStackBuilderByErrorType(defaultErr) builder := GetStackBuilderByErrorType(defaultErr)

View File

@ -14,15 +14,21 @@ type UnwrapBuilder struct {
AbstractStackBuilder AbstractStackBuilder
} }
// IsUnwrappableError returns true if error can be unwrapped.
func IsUnwrappableError(err error) bool {
_, ok := err.(Unwrappable) // nolint:errorlint
return ok
}
// Build stacktrace. // Build stacktrace.
func (b *UnwrapBuilder) Build() StackBuilderInterface { func (b *UnwrapBuilder) Build() StackBuilderInterface {
if _, ok := b.err.(Unwrappable); !ok { if !IsUnwrappableError(b.err) {
b.buildErr = ErrUnfeasibleBuilder b.buildErr = ErrUnfeasibleBuilder
return b return b
} }
err := b.err err := b.err
frames := []*raven.StacktraceFrame{} var frames []*raven.StacktraceFrame
for err != nil { for err != nil {
frames = append(frames, raven.NewStacktraceFrame( frames = append(frames, raven.NewStacktraceFrame(
@ -34,7 +40,7 @@ func (b *UnwrapBuilder) Build() StackBuilderInterface {
b.client.IncludePaths(), b.client.IncludePaths(),
)) ))
if item, ok := err.(Unwrappable); ok { if item, ok := err.(Unwrappable); ok { // nolint:errorlint
err = item.Unwrap() err = item.Unwrap()
} else { } else {
err = nil err = nil

View File

@ -25,8 +25,8 @@ func (n *simpleError) Error() string {
// wrappableError is a simple implementation of wrappable error. // wrappableError is a simple implementation of wrappable error.
type wrappableError struct { type wrappableError struct {
msg string
err error err error
msg string
} }
func newWrappableError(msg string, child error) error { func newWrappableError(msg string, child error) error {
@ -75,8 +75,7 @@ func (s *UnwrapBuilderSuite) TestBuild_NoUnwrap() {
func (s *UnwrapBuilderSuite) TestBuild_WrappableHasWrapped() { func (s *UnwrapBuilderSuite) TestBuild_WrappableHasWrapped() {
testErr := newWrappableError("first", newWrappableError("second", errors.New("third"))) testErr := newWrappableError("first", newWrappableError("second", errors.New("third")))
_, ok := testErr.(Unwrappable) require.True(s.T(), IsUnwrappableError(testErr))
require.True(s.T(), ok)
s.builder.SetError(testErr) s.builder.SetError(testErr)
stack, buildErr := s.builder.Build().GetResult() stack, buildErr := s.builder.Build().GetResult()

View File

@ -54,15 +54,15 @@ func (r *Renderer) Push(name string, files ...string) *template.Template {
// addFromFS adds embedded template. // addFromFS adds embedded template.
func (r *Renderer) addFromFS(name string, funcMap template.FuncMap, files ...string) *template.Template { func (r *Renderer) addFromFS(name string, funcMap template.FuncMap, files ...string) *template.Template {
var filesData []string filesData := make([]string, len(files))
for _, fileName := range files { for i := 0; i < len(files); i++ {
data, err := fs.ReadFile(r.TemplatesFS, fileName) data, err := fs.ReadFile(r.TemplatesFS, files[i])
if err != nil { if err != nil {
panic(err) panic(err)
} }
filesData = append(filesData, string(data)) filesData[i] = string(data)
} }
return r.AddFromStringsFuncs(name, funcMap, filesData...) return r.AddFromStringsFuncs(name, funcMap, filesData...)

View File

@ -0,0 +1,135 @@
package errorutil
import (
"fmt"
"runtime"
"strings"
)
// Collector is a replacement for the core.ErrorCollector function. It is easier to use and contains more functionality.
// For example, you can iterate over the errors or use Collector.Panic() to immediately panic
// if there are errors in the chain.
//
// Error messages will differ from the ones produced by ErrorCollector. However, it's for the best because
// new error messages contain a lot more useful information and can be even used as a stacktrace.
//
// Collector implements Error() and String() methods. As a result, you can use the collector as an error itself or just
// print out it as a value. However, it is better to use AsError() method if you want to use Collector as an error value
// because AsError() returns nil if there are no errors in the list.
//
// Example:
// err := errorutil.NewCollector().
// Do(errors.New("error 1")).
// Do(errors.New("error 2"), errors.New("error 3"))
// // Will print error message.
// fmt.Println(err)
//
// This code will produce something like this:
// #1 err at /home/user/main.go:62: error 1
// #2 err at /home/user/main.go:63: error 2
// #3 err at /home/user/main.go:64: error 3
//
// You can also iterate over the error to use their data instead of using predefined message:
// err := errorutil.NewCollector().
// Do(errors.New("error 1")).
// Do(errors.New("error 2"), errors.New("error 3"))
//
// for err := range c.Iterate() {
// fmt.Printf("Error at %s:%d: %v\n", err.File, err.Line, err)
// }
//
// This code will produce output that looks like this:
// Error at /home/user/main.go:164: error 0
// Error at /home/user/main.go:164: error 1
// Error at /home/user/main.go:164: error 2
//
// Example with GORM migration (Collector is returned as an error here).
// return errorutil.NewCollector().Do(
// db.CreateTable(models.Account{}, models.Connection{}).Error,
// db.Table("account").AddUniqueIndex("account_key", "channel").Error,
// ).AsError()
type Collector struct {
errors *errList
}
// NewCollector returns new errorutil.Collector instance.
func NewCollector() *Collector {
return &Collector{
errors: &errList{},
}
}
// Collect errors, return one error for all of them (shorthand for errorutil.NewCollector().Do(...).AsError()).
// Returns nil if there are no errors.
func Collect(errs ...error) error {
return NewCollector().Do(errs...).AsError()
}
// Do some operation that returns the error. Supports multiple operations at once.
func (e *Collector) Do(errs ...error) *Collector {
pc, file, line, _ := runtime.Caller(1)
for _, err := range errs {
if err != nil {
e.errors.Push(pc, err, file, line)
}
}
return e
}
// OK returns true if there is no errors in the list.
func (e *Collector) OK() bool {
return e.errors.Len() == 0
}
// Error message.
func (e *Collector) Error() string {
return e.buildErrorMessage()
}
// AsError returns the Collector itself as an error, but only if there are errors in the list.
// It returns nil otherwise. This method should be used if you want to return error to the caller, but only if\
// Collector actually caught something.
func (e *Collector) AsError() error {
if e.OK() {
return nil
}
return e
}
// String with an error message.
func (e *Collector) String() string {
return e.Error()
}
// Panic with the error data if there are errors in the list.
func (e *Collector) Panic() {
if !e.OK() {
panic(e)
}
}
// Iterate over the errors in the list. Every error is represented as an errorutil.Node value.
func (e *Collector) Iterate() <-chan Node {
return e.errors.Iterate()
}
// Len returns the number of the errors in the list.
func (e *Collector) Len() int {
return e.errors.Len()
}
// buildErrorMessage builds error message for the Collector.Error() and Collector.String() methods.
func (e *Collector) buildErrorMessage() string {
i := 0
var sb strings.Builder
sb.Grow(128 * e.errors.Len()) // nolint:gomnd
for node := range e.errors.Iterate() {
i++
sb.WriteString(fmt.Sprintf("#%d err at %s:%d: %v\n", i, node.File, node.Line, node.Err))
}
return strings.TrimRight(sb.String(), "\n")
}

View File

@ -0,0 +1,74 @@
package errorutil
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
type ErrCollectorTest struct {
c *Collector
suite.Suite
}
func TestErrCollector(t *testing.T) {
suite.Run(t, new(ErrCollectorTest))
}
func TestCollect_NoError(t *testing.T) {
require.NoError(t, Collect())
}
func TestCollect_NoError_Nils(t *testing.T) {
require.NoError(t, Collect(nil, nil, nil))
}
func TestCollect_Error(t *testing.T) {
require.Error(t, Collect(errors.New("first"), errors.New("second")))
}
func (t *ErrCollectorTest) SetupTest() {
t.c = NewCollector()
}
func (t *ErrCollectorTest) TestDo() {
t.c.Do(
errors.New("first"),
errors.New("second"),
errors.New("third"))
t.Assert().False(t.c.OK())
t.Assert().NotEmpty(t.c.String())
t.Assert().Error(t.c.AsError())
t.Assert().Equal(3, t.c.Len())
t.Assert().Panics(func() {
t.c.Panic()
})
i := 0
for err := range t.c.Iterate() {
t.Assert().Error(err.Err)
t.Assert().NotEmpty(err.File)
t.Assert().NotZero(err.Line)
t.Assert().NotZero(err.PC)
switch i {
case 0:
t.Assert().Equal("first", err.Error())
case 1:
t.Assert().Equal("second", err.Error())
case 2:
t.Assert().Equal("third", err.Error())
}
i++
}
}
func (t *ErrCollectorTest) Test_PanicNone() {
t.Assert().NotPanics(func() {
t.c.Panic()
})
}

View File

@ -0,0 +1,68 @@
package errorutil
// Node contains information about error in the list.
type Node struct {
Err error
next *Node
File string
PC uintptr
Line int
}
// Error returns error message from the Node's Err.
func (e Node) Error() string {
if e.Err == nil {
return ""
}
return e.Err.Error()
}
// errList is an immutable linked list of error.
type errList struct {
head *Node
tail *Node
len int
}
// Push an error into the list.
func (l *errList) Push(pc uintptr, err error, file string, line int) {
item := &Node{
PC: pc,
Err: err,
File: file,
Line: line,
}
if l.head == nil {
l.head = item
}
if l.tail != nil {
l.tail.next = item
}
l.tail = item
l.len++
}
// Iterate over error list.
func (l *errList) Iterate() <-chan Node {
c := make(chan Node)
go func() {
item := l.head
for {
if item == nil {
close(c)
return
}
c <- *item
item = item.next
}
}()
return c
}
// Len returns length of the list.
func (l *errList) Len() int {
return l.len
}

View File

@ -0,0 +1,50 @@
package errorutil
import "net/http"
// Response with the error message.
type Response struct {
Error string `json:"error"`
}
// ListResponse contains multiple errors in the list.
type ListResponse struct {
Error []string `json:"error"`
}
// GetErrorResponse returns ErrorResponse with specified status code
// Usage (with gin):
// context.JSON(GetErrorResponse(http.StatusPaymentRequired, "Not enough money"))
func GetErrorResponse(statusCode int, err string) (int, interface{}) {
return statusCode, Response{
Error: err,
}
}
// BadRequest returns ErrorResponse with code 400
// Usage (with gin):
// context.JSON(BadRequest("invalid data"))
func BadRequest(err string) (int, interface{}) {
return GetErrorResponse(http.StatusBadRequest, err)
}
// Unauthorized returns ErrorResponse with code 401
// Usage (with gin):
// context.JSON(Unauthorized("invalid credentials"))
func Unauthorized(err string) (int, interface{}) {
return GetErrorResponse(http.StatusUnauthorized, err)
}
// Forbidden returns ErrorResponse with code 403
// Usage (with gin):
// context.JSON(Forbidden("forbidden"))
func Forbidden(err string) (int, interface{}) {
return GetErrorResponse(http.StatusForbidden, err)
}
// InternalServerError returns ErrorResponse with code 500
// Usage (with gin):
// context.JSON(BadRequest("invalid data"))
func InternalServerError(err string) (int, interface{}) {
return GetErrorResponse(http.StatusInternalServerError, err)
}

View File

@ -1,4 +1,4 @@
package core package errorutil
import ( import (
"net/http" "net/http"
@ -11,19 +11,19 @@ func TestError_GetErrorResponse(t *testing.T) {
code, resp := GetErrorResponse(http.StatusBadRequest, "error string") code, resp := GetErrorResponse(http.StatusBadRequest, "error string")
assert.Equal(t, http.StatusBadRequest, code) assert.Equal(t, http.StatusBadRequest, code)
assert.Equal(t, "error string", resp.(ErrorResponse).Error) assert.Equal(t, "error string", resp.(Response).Error)
} }
func TestError_BadRequest(t *testing.T) { func TestError_BadRequest(t *testing.T) {
code, resp := BadRequest("error string") code, resp := BadRequest("error string")
assert.Equal(t, http.StatusBadRequest, code) assert.Equal(t, http.StatusBadRequest, code)
assert.Equal(t, "error string", resp.(ErrorResponse).Error) assert.Equal(t, "error string", resp.(Response).Error)
} }
func TestError_InternalServerError(t *testing.T) { func TestError_InternalServerError(t *testing.T) {
code, resp := InternalServerError("error string") code, resp := InternalServerError("error string")
assert.Equal(t, http.StatusInternalServerError, code) assert.Equal(t, http.StatusInternalServerError, code)
assert.Equal(t, "error string", resp.(ErrorResponse).Error) assert.Equal(t, "error string", resp.(Response).Error)
} }

View File

@ -1,4 +1,4 @@
package core package errorutil
import ( import (
"errors" "errors"
@ -17,8 +17,8 @@ type ScopesList interface {
// insufficientScopesErr contains information about missing auth scopes. // insufficientScopesErr contains information about missing auth scopes.
type insufficientScopesErr struct { type insufficientScopesErr struct {
scopes []string
wrapped error wrapped error
scopes []string
} }
// Error message. // Error message.
@ -48,3 +48,8 @@ func NewInsufficientScopesErr(scopes []string) error {
wrapped: ErrInsufficientScopes, wrapped: ErrInsufficientScopes,
} }
} }
// AsInsufficientScopesErr returns ScopesList instance.
func AsInsufficientScopesErr(err error) ScopesList {
return err.(ScopesList) // nolint:errorlint
}

View File

@ -1,4 +1,4 @@
package core package errorutil
import ( import (
"errors" "errors"
@ -12,4 +12,5 @@ func TestError_NewScopesError(t *testing.T) {
scopesError := NewInsufficientScopesErr(scopes) scopesError := NewInsufficientScopesErr(scopes)
assert.True(t, errors.Is(scopesError, ErrInsufficientScopes)) assert.True(t, errors.Is(scopesError, ErrInsufficientScopes))
assert.Equal(t, scopes, AsInsufficientScopesErr(scopesError).Scopes())
} }

View File

@ -1,4 +1,4 @@
package core package httputil
import ( import (
"context" "context"
@ -10,6 +10,10 @@ import (
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/retailcrm/mg-transport-core/v2/core/config"
"github.com/retailcrm/mg-transport-core/v2/core/logger"
) )
// DefaultClient stores original http.DefaultClient. // DefaultClient stores original http.DefaultClient.
@ -43,18 +47,18 @@ var DefaultTransport = http.DefaultTransport
// fmt.Print(err) // fmt.Print(err)
// } // }
type HTTPClientBuilder struct { type HTTPClientBuilder struct {
logger logger.Logger
httpClient *http.Client httpClient *http.Client
httpTransport *http.Transport httpTransport *http.Transport
certsPool *x509.CertPool
dialer *net.Dialer dialer *net.Dialer
logger LoggerInterface
built bool
logging bool
timeout time.Duration
mockAddress string mockAddress string
mockHost string mockHost string
mockPort string mockPort string
mockedDomains []string mockedDomains []string
timeout time.Duration
tlsVersion uint16
logging bool
built bool
} }
// NewHTTPClientBuilder returns HTTPClientBuilder with default values. // NewHTTPClientBuilder returns HTTPClientBuilder with default values.
@ -63,6 +67,7 @@ func NewHTTPClientBuilder() *HTTPClientBuilder {
built: false, built: false,
httpClient: &http.Client{}, httpClient: &http.Client{},
httpTransport: &http.Transport{}, httpTransport: &http.Transport{},
tlsVersion: tls.VersionTLS12,
timeout: 30 * time.Second, timeout: 30 * time.Second,
mockAddress: "", mockAddress: "",
mockedDomains: []string{}, mockedDomains: []string{},
@ -71,7 +76,7 @@ func NewHTTPClientBuilder() *HTTPClientBuilder {
} }
// WithLogger sets provided logger into HTTPClientBuilder. // WithLogger sets provided logger into HTTPClientBuilder.
func (b *HTTPClientBuilder) WithLogger(logger LoggerInterface) *HTTPClientBuilder { func (b *HTTPClientBuilder) WithLogger(logger logger.Logger) *HTTPClientBuilder {
if logger != nil { if logger != nil {
b.logger = logger b.logger = logger
} }
@ -81,7 +86,7 @@ func (b *HTTPClientBuilder) WithLogger(logger LoggerInterface) *HTTPClientBuilde
// SetTimeout sets timeout for http client. // SetTimeout sets timeout for http client.
func (b *HTTPClientBuilder) SetTimeout(seconds time.Duration) *HTTPClientBuilder { func (b *HTTPClientBuilder) SetTimeout(seconds time.Duration) *HTTPClientBuilder {
seconds = seconds * time.Second seconds *= time.Second
b.timeout = seconds b.timeout = seconds
b.httpClient.Timeout = seconds b.httpClient.Timeout = seconds
return b return b
@ -108,7 +113,7 @@ func (b *HTTPClientBuilder) SetMockedDomains(domains []string) *HTTPClientBuilde
// SetSSLVerification enables or disables SSL certificates verification in client. // SetSSLVerification enables or disables SSL certificates verification in client.
func (b *HTTPClientBuilder) SetSSLVerification(enabled bool) *HTTPClientBuilder { func (b *HTTPClientBuilder) SetSSLVerification(enabled bool) *HTTPClientBuilder {
if b.httpTransport.TLSClientConfig == nil { if b.httpTransport.TLSClientConfig == nil {
b.httpTransport.TLSClientConfig = &tls.Config{} b.httpTransport.TLSClientConfig = b.baseTLSConfig()
} }
b.httpTransport.TLSClientConfig.InsecureSkipVerify = !enabled b.httpTransport.TLSClientConfig.InsecureSkipVerify = !enabled
@ -116,10 +121,19 @@ func (b *HTTPClientBuilder) SetSSLVerification(enabled bool) *HTTPClientBuilder
return b return b
} }
// SetSSLVerification enables or disables SSL certificates verification in client. // UseTLS10 restores TLS 1.0 as a minimal supported TLS version.
func (b *HTTPClientBuilder) UseTLS10() *HTTPClientBuilder {
b.tlsVersion = tls.VersionTLS10
if b.httpTransport.TLSClientConfig != nil {
b.httpTransport.TLSClientConfig.MinVersion = b.tlsVersion
}
return b
}
// SetCertPool sets provided TLS certificates pool into the client.
func (b *HTTPClientBuilder) SetCertPool(pool *x509.CertPool) *HTTPClientBuilder { func (b *HTTPClientBuilder) SetCertPool(pool *x509.CertPool) *HTTPClientBuilder {
if b.httpTransport.TLSClientConfig == nil { if b.httpTransport.TLSClientConfig == nil {
b.httpTransport.TLSClientConfig = &tls.Config{} b.httpTransport.TLSClientConfig = b.baseTLSConfig()
} }
b.httpTransport.TLSClientConfig.RootCAs = pool b.httpTransport.TLSClientConfig.RootCAs = pool
@ -134,7 +148,7 @@ func (b *HTTPClientBuilder) SetLogging(flag bool) *HTTPClientBuilder {
} }
// FromConfig fulfills mock configuration from HTTPClientConfig. // FromConfig fulfills mock configuration from HTTPClientConfig.
func (b *HTTPClientBuilder) FromConfig(config *HTTPClientConfig) *HTTPClientBuilder { func (b *HTTPClientBuilder) FromConfig(config *config.HTTPClientConfig) *HTTPClientBuilder {
if config == nil { if config == nil {
return b return b
} }
@ -153,9 +167,9 @@ func (b *HTTPClientBuilder) FromConfig(config *HTTPClientConfig) *HTTPClientBuil
return b return b
} }
// FromEngine fulfills mock configuration from ConfigInterface inside Engine. // baseTLSConfig returns *tls.Config with TLS 1.2 as a minimal supported version.
func (b *HTTPClientBuilder) FromEngine(engine *Engine) *HTTPClientBuilder { func (b *HTTPClientBuilder) baseTLSConfig() *tls.Config {
return b.FromConfig(engine.GetHTTPClientConfig()) return &tls.Config{MinVersion: b.tlsVersion} // nolint:gosec
} }
// buildDialer initializes dialer with provided timeout. // buildDialer initializes dialer with provided timeout.

View File

@ -1,8 +1,10 @@
package core package httputil
import ( import (
"context" "context"
"crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"io" "io"
"io/ioutil" "io/ioutil"
@ -18,6 +20,12 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/retailcrm/mg-transport-core/v2/core/config"
"github.com/retailcrm/mg-transport-core/v2/core/logger"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
) )
type HTTPClientBuilderTest struct { type HTTPClientBuilderTest struct {
@ -92,7 +100,7 @@ func (t *HTTPClientBuilderTest) Test_FromConfigNil() {
} }
func (t *HTTPClientBuilderTest) Test_FromConfig() { func (t *HTTPClientBuilderTest) Test_FromConfig() {
config := &HTTPClientConfig{ config := &config.HTTPClientConfig{
SSLVerification: boolPtr(true), SSLVerification: boolPtr(true),
MockAddress: "anothermock.local:3004", MockAddress: "anothermock.local:3004",
MockedDomains: []string{"example.gov"}, MockedDomains: []string{"example.gov"},
@ -107,22 +115,6 @@ func (t *HTTPClientBuilderTest) Test_FromConfig() {
assert.Equal(t.T(), config.Timeout*time.Second, t.builder.httpClient.Timeout) assert.Equal(t.T(), config.Timeout*time.Second, t.builder.httpClient.Timeout)
} }
func (t *HTTPClientBuilderTest) Test_FromEngine() {
engine := &Engine{
Config: Config{
HTTPClientConfig: &HTTPClientConfig{
SSLVerification: boolPtr(true),
MockAddress: "anothermock.local:3004",
MockedDomains: []string{"example.gov"},
},
Debug: false,
},
}
t.builder.FromEngine(engine)
assert.Equal(t.T(), engine.Config.GetHTTPClientConfig().MockAddress, t.builder.mockAddress)
}
func (t *HTTPClientBuilderTest) Test_buildDialer() { func (t *HTTPClientBuilderTest) Test_buildDialer() {
t.builder.buildDialer() t.builder.buildDialer()
@ -138,7 +130,7 @@ func (t *HTTPClientBuilderTest) Test_buildMocks() {
} }
func (t *HTTPClientBuilderTest) Test_WithLogger() { func (t *HTTPClientBuilderTest) Test_WithLogger() {
logger := NewLogger("telegram", logging.ERROR, DefaultLogFormatter()) logger := logger.NewStandard("telegram", logging.ERROR, logger.DefaultLogFormatter())
builder := NewHTTPClientBuilder() builder := NewHTTPClientBuilder()
require.Nil(t.T(), builder.logger) require.Nil(t.T(), builder.logger)
@ -248,9 +240,9 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc=
_, err = keyFile.WriteString(keyFileData) _, err = keyFile.WriteString(keyFileData)
require.NoError(t.T(), err, "cannot write temp key file") require.NoError(t.T(), err, "cannot write temp key file")
require.NoError(t.T(), require.NoError(t.T(),
ErrorCollector(certFile.Sync(), certFile.Close()), "cannot sync and close temp cert file") errorutil.Collect(certFile.Sync(), certFile.Close()), "cannot sync and close temp cert file")
require.NoError(t.T(), require.NoError(t.T(),
ErrorCollector(keyFile.Sync(), keyFile.Close()), "cannot sync and close temp key file") errorutil.Collect(keyFile.Sync(), keyFile.Close()), "cannot sync and close temp key file")
mux := &http.ServeMux{} mux := &http.ServeMux{}
srv := &http.Server{Addr: mockServerAddr, Handler: mux} srv := &http.Server{Addr: mockServerAddr, Handler: mux}
@ -261,8 +253,8 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc=
testSkipChan := make(chan error, 1) testSkipChan := make(chan error, 1)
go func(skip chan error) { go func(skip chan error) {
if err := srv.ListenAndServeTLS(certFile.Name(), keyFile.Name()); err != nil && err != http.ErrServerClosed { if err := srv.ListenAndServeTLS(certFile.Name(), keyFile.Name()); err != nil && !errors.Is(err, http.ErrServerClosed) {
skip <- fmt.Errorf("skipping test because server won't start: %s", err.Error()) skip <- fmt.Errorf("skipping test because server won't start: %w", err)
} }
}(testSkipChan) }(testSkipChan)
@ -314,6 +306,16 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc=
assert.Equal(t.T(), "ok", string(data), "invalid body contents") assert.Equal(t.T(), "ok", string(data), "invalid body contents")
} }
func (t *HTTPClientBuilderTest) Test_UseTLS10() {
client, err := NewHTTPClientBuilder().SetSSLVerification(true).UseTLS10().Build()
t.Require().NoError(err)
t.Require().NotNil(client)
t.Require().NotNil(client.Transport)
t.Require().NotNil(client.Transport.(*http.Transport).TLSClientConfig)
t.Assert().Equal(uint16(tls.VersionTLS10), client.Transport.(*http.Transport).TLSClientConfig.MinVersion)
}
// taken from https://stackoverflow.com/questions/23558425/how-do-i-get-the-local-ip-address-in-go // taken from https://stackoverflow.com/questions/23558425/how-do-i-get-the-local-ip-address-in-go
func getOutboundIP() net.IP { func getOutboundIP() net.IP {
conn, err := net.Dial("udp", "8.8.8.8:80") conn, err := net.Dial("udp", "8.8.8.8:80")
@ -330,3 +332,8 @@ func getOutboundIP() net.IP {
func Test_HTTPClientBuilder(t *testing.T) { func Test_HTTPClientBuilder(t *testing.T) {
suite.Run(t, new(HTTPClientBuilderTest)) suite.Run(t, new(HTTPClientBuilderTest))
} }
func boolPtr(val bool) *bool {
b := val
return &b
}

View File

@ -0,0 +1,133 @@
package testutil
import (
"bytes"
"fmt"
"io"
"os"
"github.com/op/go-logging"
"github.com/retailcrm/mg-transport-core/v2/core/logger"
)
// ReadBuffer is implemented by the BufferLogger.
// Its methods give access to the buffer contents and ability to read buffer as an io.Reader or reset its contents.
type ReadBuffer interface {
io.Reader
fmt.Stringer
Bytes() []byte
Reset()
}
// BufferedLogger is a logger that can return the data written to it.
type BufferedLogger interface {
ReadBuffer
logger.Logger
}
// BufferLogger is an implementation of the BufferedLogger.
type BufferLogger struct {
buf bytes.Buffer
}
// NewBufferedLogger returns new BufferedLogger instance.
func NewBufferedLogger() BufferedLogger {
return &BufferLogger{}
}
// Read bytes from the logger buffer. io.Reader implementation.
func (l *BufferLogger) Read(p []byte) (n int, err error) {
return l.buf.Read(p)
}
// String contents of the logger buffer. fmt.Stringer implementation.
func (l *BufferLogger) String() string {
return l.buf.String()
}
// Bytes is a shorthand for the underlying bytes.Buffer method. Returns byte slice with the buffer contents.
func (l *BufferLogger) Bytes() []byte {
return l.buf.Bytes()
}
// Reset is a shorthand for the underlying bytes.Buffer method. It will reset buffer contents.
func (l *BufferLogger) Reset() {
l.buf.Reset()
}
func (l *BufferLogger) write(level logging.Level, args ...interface{}) {
l.buf.WriteString(fmt.Sprintln(append([]interface{}{level.String(), "=>"}, args...)...))
}
func (l *BufferLogger) writef(level logging.Level, format string, args ...interface{}) {
l.buf.WriteString(fmt.Sprintf(level.String()+" => "+format, args...))
}
func (l *BufferLogger) Fatal(args ...interface{}) {
l.write(logging.CRITICAL, args...)
os.Exit(1)
}
func (l *BufferLogger) Fatalf(format string, args ...interface{}) {
l.writef(logging.CRITICAL, format, args...)
os.Exit(1)
}
func (l *BufferLogger) Panic(args ...interface{}) {
l.write(logging.CRITICAL, args...)
panic(fmt.Sprint(args...))
}
func (l *BufferLogger) Panicf(format string, args ...interface{}) {
l.writef(logging.CRITICAL, format, args...)
panic(fmt.Sprintf(format, args...))
}
func (l *BufferLogger) Critical(args ...interface{}) {
l.write(logging.CRITICAL, args...)
}
func (l *BufferLogger) Criticalf(format string, args ...interface{}) {
l.writef(logging.CRITICAL, format, args...)
}
func (l *BufferLogger) Error(args ...interface{}) {
l.write(logging.ERROR, args...)
}
func (l *BufferLogger) Errorf(format string, args ...interface{}) {
l.writef(logging.ERROR, format, args...)
}
func (l *BufferLogger) Warning(args ...interface{}) {
l.write(logging.WARNING, args...)
}
func (l *BufferLogger) Warningf(format string, args ...interface{}) {
l.writef(logging.WARNING, format, args...)
}
func (l *BufferLogger) Notice(args ...interface{}) {
l.write(logging.NOTICE, args...)
}
func (l *BufferLogger) Noticef(format string, args ...interface{}) {
l.writef(logging.NOTICE, format, args...)
}
func (l *BufferLogger) Info(args ...interface{}) {
l.write(logging.INFO, args...)
}
func (l *BufferLogger) Infof(format string, args ...interface{}) {
l.writef(logging.INFO, format, args...)
}
func (l *BufferLogger) Debug(args ...interface{}) {
l.write(logging.DEBUG, args...)
}
func (l *BufferLogger) Debugf(format string, args ...interface{}) {
l.writef(logging.DEBUG, format, args...)
}

View File

@ -0,0 +1,43 @@
package testutil
import (
"io"
"testing"
"github.com/op/go-logging"
"github.com/stretchr/testify/suite"
)
type BufferLoggerTest struct {
suite.Suite
logger BufferedLogger
}
func TestBufferLogger(t *testing.T) {
suite.Run(t, new(BufferLoggerTest))
}
func (t *BufferLoggerTest) SetupSuite() {
t.logger = NewBufferedLogger()
}
func (t *BufferLoggerTest) SetupTest() {
t.logger.Reset()
}
func (t *BufferLoggerTest) Log() string {
return t.logger.String()
}
func (t *BufferLoggerTest) Test_Read() {
t.logger.Debug("test")
data, err := io.ReadAll(t.logger)
t.Require().NoError(err)
t.Assert().Equal([]byte(logging.DEBUG.String()+" => test\n"), data)
}
func (t *BufferLoggerTest) Test_Bytes() {
t.logger.Debug("test")
t.Assert().Equal([]byte(logging.DEBUG.String()+" => test\n"), t.logger.Bytes())
}

View File

@ -0,0 +1,69 @@
package testutil
import (
"fmt"
"io"
"net/http"
"gopkg.in/h2non/gock.v1"
)
// UnmatchedRequestsTestingT contains all of *testing.T methods which are needed for AssertNoUnmatchedRequests.
type UnmatchedRequestsTestingT interface {
Log(...interface{})
Logf(string, ...interface{})
FailNow()
}
// AssertNoUnmatchedRequests check that gock didn't receive any request that it was not able to match.
// It will print out an entire request data for every unmatched request.
func AssertNoUnmatchedRequests(t UnmatchedRequestsTestingT) {
if gock.HasUnmatchedRequest() { // nolint:nestif
t.Log("gock has unmatched requests. their contents will be dumped here.\n")
for _, r := range gock.GetUnmatchedRequests() {
printRequestData(t, r)
fmt.Println()
}
t.FailNow()
}
}
func printRequestData(t UnmatchedRequestsTestingT, r *http.Request) {
t.Logf("%s %s %s\n", r.Proto, r.Method, r.URL.String())
t.Logf(" > RemoteAddr: %s\n", r.RemoteAddr)
t.Logf(" > Host: %s\n", r.Host)
t.Logf(" > Length: %d\n", r.ContentLength)
for _, encoding := range r.TransferEncoding {
t.Logf(" > Transfer-Encoding: %s\n", encoding)
}
for header, values := range r.Header {
for _, value := range values {
t.Logf("[header] %s: %s\n", header, value)
}
}
if r.Body == nil {
t.Log("No body is present.")
} else {
data, err := io.ReadAll(r.Body)
if err != nil {
t.Logf("Cannot read body: %s\n", err)
}
if len(data) == 0 {
t.Log("Body is empty.")
} else {
t.Logf("Body:\n%s\n", string(data))
}
}
for header, values := range r.Trailer {
for _, value := range values {
t.Logf("[trailer header] %s: %s\n", header, value)
}
}
}

View File

@ -0,0 +1,92 @@
package testutil
import (
"bytes"
"fmt"
"net/http"
"testing"
"github.com/stretchr/testify/suite"
"gopkg.in/h2non/gock.v1"
)
type testingTMock struct {
logs *bytes.Buffer
failed bool
}
func (t *testingTMock) Log(args ...interface{}) {
t.logs.WriteString(fmt.Sprintln(append([]interface{}{"=>"}, args...)...))
}
func (t *testingTMock) Logf(format string, args ...interface{}) {
t.logs.WriteString(fmt.Sprintf(" => "+format, args...))
}
func (t *testingTMock) FailNow() {
t.failed = true
}
func (t *testingTMock) Reset() {
t.logs.Reset()
t.failed = false
}
func (t *testingTMock) Logs() string {
return t.logs.String()
}
func (t *testingTMock) Failed() bool {
return t.failed
}
type AssertNoUnmatchedRequestsTest struct {
suite.Suite
tmock *testingTMock
}
func TestAssertNoUnmatchedRequests(t *testing.T) {
suite.Run(t, new(AssertNoUnmatchedRequestsTest))
}
func (t *AssertNoUnmatchedRequestsTest) SetupSuite() {
t.tmock = &testingTMock{logs: &bytes.Buffer{}}
}
func (t *AssertNoUnmatchedRequestsTest) SetupTest() {
t.tmock.Reset()
gock.CleanUnmatchedRequest()
}
func (t *AssertNoUnmatchedRequestsTest) Test_OK() {
AssertNoUnmatchedRequests(t.tmock)
t.Assert().Empty(t.tmock.Logs())
t.Assert().False(t.tmock.Failed())
}
func (t *AssertNoUnmatchedRequestsTest) Test_HasUnmatchedRequests() {
defer gock.Off()
gock.New("https://example.com").
Post("/dial").
MatchHeader("X-Client-Data", "something").
BodyString("something in body").
Reply(http.StatusOK)
_, _ = http.Get("https://example.com/nil")
AssertNoUnmatchedRequests(t.tmock)
t.Assert().True(gock.HasUnmatchedRequest())
t.Assert().NotEmpty(t.tmock.Logs())
t.Assert().Equal(`=> gock has unmatched requests. their contents will be dumped here.
=> HTTP/1.1 GET https://example.com/nil
=> > RemoteAddr:
=> > Host: example.com
=> > Length: 0
=> No body is present.
`, t.tmock.Logs())
t.Assert().True(t.tmock.Failed())
}

View File

@ -0,0 +1,55 @@
package testutil
import (
"database/sql"
"fmt"
"github.com/jinzhu/gorm"
)
// DeleteCreatedEntities sets up GORM `onCreate` hook and return a function that can be deferred to
// remove all the entities created after the hook was set up.
// You can use it like this:
//
// func TestSomething(t *testing.T){
// db, _ := gorm.Open(...)
// cleaner := DeleteCreatedEntities(db)
// defer cleaner()
// }.
func DeleteCreatedEntities(db *gorm.DB) func() { // nolint
type entity struct {
key interface{}
table string
keyname string
}
var entries []entity
hookName := "cleanupHook"
db.Callback().Create().After("gorm:create").Register(hookName, func(scope *gorm.Scope) {
fmt.Printf("Inserted entities of %s with %s=%v\n", scope.TableName(), scope.PrimaryKey(), scope.PrimaryKeyValue())
entries = append(entries, entity{table: scope.TableName(), keyname: scope.PrimaryKey(), key: scope.PrimaryKeyValue()})
})
return func() {
// Remove the hook once we're done
defer db.Callback().Create().Remove(hookName)
// Find out if the current db object is already a transaction
_, inTransaction := db.CommonDB().(*sql.Tx)
tx := db
if !inTransaction {
tx = db.Begin()
}
// Loop from the end. It is important that we delete the entries in the
// reverse order of their insertion
for i := len(entries) - 1; i >= 0; i-- {
entry := entries[i]
fmt.Printf("Deleting entities from '%s' table with key %v\n", entry.table, entry.key)
if err := tx.Table(entry.table).Where(entry.keyname+" = ?", entry.key).Delete("").Error; err != nil {
panic(err)
}
}
if !inTransaction {
tx.Commit()
}
}
}

View File

@ -0,0 +1,49 @@
package testutil
import (
"regexp"
"testing"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jinzhu/gorm"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type testUser struct {
Username string `gorm:"column:username; type:varchar(255); not null;"`
ID int `gorm:"primary_key"`
}
func TestDeleteCreatedEntities(t *testing.T) {
sqlDB, mock, err := sqlmock.New()
require.NoError(t, err)
db, err := gorm.Open("postgres", sqlDB)
require.NoError(t, err)
mock.
ExpectExec(regexp.QuoteMeta(`CREATE TABLE "test_users" ("username" varchar(255) NOT NULL,"id" serial , PRIMARY KEY ("id"))`)).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectBegin()
mock.
ExpectQuery(regexp.QuoteMeta(`INSERT INTO "test_users" ("username") VALUES ($1) RETURNING "test_users"."id"`)).
WithArgs("user").
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1"))
mock.ExpectCommit()
mock.ExpectBegin()
mock.
ExpectExec(regexp.QuoteMeta(`DELETE FROM "test_users" WHERE (id = $1)`)).
WithArgs(1).
WillReturnResult(sqlmock.NewResult(1, 1))
mock.ExpectCommit()
cleaner := DeleteCreatedEntities(db)
require.NoError(t, db.AutoMigrate(&testUser{}).Error)
require.NoError(t, db.Create(&testUser{Username: "user"}).Error)
cleaner()
assert.NoError(t, mock.ExpectationsWereMet())
}

View File

@ -1,4 +1,4 @@
package core package testutil
import ( import (
"errors" "errors"
@ -92,9 +92,8 @@ func (t *TranslationsExtractor) loadYAMLFile(fileName string) (map[string]interf
func (t *TranslationsExtractor) loadYAML(fileName string) (map[string]interface{}, error) { func (t *TranslationsExtractor) loadYAML(fileName string) (map[string]interface{}, error) {
if t.TranslationsPath != "" { if t.TranslationsPath != "" {
return t.loadYAMLFile(filepath.Join(t.TranslationsPath, fileName)) return t.loadYAMLFile(filepath.Join(t.TranslationsPath, fileName))
} else {
return t.loadYAMLFromFS(fileName)
} }
return t.loadYAMLFromFS(fileName)
} }
// GetMapKeys returns sorted map keys from map[string]interface{} - useful to check keys in several translation files. // GetMapKeys returns sorted map keys from map[string]interface{} - useful to check keys in several translation files.

View File

@ -1,4 +1,4 @@
package core package testutil
import ( import (
"io/ioutil" "io/ioutil"

View File

@ -1,4 +1,4 @@
package core package util
import ( import (
// nolint:gosec // nolint:gosec
@ -18,6 +18,17 @@ import (
"github.com/aws/aws-sdk-go/service/s3/s3manager" "github.com/aws/aws-sdk-go/service/s3/s3manager"
retailcrm "github.com/retailcrm/api-client-go/v2" retailcrm "github.com/retailcrm/api-client-go/v2"
v1 "github.com/retailcrm/mg-transport-api-client-go/v1" v1 "github.com/retailcrm/mg-transport-api-client-go/v1"
"github.com/retailcrm/mg-transport-core/v2/core/config"
"github.com/retailcrm/mg-transport-core/v2/core/logger"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
)
var (
markdownSymbols = []string{"*", "_", "`", "["}
slashRegex = regexp.MustCompile(`/+$`)
) )
var DefaultScopes = []string{ var DefaultScopes = []string{
@ -78,28 +89,28 @@ var defaultCurrencies = map[string]string{
// Utils service object. // Utils service object.
type Utils struct { type Utils struct {
IsDebug bool Logger logger.Logger
TokenCounter uint32
ConfigAWS ConfigAWS
Logger LoggerInterface
slashRegex *regexp.Regexp slashRegex *regexp.Regexp
AWS config.AWS
TokenCounter uint32
IsDebug bool
} }
// NewUtils will create new Utils instance. // NewUtils will create new Utils instance.
func NewUtils(awsConfig ConfigAWS, logger LoggerInterface, debug bool) *Utils { func NewUtils(awsConfig config.AWS, logger logger.Logger, debug bool) *Utils {
return &Utils{ return &Utils{
IsDebug: debug, IsDebug: debug,
ConfigAWS: awsConfig, AWS: awsConfig,
Logger: logger, Logger: logger,
TokenCounter: 0, TokenCounter: 0,
slashRegex: slashRegex, slashRegex: slashRegex,
} }
} }
// resetUtils. // ResetUtils resets the utils inner state.
func (u *Utils) resetUtils(awsConfig ConfigAWS, debug bool, tokenCounter uint32) { func (u *Utils) ResetUtils(awsConfig config.AWS, debug bool, tokenCounter uint32) {
u.TokenCounter = tokenCounter u.TokenCounter = tokenCounter
u.ConfigAWS = awsConfig u.AWS = awsConfig
u.IsDebug = debug u.IsDebug = debug
u.slashRegex = slashRegex u.slashRegex = slashRegex
} }
@ -124,7 +135,7 @@ func (u *Utils) GetAPIClient(url, key string, scopes []string) (*retailcrm.Clien
if res := u.checkScopes(cr.Scopes, scopes); len(res) != 0 { if res := u.checkScopes(cr.Scopes, scopes); len(res) != 0 {
u.Logger.Error(url, status, res) u.Logger.Error(url, status, res)
return nil, http.StatusBadRequest, NewInsufficientScopesErr(res) return nil, http.StatusBadRequest, errorutil.NewInsufficientScopesErr(res)
} }
return client, 0, nil return client, 0, nil
@ -153,10 +164,10 @@ func (u *Utils) checkScopes(scopes []string, scopesRequired []string) []string {
func (u *Utils) UploadUserAvatar(url string) (picURLs3 string, err error) { func (u *Utils) UploadUserAvatar(url string) (picURLs3 string, err error) {
s3Config := &aws.Config{ s3Config := &aws.Config{
Credentials: credentials.NewStaticCredentials( Credentials: credentials.NewStaticCredentials(
u.ConfigAWS.AccessKeyID, u.AWS.AccessKeyID,
u.ConfigAWS.SecretAccessKey, u.AWS.SecretAccessKey,
""), ""),
Region: aws.String(u.ConfigAWS.Region), Region: aws.String(u.AWS.Region),
} }
s := session.Must(session.NewSession(s3Config)) s := session.Must(session.NewSession(s3Config))
@ -174,10 +185,10 @@ func (u *Utils) UploadUserAvatar(url string) (picURLs3 string, err error) {
} }
result, err := uploader.Upload(&s3manager.UploadInput{ result, err := uploader.Upload(&s3manager.UploadInput{
Bucket: aws.String(u.ConfigAWS.Bucket), Bucket: aws.String(u.AWS.Bucket),
Key: aws.String(fmt.Sprintf("%v/%v.jpg", u.ConfigAWS.FolderName, u.GenerateToken())), Key: aws.String(fmt.Sprintf("%v/%v.jpg", u.AWS.FolderName, u.GenerateToken())),
Body: resp.Body, Body: resp.Body,
ContentType: aws.String(u.ConfigAWS.ContentType), ContentType: aws.String(u.AWS.ContentType),
ACL: aws.String("public-read"), ACL: aws.String("public-read"),
}) })
if err != nil { if err != nil {
@ -228,7 +239,7 @@ func GetEntitySHA1(v interface{}) (hash string, err error) {
// ReplaceMarkdownSymbols will remove markdown symbols from text. // ReplaceMarkdownSymbols will remove markdown symbols from text.
func ReplaceMarkdownSymbols(s string) string { func ReplaceMarkdownSymbols(s string) string {
for _, v := range markdownSymbols { for _, v := range markdownSymbols {
s = strings.Replace(s, v, "\\"+v, -1) s = strings.ReplaceAll(s, v, "\\"+v)
} }
return s return s

View File

@ -1,4 +1,4 @@
package core package util
import ( import (
"encoding/json" "encoding/json"
@ -8,13 +8,19 @@ import (
"testing" "testing"
"time" "time"
"github.com/h2non/gock"
"github.com/op/go-logging" "github.com/op/go-logging"
retailcrm "github.com/retailcrm/api-client-go/v2" retailcrm "github.com/retailcrm/api-client-go/v2"
v1 "github.com/retailcrm/mg-transport-api-client-go/v1" v1 "github.com/retailcrm/mg-transport-api-client-go/v1"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"gopkg.in/h2non/gock.v1"
"github.com/retailcrm/mg-transport-core/v2/core/config"
"github.com/retailcrm/mg-transport-core/v2/core/logger"
"github.com/retailcrm/mg-transport-core/v2/core/util/errorutil"
) )
var ( var (
@ -32,8 +38,8 @@ func mgClient() *v1.MgClient {
} }
func (u *UtilsTest) SetupSuite() { func (u *UtilsTest) SetupSuite() {
logger := NewLogger("code", logging.DEBUG, DefaultLogFormatter()) logger := logger.NewStandard("code", logging.DEBUG, logger.DefaultLogFormatter())
awsConfig := ConfigAWS{ awsConfig := config.AWS{
AccessKeyID: "access key id (will be removed)", AccessKeyID: "access key id (will be removed)",
SecretAccessKey: "secret access key", SecretAccessKey: "secret access key",
Region: "region", Region: "region",
@ -47,15 +53,15 @@ func (u *UtilsTest) SetupSuite() {
} }
func (u *UtilsTest) Test_ResetUtils() { func (u *UtilsTest) Test_ResetUtils() {
assert.Equal(u.T(), "access key id (will be removed)", u.utils.ConfigAWS.AccessKeyID) assert.Equal(u.T(), "access key id (will be removed)", u.utils.AWS.AccessKeyID)
assert.Equal(u.T(), uint32(12346), u.utils.TokenCounter) assert.Equal(u.T(), uint32(12346), u.utils.TokenCounter)
assert.False(u.T(), u.utils.IsDebug) assert.False(u.T(), u.utils.IsDebug)
awsConfig := u.utils.ConfigAWS awsConfig := u.utils.AWS
awsConfig.AccessKeyID = "access key id" awsConfig.AccessKeyID = "access key id"
u.utils.resetUtils(awsConfig, true, 0) u.utils.ResetUtils(awsConfig, true, 0)
assert.Equal(u.T(), "access key id", u.utils.ConfigAWS.AccessKeyID) assert.Equal(u.T(), "access key id", u.utils.AWS.AccessKeyID)
assert.Equal(u.T(), uint32(0), u.utils.TokenCounter) assert.Equal(u.T(), uint32(0), u.utils.TokenCounter)
assert.True(u.T(), u.utils.IsDebug) assert.True(u.T(), u.utils.IsDebug)
} }
@ -110,7 +116,7 @@ func (u *UtilsTest) Test_GetAPIClient_FailAPICredentials() {
_, status, err := u.utils.GetAPIClient(testCRMURL, "key", DefaultScopes) _, status, err := u.utils.GetAPIClient(testCRMURL, "key", DefaultScopes)
assert.Equal(u.T(), http.StatusBadRequest, status) assert.Equal(u.T(), http.StatusBadRequest, status)
if assert.NotNil(u.T(), err) { if assert.NotNil(u.T(), err) {
assert.True(u.T(), errors.Is(err, ErrInsufficientScopes)) assert.True(u.T(), errors.Is(err, errorutil.ErrInsufficientScopes))
} }
} }

View File

@ -8,6 +8,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite" "github.com/stretchr/testify/suite"
"github.com/retailcrm/mg-transport-core/v2/core/db/models"
) )
type ValidatorSuite struct { type ValidatorSuite struct {
@ -49,7 +51,7 @@ func (s *ValidatorSuite) Test_ValidationFails() {
} }
for _, domain := range crmDomains { for _, domain := range crmDomains {
conn := Connection{ conn := models.Connection{
Key: "key", Key: "key",
URL: domain, URL: domain,
} }
@ -81,7 +83,7 @@ func (s *ValidatorSuite) Test_ValidationSuccess() {
} }
for _, domain := range crmDomains { for _, domain := range crmDomains {
conn := Connection{ conn := models.Connection{
Key: "key", Key: "key",
URL: domain, URL: domain,
} }

4
go.mod
View File

@ -1,4 +1,4 @@
module github.com/retailcrm/mg-transport-core module github.com/retailcrm/mg-transport-core/v2
go 1.12 go 1.12
@ -16,7 +16,6 @@ require (
github.com/golang/protobuf v1.5.2 // indirect github.com/golang/protobuf v1.5.2 // indirect
github.com/gorilla/securecookie v1.1.1 github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.0 github.com/gorilla/sessions v1.2.0
github.com/h2non/gock v1.0.10
github.com/jessevdk/go-flags v1.4.0 github.com/jessevdk/go-flags v1.4.0
github.com/jinzhu/gorm v1.9.11 github.com/jinzhu/gorm v1.9.11
github.com/json-iterator/go v1.1.11 // indirect github.com/json-iterator/go v1.1.11 // indirect
@ -36,6 +35,7 @@ require (
google.golang.org/protobuf v1.27.1 // indirect google.golang.org/protobuf v1.27.1 // indirect
gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect
gopkg.in/gormigrate.v1 v1.6.0 gopkg.in/gormigrate.v1 v1.6.0
gopkg.in/h2non/gock.v1 v1.1.2
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
) )

6
go.sum
View File

@ -159,8 +159,6 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ= github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ=
github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/h2non/gock v1.0.10 h1:EzHYzKKSLN4xk0w193uAy3tp8I3+L1jmaI2Mjg4lCgU=
github.com/h2non/gock v1.0.10/go.mod h1:CZMcB0Lg5IWnr9bF79pPMg9WeV6WumxQiUJ1UvdO1iE=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw=
github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
@ -249,10 +247,6 @@ github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y8
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/retailcrm/api-client-go/v2 v2.0.1 h1:wyM0F1VTSJPO8PVEXB0u7s6ZEs0xRCnUu7YdQ1E4UZ8=
github.com/retailcrm/api-client-go/v2 v2.0.1/go.mod h1:1yTZl9+gd3+/k0kAJe7sYvC+mL4fqMwIwtnSgSWZlkQ=
github.com/retailcrm/api-client-go/v2 v2.0.2 h1:oFQycGqwcvfgW2JrbeWmPjxH7Wmh9j762c4FRxCDGNs=
github.com/retailcrm/api-client-go/v2 v2.0.2/go.mod h1:1yTZl9+gd3+/k0kAJe7sYvC+mL4fqMwIwtnSgSWZlkQ=
github.com/retailcrm/api-client-go/v2 v2.0.3 h1:7oKwOZgRLM7eEJUvFNhzfnyIJVomy80ffOEBdYpQRIs= github.com/retailcrm/api-client-go/v2 v2.0.3 h1:7oKwOZgRLM7eEJUvFNhzfnyIJVomy80ffOEBdYpQRIs=
github.com/retailcrm/api-client-go/v2 v2.0.3/go.mod h1:1yTZl9+gd3+/k0kAJe7sYvC+mL4fqMwIwtnSgSWZlkQ= github.com/retailcrm/api-client-go/v2 v2.0.3/go.mod h1:1yTZl9+gd3+/k0kAJe7sYvC+mL4fqMwIwtnSgSWZlkQ=
github.com/retailcrm/mg-transport-api-client-go v1.1.32 h1:IBPltSoD5q2PPZJbNC/prK5F9rEVPXVx/ZzDpi7HKhs= github.com/retailcrm/mg-transport-api-client-go v1.1.32 h1:IBPltSoD5q2PPZJbNC/prK5F9rEVPXVx/ZzDpi7HKhs=