diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b928a72..198c5de 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,14 +22,14 @@ jobs: - name: Lint code with golangci-lint uses: golangci/golangci-lint-action@v2 with: - version: v1.36 + version: v1.42.1 only-new-issues: true tests: name: Tests runs-on: ubuntu-latest strategy: matrix: - go-version: ['1.16'] + go-version: ['1.16', '1.17'] steps: - name: Set up Go ${{ matrix.go-version }} uses: actions/setup-go@v2 diff --git a/.golangci.yml b/.golangci.yml index 656e05b..306ec41 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -32,19 +32,15 @@ linters: - gocyclo - godot - goimports - - golint - - gomnd + - revive - gosec - ifshort - - interfacer - lll - makezero - - maligned - misspell - nestif - prealloc - predeclared - - scopelint - sqlclosecheck - unconvert - whitespace @@ -56,9 +52,11 @@ linters-settings: enable: - assign - atomic + - atomicalign - bools - buildtag - copylocks + - fieldalignment - httpresponse - loopclosure - lostcancel @@ -152,12 +150,10 @@ linters-settings: local-prefixes: github.com/retailcrm/mg-transport-core lll: line-length: 120 - maligned: - suggest-new: true misspell: locale: US nestif: - min-complexity: 4 + min-complexity: 6 whitespace: multi-if: false multi-func: false @@ -166,7 +162,6 @@ issues: exclude-rules: - path: _test\.go linters: - - gomnd - lll - bodyclose - errcheck @@ -175,7 +170,6 @@ issues: - ineffassign - whitespace - makezero - - maligned - ifshort - errcheck - funlen diff --git a/cmd/transport-core-tool/main.go b/cmd/transport-core-tool/main.go index 871ec1e..2fdcfd1 100644 --- a/cmd/transport-core-tool/main.go +++ b/cmd/transport-core-tool/main.go @@ -5,7 +5,7 @@ import ( "github.com/jessevdk/go-flags" - "github.com/retailcrm/mg-transport-core/core" + "github.com/retailcrm/mg-transport-core/v2/core/db" ) // Options for tool command. @@ -20,7 +20,7 @@ func init() { _, err := parser.AddCommand("migration", "Create new empty migration in specified directory.", "Create new empty migration in specified directory.", - &core.NewMigrationCommand{}, + &db.NewMigrationCommand{}, ) if err != nil { @@ -30,7 +30,7 @@ func init() { func main() { if _, err := parser.Parse(); err != nil { - if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { + if flagsErr, ok := err.(*flags.Error); ok && flagsErr.Type == flags.ErrHelp { // nolint:errorlint os.Exit(0) } else { os.Exit(1) diff --git a/core/config.go b/core/config/config.go similarity index 92% rename from core/config.go rename to core/config/config.go index a1e2578..beb21b7 100644 --- a/core/config.go +++ b/core/config/config.go @@ -1,28 +1,22 @@ -package core +package config import ( "io/ioutil" "path/filepath" - "regexp" "time" "github.com/op/go-logging" "gopkg.in/yaml.v2" ) -var ( - markdownSymbols = []string{"*", "_", "`", "["} - slashRegex = regexp.MustCompile(`/+$`) -) - -// ConfigInterface settings data structure. -type ConfigInterface interface { +// Configuration settings data structure. +type Configuration interface { GetVersion() string GetSentryDSN() string GetLogLevel() logging.Level GetHTTPConfig() HTTPServerConfig GetDBConfig() DatabaseConfig - GetAWSConfig() ConfigAWS + GetAWSConfig() AWS GetTransportInfo() InfoInterface GetHTTPClientConfig() *HTTPClientConfig GetUpdateInterval() int @@ -39,16 +33,16 @@ type InfoInterface interface { // Config struct. type Config struct { - Version string `yaml:"version"` - LogLevel logging.Level `yaml:"log_level"` - Database DatabaseConfig `yaml:"database"` - SentryDSN string `yaml:"sentry_dsn"` - HTTPServer HTTPServerConfig `yaml:"http_server"` - Debug bool `yaml:"debug"` - UpdateInterval int `yaml:"update_interval"` - ConfigAWS ConfigAWS `yaml:"config_aws"` - TransportInfo Info `yaml:"transport_info"` HTTPClientConfig *HTTPClientConfig `yaml:"http_client"` + ConfigAWS AWS `yaml:"config_aws"` + TransportInfo Info `yaml:"transport_info"` + HTTPServer HTTPServerConfig `yaml:"http_server"` + Version string `yaml:"version"` + SentryDSN string `yaml:"sentry_dsn"` + Database DatabaseConfig `yaml:"database"` + UpdateInterval int `yaml:"update_interval"` + LogLevel logging.Level `yaml:"log_level"` + Debug bool `yaml:"debug"` } // Info struct. @@ -59,8 +53,8 @@ type Info struct { Secret string `yaml:"secret"` } -// ConfigAWS struct. -type ConfigAWS struct { +// AWS struct. +type AWS struct { AccessKeyID string `yaml:"access_key_id"` SecretAccessKey string `yaml:"secret_access_key"` Region string `yaml:"region"` @@ -72,19 +66,19 @@ type ConfigAWS struct { // DatabaseConfig struct. type DatabaseConfig struct { Connection interface{} `yaml:"connection"` - Logging bool `yaml:"logging"` TablePrefix string `yaml:"table_prefix"` MaxOpenConnections int `yaml:"max_open_connections"` MaxIdleConnections int `yaml:"max_idle_connections"` ConnectionLifetime int `yaml:"connection_lifetime"` + Logging bool `yaml:"logging"` } // HTTPClientConfig struct. type HTTPClientConfig struct { - Timeout time.Duration `yaml:"timeout"` SSLVerification *bool `yaml:"ssl_verification"` MockAddress string `yaml:"mock_address"` MockedDomains []string `yaml:"mocked_domains"` + Timeout time.Duration `yaml:"timeout"` } // HTTPServerConfig struct. @@ -157,7 +151,7 @@ func (c Config) IsDebug() bool { } // GetAWSConfig AWS configuration. -func (c Config) GetAWSConfig() ConfigAWS { +func (c Config) GetAWSConfig() AWS { return c.ConfigAWS } diff --git a/core/config_test.go b/core/config/config_test.go similarity index 99% rename from core/config_test.go rename to core/config/config_test.go index 047b79a..9d8794f 100644 --- a/core/config_test.go +++ b/core/config/config_test.go @@ -1,4 +1,4 @@ -package core +package config import ( "io/ioutil" diff --git a/core/migrate.go b/core/db/migrate.go similarity index 94% rename from core/migrate.go rename to core/db/migrate.go index 962c3b9..2f86a88 100644 --- a/core/migrate.go +++ b/core/db/migrate.go @@ -1,4 +1,4 @@ -package core +package db import ( "fmt" @@ -16,9 +16,9 @@ var migrations *Migrate type Migrate struct { db *gorm.DB first *gormigrate.Migration - versions []string migrations map[string]*gormigrate.Migration GORMigrate *gormigrate.Gormigrate + versions []string prepared bool } @@ -123,7 +123,7 @@ func (m *Migrate) MigrateNextTo(version string) error { case current < next: return m.GORMigrate.MigrateTo(next) case current > next: - return errors.New(fmt.Sprintf("current migration version '%s' is higher than fetched version '%s'", current, next)) + return fmt.Errorf("current migration version '%s' is higher than fetched version '%s'", current, next) default: return nil } @@ -144,7 +144,7 @@ func (m *Migrate) MigratePreviousTo(version string) error { case current > prev: return m.GORMigrate.RollbackTo(prev) case current < prev: - return errors.New(fmt.Sprintf("current migration version '%s' is lower than fetched version '%s'", current, prev)) + return fmt.Errorf("current migration version '%s' is lower than fetched version '%s'", current, prev) case prev == "0": return m.GORMigrate.RollbackMigration(m.first) default: @@ -241,8 +241,11 @@ func (m *Migrate) prepareMigrations() error { return nil } + i := 0 + keys = make([]string, len(m.migrations)) for key := range m.migrations { - keys = append(keys, key) + keys[i] = key + i++ } sort.Strings(keys) diff --git a/core/migrate_test.go b/core/db/migrate_test.go similarity index 99% rename from core/migrate_test.go rename to core/db/migrate_test.go index 1533cb8..d908ce1 100644 --- a/core/migrate_test.go +++ b/core/db/migrate_test.go @@ -1,4 +1,4 @@ -package core +package db import ( "database/sql" diff --git a/core/migration_generator.go b/core/db/migration_generator.go similarity index 95% rename from core/migration_generator.go rename to core/db/migration_generator.go index 7e66ddd..8d6d20a 100644 --- a/core/migration_generator.go +++ b/core/db/migration_generator.go @@ -1,4 +1,4 @@ -package core +package db import ( "fmt" @@ -32,7 +32,7 @@ func init() { // NewMigrationCommand struct. type NewMigrationCommand struct { - Directory string `short:"d" long:"directory" default:"./migrations" description:"Directory where migration will be created"` + Directory string `short:"d" long:"directory" default:"./migrations" description:"Directory where migration will be created"` // nolint:lll } // FileExists returns true if provided file exist and it's not directory. diff --git a/core/migration_generator_test.go b/core/db/migration_generator_test.go similarity index 86% rename from core/migration_generator_test.go rename to core/db/migration_generator_test.go index ea28944..1d86ac5 100644 --- a/core/migration_generator_test.go +++ b/core/db/migration_generator_test.go @@ -1,4 +1,4 @@ -package core +package db import ( "fmt" @@ -26,8 +26,8 @@ func (s *MigrationGeneratorSuite) SetupSuite() { func (s *MigrationGeneratorSuite) Test_FileExists() { var ( - seededRand *rand.Rand = rand.New(rand.NewSource(time.Now().UnixNano())) - notExist = fmt.Sprintf("/tmp/%d", seededRand.Int31()) + seededRand = rand.New(rand.NewSource(time.Now().UnixNano())) // nolint:gosec + notExist = fmt.Sprintf("/tmp/%d", seededRand.Int31()) ) assert.False(s.T(), s.command.FileExists(notExist)) diff --git a/core/db/models/account.go b/core/db/models/account.go new file mode 100644 index 0000000..6bd3e0c --- /dev/null +++ b/core/db/models/account.go @@ -0,0 +1,18 @@ +package models + +import "time" + +// Account model. +type Account struct { + CreatedAt time.Time + UpdatedAt time.Time + ChannelSettingsHash string `gorm:"column:channel_settings_hash; type:varchar(70)" binding:"max=70"` + Name string `gorm:"column:name; type:varchar(100)" json:"name,omitempty" binding:"max=100"` + Lang string `gorm:"column:lang; type:varchar(2)" json:"lang,omitempty" binding:"max=2"` + Channel uint64 `gorm:"column:channel; not null; unique" json:"channel,omitempty"` + ID int `gorm:"primary_key"` + ConnectionID int `gorm:"column:connection_id" json:"connectionId,omitempty"` +} + +// Accounts list. +type Accounts []Account diff --git a/core/db/models/connection.go b/core/db/models/connection.go new file mode 100644 index 0000000..3bc8512 --- /dev/null +++ b/core/db/models/connection.go @@ -0,0 +1,17 @@ +package models + +import "time" + +// Connection model. +type Connection struct { + CreatedAt time.Time + UpdatedAt time.Time + Key string `gorm:"column:api_key; type:varchar(100); not null" json:"api_key,omitempty" binding:"required,max=100"` // nolint:lll + URL string `gorm:"column:api_url; type:varchar(255); not null" json:"api_url,omitempty" binding:"required,validateCrmURL,max=255"` // nolint:lll + GateURL string `gorm:"column:mg_url; type:varchar(255); not null;" json:"mg_url,omitempty" binding:"max=255"` + GateToken string `gorm:"column:mg_token; type:varchar(100); not null; unique" json:"mg_token,omitempty" binding:"max=100"` // nolint:lll + ClientID string `gorm:"column:client_id; type:varchar(70); not null; unique" json:"clientId,omitempty"` + Accounts []Account `gorm:"foreignkey:ConnectionID"` + ID int `gorm:"primary_key"` + Active bool `json:"active,omitempty"` +} diff --git a/core/db/models/user.go b/core/db/models/user.go new file mode 100644 index 0000000..5a7b4b0 --- /dev/null +++ b/core/db/models/user.go @@ -0,0 +1,24 @@ +package models + +import "time" + +// User model. +type User struct { + CreatedAt time.Time + UpdatedAt time.Time + ExternalID string `gorm:"column:external_id; type:varchar(255); not null; unique"` + UserPhotoURL string `gorm:"column:user_photo_url; type:varchar(255)" binding:"max=255"` + UserPhotoID string `gorm:"column:user_photo_id; type:varchar(100)" binding:"max=100"` + ID int `gorm:"primary_key"` +} + +// TableName will return table name for User +// It will not work if User is not embedded, but mapped as another type +// type MyUser User // will not work +// but +// type MyUser struct { // will work +// User +// } +func (User) TableName() string { + return "mg_user" +} diff --git a/core/models_test.go b/core/db/models/user_test.go similarity index 90% rename from core/models_test.go rename to core/db/models/user_test.go index 6c26666..7f54a10 100644 --- a/core/models_test.go +++ b/core/db/models/user_test.go @@ -1,4 +1,4 @@ -package core +package models import ( "testing" diff --git a/core/orm.go b/core/db/orm.go similarity index 73% rename from core/orm.go rename to core/db/orm.go index 9637172..3eea504 100644 --- a/core/orm.go +++ b/core/db/orm.go @@ -1,9 +1,12 @@ -package core +package db import ( "time" "github.com/jinzhu/gorm" + + "github.com/retailcrm/mg-transport-core/v2/core/config" + // PostgreSQL is an default. _ "github.com/jinzhu/gorm/dialects/postgres" ) @@ -14,13 +17,14 @@ type ORM struct { } // NewORM will init new database connection. -func NewORM(config DatabaseConfig) *ORM { +func NewORM(config config.DatabaseConfig) *ORM { orm := &ORM{} - orm.createDB(config) + orm.CreateDB(config) return orm } -func (orm *ORM) createDB(config DatabaseConfig) { +// CreateDB connection using provided config. +func (orm *ORM) CreateDB(config config.DatabaseConfig) { db, err := gorm.Open("postgres", config.Connection) if err != nil { panic(err) diff --git a/core/orm_test.go b/core/db/orm_test.go similarity index 85% rename from core/orm_test.go rename to core/db/orm_test.go index 404c947..00f67fa 100644 --- a/core/orm_test.go +++ b/core/db/orm_test.go @@ -1,4 +1,4 @@ -package core +package db import ( "database/sql" @@ -7,6 +7,8 @@ import ( "github.com/DATA-DOG/go-sqlmock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/retailcrm/mg-transport-core/v2/core/config" ) func TestORM_NewORM(t *testing.T) { @@ -22,7 +24,7 @@ func TestORM_NewORM(t *testing.T) { db, _, err = sqlmock.New() require.NoError(t, err) - config := DatabaseConfig{ + config := config.DatabaseConfig{ Connection: db, Logging: true, TablePrefix: "", @@ -39,7 +41,7 @@ func TestORM_createDB_Fail(t *testing.T) { assert.NotNil(t, recover()) }() - NewORM(DatabaseConfig{Connection: nil}) + NewORM(config.DatabaseConfig{Connection: nil}) } func TestORM_CloseDB(t *testing.T) { @@ -56,7 +58,7 @@ func TestORM_CloseDB(t *testing.T) { db, dbMock, err = sqlmock.New() require.NoError(t, err) - config := DatabaseConfig{ + config := config.DatabaseConfig{ Connection: db, Logging: true, TablePrefix: "", diff --git a/core/doc.go b/core/doc.go index 204fc7f..b1916a4 100644 --- a/core/doc.go +++ b/core/doc.go @@ -1,7 +1,6 @@ -// Copyright (c) 2019 RetailDriver LLC -// Use of this source code is governed by a MIT /* -Package core provides different functions like error-reporting, logging, localization, etc. in order to make it easier to create transports. +Package core provides different functions like error-reporting, logging, localization, etc. +to make it easier to create transports. Usage: package main @@ -110,5 +109,7 @@ Migration generator This library contains helper tool for transports. You can install it via go: $ go get -u github.com/retailcrm/mg-transport-core/cmd/transport-core-tool Currently, it only can generate new migrations for your transport. + +Copyright (c) 2019 RetailDriver LLC. Usage of this source code is governed by a MIT license. */ package core diff --git a/core/engine.go b/core/engine.go index 862905e..b577e4c 100644 --- a/core/engine.go +++ b/core/engine.go @@ -12,34 +12,42 @@ import ( "github.com/gorilla/sessions" "github.com/op/go-logging" "golang.org/x/text/language" + + "github.com/retailcrm/mg-transport-core/v2/core/config" + "github.com/retailcrm/mg-transport-core/v2/core/db" + "github.com/retailcrm/mg-transport-core/v2/core/middleware" + "github.com/retailcrm/mg-transport-core/v2/core/util" + "github.com/retailcrm/mg-transport-core/v2/core/util/httputil" + + "github.com/retailcrm/mg-transport-core/v2/core/logger" ) var boolTrue = true // DefaultHTTPClientConfig is a default config for HTTP client. It will be used by Engine for building HTTP client // if HTTP client config is not present in the configuration. -var DefaultHTTPClientConfig = &HTTPClientConfig{ +var DefaultHTTPClientConfig = &config.HTTPClientConfig{ Timeout: 30, SSLVerification: &boolTrue, } // Engine struct. type Engine struct { + logger logger.Logger + Sessions sessions.Store + LogFormatter logging.Formatter + Config config.Configuration + ginEngine *gin.Engine + csrf *middleware.CSRF + httpClient *http.Client + jobManager *JobManager + db.ORM Localizer - ORM - Sentry - Utils - ginEngine *gin.Engine - httpClient *http.Client - logger LoggerInterface - mutex sync.RWMutex - csrf *CSRF - jobManager *JobManager + util.Utils PreloadLanguages []language.Tag - Sessions sessions.Store - Config ConfigInterface - LogFormatter logging.Formatter - prepared bool + Sentry + mutex sync.RWMutex + prepared bool } // New Engine instance (must be configured manually, gin can be accessed via engine.Router() directly or @@ -52,9 +60,9 @@ func New() *Engine { loadMutex: &sync.RWMutex{}, }, PreloadLanguages: []language.Tag{}, - ORM: ORM{}, + ORM: db.ORM{}, Sentry: Sentry{}, - Utils: Utils{}, + Utils: util.Utils{}, ginEngine: nil, logger: nil, mutex: sync.RWMutex{}, @@ -91,7 +99,7 @@ func (e *Engine) Prepare() *Engine { e.DefaultError = "error" } if e.LogFormatter == nil { - e.LogFormatter = DefaultLogFormatter() + e.LogFormatter = logger.DefaultLogFormatter() } if e.LocaleMatcher == nil { e.LocaleMatcher = DefaultLocalizerMatcher() @@ -107,10 +115,10 @@ func (e *Engine) Prepare() *Engine { e.Localizer.Preload(e.PreloadLanguages) } - e.createDB(e.Config.GetDBConfig()) + e.CreateDB(e.Config.GetDBConfig()) e.createRavenClient(e.Config.GetSentryDSN()) - e.resetUtils(e.Config.GetAWSConfig(), e.Config.IsDebug(), 0) - e.SetLogger(NewLogger(e.Config.GetTransportInfo().GetCode(), e.Config.GetLogLevel(), e.LogFormatter)) + e.ResetUtils(e.Config.GetAWSConfig(), e.Config.IsDebug(), 0) + e.SetLogger(logger.NewStandard(e.Config.GetTransportInfo().GetCode(), e.Config.GetLogLevel(), e.LogFormatter)) e.Sentry.Localizer = &e.Localizer e.Sentry.Stacktrace = true e.Utils.Logger = e.Logger() @@ -176,12 +184,12 @@ func (e *Engine) JobManager() *JobManager { } // Logger returns current logger. -func (e *Engine) Logger() LoggerInterface { +func (e *Engine) Logger() logger.Logger { return e.logger } // SetLogger sets provided logger instance to engine. -func (e *Engine) SetLogger(l LoggerInterface) *Engine { +func (e *Engine) SetLogger(l logger.Logger) *Engine { if l == nil { return e } @@ -194,11 +202,11 @@ func (e *Engine) SetLogger(l LoggerInterface) *Engine { // BuildHTTPClient builds HTTP client with provided configuration. func (e *Engine) BuildHTTPClient(certs *x509.CertPool, replaceDefault ...bool) *Engine { - client, err := NewHTTPClientBuilder(). + client, err := httputil.NewHTTPClientBuilder(). WithLogger(e.Logger()). SetLogging(e.Config.IsDebug()). SetCertPool(certs). - FromEngine(e). + FromConfig(e.GetHTTPClientConfig()). Build(replaceDefault...) if err != nil { @@ -211,7 +219,7 @@ func (e *Engine) BuildHTTPClient(certs *x509.CertPool, replaceDefault ...bool) * } // GetHTTPClientConfig returns configuration for HTTP client. -func (e *Engine) GetHTTPClientConfig() *HTTPClientConfig { +func (e *Engine) GetHTTPClientConfig() *config.HTTPClientConfig { if e.Config.GetHTTPClientConfig() != nil { return e.Config.GetHTTPClientConfig() } @@ -266,12 +274,13 @@ func (e *Engine) WithFilesystemSessions(path string, keyLength ...int) *Engine { // InitCSRF initializes CSRF middleware. engine.Sessions must be already initialized, // use engine.WithCookieStore or engine.WithFilesystemStore for that. // Syntax is similar to core.NewCSRF, but you shouldn't pass sessionName, store and salt. -func (e *Engine) InitCSRF(secret string, abortFunc CSRFAbortFunc, getter CSRFTokenGetter) *Engine { +func (e *Engine) InitCSRF( + secret string, abortFunc middleware.CSRFAbortFunc, getter middleware.CSRFTokenGetter) *Engine { if e.Sessions == nil { panic("engine.Sessions must be initialized first") } - e.csrf = NewCSRF("", secret, "", e.Sessions, abortFunc, getter) + e.csrf = middleware.NewCSRF("", secret, "", e.Sessions, abortFunc, getter) return e } diff --git a/core/engine_test.go b/core/engine_test.go index 0885b7d..487e640 100644 --- a/core/engine_test.go +++ b/core/engine_test.go @@ -17,6 +17,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/retailcrm/mg-transport-core/v2/core/config" + "github.com/retailcrm/mg-transport-core/v2/core/middleware" + "github.com/retailcrm/mg-transport-core/v2/core/util/httputil" + + "github.com/retailcrm/mg-transport-core/v2/core/logger" ) type EngineTest struct { @@ -38,10 +44,10 @@ func (e *EngineTest) SetupTest() { createTestLangFiles(e.T()) - e.engine.Config = Config{ + e.engine.Config = config.Config{ Version: "1", LogLevel: 5, - Database: DatabaseConfig{ + Database: config.DatabaseConfig{ Connection: db, Logging: true, TablePrefix: "", @@ -50,14 +56,14 @@ func (e *EngineTest) SetupTest() { ConnectionLifetime: 60, }, SentryDSN: "sentry dsn", - HTTPServer: HTTPServerConfig{ + HTTPServer: config.HTTPServerConfig{ Host: "0.0.0.0", Listen: ":3001", }, Debug: true, UpdateInterval: 30, - ConfigAWS: ConfigAWS{}, - TransportInfo: Info{ + ConfigAWS: config.AWS{}, + TransportInfo: config.Info{ Name: "test", Code: "test", LogoPath: "/test.svg", @@ -113,7 +119,7 @@ func (e *EngineTest) Test_Prepare() { func (e *EngineTest) Test_initGin_Release() { engine := New() - engine.Config = Config{Debug: false} + engine.Config = config.Config{Debug: false} engine.initGin() assert.NotNil(e.T(), engine.ginEngine) } @@ -169,8 +175,8 @@ func (e *EngineTest) Test_ConfigureRouter() { } func (e *EngineTest) Test_BuildHTTPClient() { - e.engine.Config = &Config{ - HTTPClientConfig: &HTTPClientConfig{ + e.engine.Config = &config.Config{ + HTTPClientConfig: &config.HTTPClientConfig{ Timeout: 30, SSLVerification: boolPtr(true), }, @@ -186,7 +192,7 @@ func (e *EngineTest) Test_BuildHTTPClient() { } func (e *EngineTest) Test_BuildHTTPClient_NoConfig() { - e.engine.Config = &Config{} + e.engine.Config = &config.Config{} e.engine.BuildHTTPClient(x509.NewCertPool()) assert.NotNil(e.T(), e.engine.httpClient) @@ -198,11 +204,11 @@ func (e *EngineTest) Test_BuildHTTPClient_NoConfig() { } func (e *EngineTest) Test_GetHTTPClientConfig() { - e.engine.Config = &Config{} + e.engine.Config = &config.Config{} assert.Equal(e.T(), DefaultHTTPClientConfig, e.engine.GetHTTPClientConfig()) - e.engine.Config = &Config{ - HTTPClientConfig: &HTTPClientConfig{ + e.engine.Config = &config.Config{ + HTTPClientConfig: &config.HTTPClientConfig{ Timeout: 10, SSLVerification: boolPtr(true), }, @@ -230,7 +236,7 @@ func (e *EngineTest) Test_SetLogger() { defer func() { e.engine.logger = origLogger }() - e.engine.logger = &Logger{} + e.engine.logger = &logger.StandardLogger{} e.engine.SetLogger(nil) assert.NotNil(e.T(), e.engine.logger) } @@ -241,7 +247,7 @@ func (e *EngineTest) Test_SetHTTPClient() { e.engine.httpClient = origClient }() e.engine.httpClient = nil - httpClient, err := NewHTTPClientBuilder().Build() + httpClient, err := httputil.NewHTTPClientBuilder().Build() require.NoError(e.T(), err) assert.NotNil(e.T(), httpClient) e.engine.SetHTTPClient(&http.Client{}) @@ -257,7 +263,7 @@ func (e *EngineTest) Test_HTTPClient() { }() e.engine.httpClient = nil require.Same(e.T(), http.DefaultClient, e.engine.HTTPClient()) - httpClient, err := NewHTTPClientBuilder().Build() + httpClient, err := httputil.NewHTTPClientBuilder().Build() require.NoError(e.T(), err) e.engine.httpClient = httpClient assert.Same(e.T(), httpClient, e.engine.HTTPClient()) @@ -270,7 +276,7 @@ func (e *EngineTest) Test_InitCSRF_Fail() { e.engine.csrf = nil e.engine.Sessions = nil - e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter) assert.Nil(e.T(), e.engine.csrf) } @@ -281,7 +287,7 @@ func (e *EngineTest) Test_InitCSRF() { e.engine.csrf = nil e.engine.WithCookieSessions(4) - e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter) assert.NotNil(e.T(), e.engine.csrf) } @@ -291,7 +297,7 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware_Fail() { }() e.engine.csrf = nil - e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods) + e.engine.VerifyCSRFMiddleware(middleware.DefaultIgnoredMethods) } func (e *EngineTest) Test_VerifyCSRFMiddleware() { @@ -301,8 +307,8 @@ func (e *EngineTest) Test_VerifyCSRFMiddleware() { e.engine.csrf = nil e.engine.WithCookieSessions(4) - e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) - e.engine.VerifyCSRFMiddleware(DefaultIgnoredMethods) + e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter) + e.engine.VerifyCSRFMiddleware(middleware.DefaultIgnoredMethods) } func (e *EngineTest) Test_GenerateCSRFMiddleware_Fail() { @@ -321,7 +327,7 @@ func (e *EngineTest) Test_GenerateCSRFMiddleware() { e.engine.csrf = nil e.engine.WithCookieSessions(4) - e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter) e.engine.GenerateCSRFMiddleware() } @@ -350,7 +356,7 @@ func (e *EngineTest) Test_GetCSRFToken() { e.engine.csrf = nil e.engine.WithCookieSessions(4) - e.engine.InitCSRF("test", func(context *gin.Context, r CSRFErrorReason) {}, DefaultCSRFTokenGetter) + e.engine.InitCSRF("test", func(context *gin.Context, r middleware.CSRFErrorReason) {}, middleware.DefaultCSRFTokenGetter) assert.NotEmpty(e.T(), e.engine.GetCSRFToken(c)) assert.Equal(e.T(), "token", e.engine.GetCSRFToken(c)) } diff --git a/core/error.go b/core/error.go deleted file mode 100644 index 1488c03..0000000 --- a/core/error.go +++ /dev/null @@ -1,36 +0,0 @@ -package core - -import "net/http" - -// ErrorResponse struct. -type ErrorResponse struct { - Error string `json:"error"` -} - -// ErrorsResponse struct. -type ErrorsResponse struct { - Error []string `json:"error"` -} - -// GetErrorResponse returns ErrorResponse with specified status code -// Usage (with gin): -// context.JSON(GetErrorResponse(http.StatusPaymentRequired, "Not enough money")) -func GetErrorResponse(statusCode int, error string) (int, interface{}) { - return statusCode, ErrorResponse{ - Error: error, - } -} - -// BadRequest returns ErrorResponse with code 400 -// Usage (with gin): -// context.JSON(BadRequest("invalid data")) -func BadRequest(error string) (int, interface{}) { - return GetErrorResponse(http.StatusBadRequest, error) -} - -// InternalServerError returns ErrorResponse with code 500 -// Usage (with gin): -// context.JSON(BadRequest("invalid data")) -func InternalServerError(error string) (int, interface{}) { - return GetErrorResponse(http.StatusInternalServerError, error) -} diff --git a/core/error_collector.go b/core/error_collector.go deleted file mode 100644 index f494ebb..0000000 --- a/core/error_collector.go +++ /dev/null @@ -1,38 +0,0 @@ -package core - -import ( - "github.com/pkg/errors" -) - -// ErrorCollector can be used to group several error calls into one call. -// It is mostly useful in GORM migrations, where you should return only one errors, but several can occur. -// Error messages are composed into one message. For example: -// err := core.ErrorCollector( -// errors.New("first"), -// errors.New("second") -// ) -// -// // Will output `first < second` -// fmt.Println(err.Error()) -// Example with GORM migration, returns one migration error with all error messages: -// return core.ErrorCollector( -// db.CreateTable(models.Account{}, models.Connection{}).Error, -// db.Table("account").AddUniqueIndex("account_key", "channel").Error, -// ) -func ErrorCollector(errorsList ...error) error { - var errorMsg string - - for _, errItem := range errorsList { - if errItem == nil { - continue - } - - errorMsg += "< " + errItem.Error() + " " - } - - if errorMsg != "" { - return errors.New(errorMsg[2 : len(errorMsg)-1]) - } - - return nil -} diff --git a/core/error_collector_test.go b/core/error_collector_test.go deleted file mode 100644 index 5eaf6ea..0000000 --- a/core/error_collector_test.go +++ /dev/null @@ -1,40 +0,0 @@ -package core - -import ( - "errors" - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestErrorCollector_NoError(t *testing.T) { - err := ErrorCollector(nil, nil, nil) - - assert.NoError(t, err) - assert.Nil(t, err) -} - -func TestErrorCollector_SeveralErrors(t *testing.T) { - err := ErrorCollector(nil, errors.New("error text"), nil) - - assert.Error(t, err) - assert.Equal(t, "error text", err.Error()) -} - -func TestErrorCollector_EmptyErrorMessage(t *testing.T) { - err := ErrorCollector(nil, errors.New(""), nil) - - assert.Error(t, err) - assert.Equal(t, "", err.Error()) -} - -func TestErrorCollector_AllErrors(t *testing.T) { - err := ErrorCollector( - errors.New("first"), - errors.New("second"), - errors.New("third"), - ) - - assert.Error(t, err) - assert.Equal(t, "first < second < third", err.Error()) -} diff --git a/core/job_manager.go b/core/job_manager.go index 2d9bd7a..d6fc49a 100644 --- a/core/job_manager.go +++ b/core/job_manager.go @@ -6,31 +6,28 @@ import ( "sync" "time" - "github.com/op/go-logging" + "github.com/retailcrm/mg-transport-core/v2/core/logger" ) // JobFunc is empty func which should be executed in a parallel goroutine. -type JobFunc func(JobLogFunc) error - -// JobLogFunc is a function which logs data from job. -type JobLogFunc func(string, logging.Level, ...interface{}) +type JobFunc func(logger.Logger) error // JobErrorHandler is a function to handle jobs errors. First argument is a job name. -type JobErrorHandler func(string, error, JobLogFunc) +type JobErrorHandler func(string, error, logger.Logger) // JobPanicHandler is a function to handle jobs panics. First argument is a job name. -type JobPanicHandler func(string, interface{}, JobLogFunc) +type JobPanicHandler func(string, interface{}, logger.Logger) // Job represents single job. Regular job will be executed every Interval. type Job struct { Command JobFunc ErrorHandler JobErrorHandler PanicHandler JobPanicHandler + stopChannel chan bool Interval time.Duration writeLock sync.RWMutex Regular bool active bool - stopChannel chan bool } // JobManager controls jobs execution flow. Jobs can be added just for later use (e.g. JobManager can be used as @@ -39,9 +36,9 @@ type Job struct { // SetLogger(logger). // SetLogging(false) // _ = manager.RegisterJob("updateTokens", &Job{ -// Command: func(logFunc JobLogFunc) error { +// Command: func(log logger.Logger) error { // // logic goes here... -// logFunc("All tokens were updated successfully", logging.INFO) +// logger.Info("All tokens were updated successfully") // return nil // }, // ErrorHandler: DefaultJobErrorHandler(), @@ -51,13 +48,14 @@ type Job struct { // }) // manager.Start() type JobManager struct { + logger logger.Logger + nilLogger logger.Logger jobs *sync.Map enableLogging bool - logger LoggerInterface } // getWrappedFunc wraps job into function. -func (j *Job) getWrappedFunc(name string, log JobLogFunc) func() { +func (j *Job) getWrappedFunc(name string, log logger.Logger) func() { return func() { defer func() { if r := recover(); r != nil && j.PanicHandler != nil { @@ -72,7 +70,7 @@ func (j *Job) getWrappedFunc(name string, log JobLogFunc) func() { } // getWrappedTimerFunc returns job timer func to run in the separate goroutine. -func (j *Job) getWrappedTimerFunc(name string, log JobLogFunc) func(chan bool) { +func (j *Job) getWrappedTimerFunc(name string, log logger.Logger) func(chan bool) { return func(stopChannel chan bool) { for range time.NewTicker(j.Interval).C { select { @@ -86,7 +84,7 @@ func (j *Job) getWrappedTimerFunc(name string, log JobLogFunc) func(chan bool) { } // run job. -func (j *Job) run(name string, log JobLogFunc) *Job { +func (j *Job) run(name string, log logger.Logger) { j.writeLock.RLock() if j.Regular && j.Interval > 0 && !j.active { @@ -100,12 +98,10 @@ func (j *Job) run(name string, log JobLogFunc) *Job { } else { j.writeLock.RUnlock() } - - return j } // stop running job. -func (j *Job) stop() *Job { +func (j *Job) stop() { j.writeLock.RLock() if j.active && j.stopChannel != nil { @@ -119,47 +115,43 @@ func (j *Job) stop() *Job { } else { j.writeLock.RUnlock() } - - return j } // runOnce run job once. -func (j *Job) runOnce(name string, log JobLogFunc) *Job { +func (j *Job) runOnce(name string, log logger.Logger) { go j.getWrappedFunc(name, log)() - return j } // runOnceSync run job once in current goroutine. -func (j *Job) runOnceSync(name string, log JobLogFunc) *Job { +func (j *Job) runOnceSync(name string, log logger.Logger) { j.getWrappedFunc(name, log)() - return j } // NewJobManager is a JobManager constructor. func NewJobManager() *JobManager { - return &JobManager{jobs: &sync.Map{}} + return &JobManager{jobs: &sync.Map{}, nilLogger: logger.NewNil()} } // DefaultJobErrorHandler returns default error handler for a job. func DefaultJobErrorHandler() JobErrorHandler { - return func(name string, err error, log JobLogFunc) { + return func(name string, err error, log logger.Logger) { if err != nil && name != "" { - log("Job `%s` errored with an error: `%s`", logging.ERROR, name, err.Error()) + log.Errorf("Job `%s` errored with an error: `%s`", name, err.Error()) } } } // DefaultJobPanicHandler returns default panic handler for a job. func DefaultJobPanicHandler() JobPanicHandler { - return func(name string, recoverValue interface{}, log JobLogFunc) { + return func(name string, recoverValue interface{}, log logger.Logger) { if recoverValue != nil && name != "" { - log("Job `%s` panicked with value: `%#v`", logging.ERROR, name, recoverValue) + log.Errorf("Job `%s` panicked with value: `%#v`", name, recoverValue) } } } // SetLogger sets logger into JobManager. -func (j *JobManager) SetLogger(logger LoggerInterface) *JobManager { +func (j *JobManager) SetLogger(logger logger.Logger) *JobManager { if logger != nil { j.logger = logger } @@ -167,6 +159,14 @@ func (j *JobManager) SetLogger(logger LoggerInterface) *JobManager { return j } +// Logger returns logger. +func (j *JobManager) Logger() logger.Logger { + if !j.enableLogging { + return j.nilLogger + } + return j.logger +} + // SetLogging enables or disables JobManager logging. func (j *JobManager) SetLogging(enableLogging bool) *JobManager { j.enableLogging = enableLogging @@ -211,7 +211,7 @@ func (j *JobManager) FetchJob(name string) (value *Job, ok bool) { // UpdateJob updates job. func (j *JobManager) UpdateJob(name string, job *Job) error { - if job, ok := j.FetchJob(name); ok { + if _, ok := j.FetchJob(name); ok { _ = j.UnregisterJob(name) return j.RegisterJob(name, job) } @@ -219,10 +219,11 @@ func (j *JobManager) UpdateJob(name string, job *Job) error { return fmt.Errorf("cannot find job `%s`", name) } -// RunJob starts provided regular job if it's exists. It's async operation and error returns only of job wasn't executed at all. +// RunJob starts provided regular job if it's exists. +// It runs asynchronously and error returns only of job wasn't executed at all. func (j *JobManager) RunJob(name string) error { if job, ok := j.FetchJob(name); ok { - job.run(name, j.log) + job.run(name, j.Logger()) return nil } @@ -242,7 +243,7 @@ func (j *JobManager) StopJob(name string) error { // RunJobOnce starts provided job once if it exists. It's also async. func (j *JobManager) RunJobOnce(name string) error { if job, ok := j.FetchJob(name); ok { - job.runOnce(name, j.log) + job.runOnce(name, j.Logger()) return nil } @@ -252,7 +253,7 @@ func (j *JobManager) RunJobOnce(name string) error { // RunJobOnceSync starts provided job once in current goroutine if job exists. Will wait for job to end it's work. func (j *JobManager) RunJobOnceSync(name string) error { if job, ok := j.FetchJob(name); ok { - job.runOnceSync(name, j.log) + job.runOnceSync(name, j.Logger()) return nil } @@ -264,48 +265,7 @@ func (j *JobManager) Start() { j.jobs.Range(func(key, value interface{}) bool { name := key.(string) job := value.(*Job) - job.run(name, j.log) + job.run(name, j.Logger()) return true }) } - -// log logs via logger or as plaintext. -func (j *JobManager) log(format string, severity logging.Level, args ...interface{}) { - if !j.enableLogging { - return - } - - if j.logger != nil { - switch severity { - case logging.CRITICAL: - j.logger.Criticalf(format, args...) - case logging.ERROR: - j.logger.Errorf(format, args...) - case logging.WARNING: - j.logger.Warningf(format, args...) - case logging.NOTICE: - j.logger.Noticef(format, args...) - case logging.INFO: - j.logger.Infof(format, args...) - case logging.DEBUG: - j.logger.Debugf(format, args...) - } - - return - } - - switch severity { - case logging.CRITICAL: - fmt.Print("[CRITICAL] ", fmt.Sprintf(format, args...)) - case logging.ERROR: - fmt.Print("[ERROR] ", fmt.Sprintf(format, args...)) - case logging.WARNING: - fmt.Print("[WARNING] ", fmt.Sprintf(format, args...)) - case logging.NOTICE: - fmt.Print("[NOTICE] ", fmt.Sprintf(format, args...)) - case logging.INFO: - fmt.Print("[INFO] ", fmt.Sprintf(format, args...)) - case logging.DEBUG: - fmt.Print("[DEBUG] ", fmt.Sprintf(format, args...)) - } -} diff --git a/core/job_manager_test.go b/core/job_manager_test.go index 6ae79bb..b71a86a 100644 --- a/core/job_manager_test.go +++ b/core/job_manager_test.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math/rand" + "strings" "sync" "testing" "time" @@ -12,18 +13,20 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/retailcrm/mg-transport-core/v2/core/logger" ) type JobTest struct { suite.Suite job *Job - syncBool bool executedChan chan bool randomNumber chan int executeErr chan error panicValue chan interface{} lastLog string lastMsgLevel logging.Level + syncBool bool } type JobManagerTest struct { @@ -33,6 +36,70 @@ type JobManagerTest struct { syncRunnerFlag bool } +type callbackLoggerFunc func(level logging.Level, format string, args ...interface{}) + +type callbackLogger struct { + fn callbackLoggerFunc +} + +func (n *callbackLogger) Fatal(args ...interface{}) { + n.fn(logging.CRITICAL, "", args...) +} + +func (n *callbackLogger) Fatalf(format string, args ...interface{}) { + n.fn(logging.CRITICAL, format, args...) +} + +func (n *callbackLogger) Panic(args ...interface{}) { + n.fn(logging.CRITICAL, "", args...) +} +func (n *callbackLogger) Panicf(format string, args ...interface{}) { + n.fn(logging.CRITICAL, format, args...) +} + +func (n *callbackLogger) Critical(args ...interface{}) { + n.fn(logging.CRITICAL, "", args...) +} + +func (n *callbackLogger) Criticalf(format string, args ...interface{}) { + n.fn(logging.CRITICAL, format, args...) +} + +func (n *callbackLogger) Error(args ...interface{}) { + n.fn(logging.ERROR, "", args...) +} +func (n *callbackLogger) Errorf(format string, args ...interface{}) { + n.fn(logging.ERROR, format, args...) +} + +func (n *callbackLogger) Warning(args ...interface{}) { + n.fn(logging.WARNING, "", args...) +} +func (n *callbackLogger) Warningf(format string, args ...interface{}) { + n.fn(logging.WARNING, format, args...) +} + +func (n *callbackLogger) Notice(args ...interface{}) { + n.fn(logging.NOTICE, "", args...) +} +func (n *callbackLogger) Noticef(format string, args ...interface{}) { + n.fn(logging.NOTICE, format, args...) +} + +func (n *callbackLogger) Info(args ...interface{}) { + n.fn(logging.INFO, "", args...) +} +func (n *callbackLogger) Infof(format string, args ...interface{}) { + n.fn(logging.INFO, format, args...) +} + +func (n *callbackLogger) Debug(args ...interface{}) { + n.fn(logging.DEBUG, "", args...) +} +func (n *callbackLogger) Debugf(format string, args ...interface{}) { + n.fn(logging.DEBUG, format, args...) +} + func TestJob(t *testing.T) { suite.Run(t, new(JobTest)) } @@ -48,10 +115,10 @@ func TestDefaultJobErrorHandler(t *testing.T) { fn := DefaultJobErrorHandler() require.NotNil(t, fn) - fn("job", errors.New("test"), func(s string, level logging.Level, i ...interface{}) { + fn("job", errors.New("test"), &callbackLogger{fn: func(level logging.Level, s string, i ...interface{}) { require.Len(t, i, 2) assert.Equal(t, fmt.Sprintf("%s", i[1]), "test") - }) + }}) } func TestDefaultJobPanicHandler(t *testing.T) { @@ -61,41 +128,52 @@ func TestDefaultJobPanicHandler(t *testing.T) { fn := DefaultJobPanicHandler() require.NotNil(t, fn) - fn("job", errors.New("test"), func(s string, level logging.Level, i ...interface{}) { + fn("job", errors.New("test"), &callbackLogger{fn: func(level logging.Level, s string, i ...interface{}) { require.Len(t, i, 2) assert.Equal(t, fmt.Sprintf("%s", i[1]), "test") - }) + }}) } func (t *JobTest) testErrorHandler() JobErrorHandler { - return func(name string, err error, logFunc JobLogFunc) { + return func(name string, err error, log logger.Logger) { t.executeErr <- err } } func (t *JobTest) testPanicHandler() JobPanicHandler { - return func(name string, i interface{}, logFunc JobLogFunc) { + return func(name string, i interface{}, log logger.Logger) { t.panicValue <- i } } -func (t *JobTest) testLogFunc() JobLogFunc { - return func(s string, level logging.Level, i ...interface{}) { - t.lastLog = fmt.Sprintf(s, i...) +func (t *JobTest) testLogger() logger.Logger { + return &callbackLogger{fn: func(level logging.Level, format string, args ...interface{}) { + if format == "" { + var sb strings.Builder + sb.Grow(3 * len(args)) // nolint:gomnd + + for i := 0; i < len(args); i++ { + sb.WriteString("%v ") + } + + format = strings.TrimRight(sb.String(), " ") + } + + t.lastLog = fmt.Sprintf(format, args...) t.lastMsgLevel = level - } + }} } -func (t *JobTest) executed(wait time.Duration, defaultVal bool) bool { +func (t *JobTest) executed() bool { if t.executedChan == nil { - return defaultVal + return false } select { case c := <-t.executedChan: return c - case <-time.After(wait): - return defaultVal + case <-time.After(time.Millisecond): + return false } } @@ -140,7 +218,7 @@ func (t *JobTest) clear() { func (t *JobTest) onceJob() { t.job = &Job{ - Command: func(logFunc JobLogFunc) error { + Command: func(log logger.Logger) error { t.executedChan <- true return nil }, @@ -153,7 +231,7 @@ func (t *JobTest) onceJob() { func (t *JobTest) onceErrorJob() { t.job = &Job{ - Command: func(logFunc JobLogFunc) error { + Command: func(log logger.Logger) error { t.executedChan <- true return errors.New("test error") }, @@ -166,7 +244,7 @@ func (t *JobTest) onceErrorJob() { func (t *JobTest) oncePanicJob() { t.job = &Job{ - Command: func(logFunc JobLogFunc) error { + Command: func(log logger.Logger) error { t.executedChan <- true panic("test panic") }, @@ -180,9 +258,9 @@ func (t *JobTest) oncePanicJob() { func (t *JobTest) regularJob() { rand.Seed(time.Now().UnixNano()) t.job = &Job{ - Command: func(logFunc JobLogFunc) error { + Command: func(log logger.Logger) error { t.executedChan <- true - t.randomNumber <- rand.Int() + t.randomNumber <- rand.Int() // nolint:gosec return nil }, ErrorHandler: t.testErrorHandler(), @@ -195,7 +273,7 @@ func (t *JobTest) regularJob() { func (t *JobTest) regularSyncJob() { rand.Seed(time.Now().UnixNano()) t.job = &Job{ - Command: func(logFunc JobLogFunc) error { + Command: func(log logger.Logger) error { t.syncBool = true return nil }, @@ -213,10 +291,10 @@ func (t *JobTest) Test_getWrappedFunc() { t.clear() t.onceJob() - fn := t.job.getWrappedFunc("job", t.testLogFunc()) + fn := t.job.getWrappedFunc("job", t.testLogger()) require.NotNil(t.T(), fn) go fn() - assert.True(t.T(), t.executed(time.Millisecond, false)) + assert.True(t.T(), t.executed()) assert.False(t.T(), t.errored(time.Millisecond)) assert.False(t.T(), t.panicked(time.Millisecond)) } @@ -228,10 +306,10 @@ func (t *JobTest) Test_getWrappedFuncError() { t.clear() t.onceErrorJob() - fn := t.job.getWrappedFunc("job", t.testLogFunc()) + fn := t.job.getWrappedFunc("job", t.testLogger()) require.NotNil(t.T(), fn) go fn() - assert.True(t.T(), t.executed(time.Millisecond, false)) + assert.True(t.T(), t.executed()) assert.True(t.T(), t.errored(time.Millisecond)) assert.False(t.T(), t.panicked(time.Millisecond)) } @@ -243,10 +321,10 @@ func (t *JobTest) Test_getWrappedFuncPanic() { t.clear() t.oncePanicJob() - fn := t.job.getWrappedFunc("job", t.testLogFunc()) + fn := t.job.getWrappedFunc("job", t.testLogger()) require.NotNil(t.T(), fn) go fn() - assert.True(t.T(), t.executed(time.Millisecond, false)) + assert.True(t.T(), t.executed()) assert.False(t.T(), t.errored(time.Millisecond)) assert.True(t.T(), t.panicked(time.Millisecond)) } @@ -257,10 +335,10 @@ func (t *JobTest) Test_run() { }() t.regularJob() - t.job.run("job", t.testLogFunc()) - time.Sleep(time.Millisecond * 5) + t.job.run("job", t.testLogger()) + time.Sleep(time.Millisecond * 10) t.job.stop() - require.True(t.T(), t.executed(time.Millisecond, false)) + require.True(t.T(), t.executed()) } func (t *JobTest) Test_runOnce() { @@ -269,9 +347,9 @@ func (t *JobTest) Test_runOnce() { }() t.regularJob() - t.job.runOnce("job", t.testLogFunc()) + t.job.runOnce("job", t.testLogger()) time.Sleep(time.Millisecond * 5) - require.True(t.T(), t.executed(time.Millisecond, false)) + require.True(t.T(), t.executed()) first := 0 select { @@ -301,7 +379,7 @@ func (t *JobTest) Test_runOnceSync() { t.clear() t.regularSyncJob() require.False(t.T(), t.syncBool) - t.job.runOnceSync("job", t.testLogFunc()) + t.job.runOnceSync("job", t.testLogger()) assert.True(t.T(), t.syncBool) } @@ -326,11 +404,11 @@ func (t *JobManagerTest) WaitForJob() bool { func (t *JobManagerTest) Test_SetLogger() { t.manager.logger = nil - t.manager.SetLogger(NewLogger("test", logging.ERROR, DefaultLogFormatter())) - assert.IsType(t.T(), &Logger{}, t.manager.logger) + t.manager.SetLogger(logger.NewStandard("test", logging.ERROR, logger.DefaultLogFormatter())) + assert.IsType(t.T(), &logger.StandardLogger{}, t.manager.logger) t.manager.SetLogger(nil) - assert.IsType(t.T(), &Logger{}, t.manager.logger) + assert.IsType(t.T(), &logger.StandardLogger{}, t.manager.logger) } func (t *JobManagerTest) Test_SetLogging() { @@ -351,7 +429,7 @@ func (t *JobManagerTest) Test_RegisterJobNil() { func (t *JobManagerTest) Test_RegisterJob() { require.NotNil(t.T(), t.manager.jobs) err := t.manager.RegisterJob("job", &Job{ - Command: func(log JobLogFunc) error { + Command: func(log logger.Logger) error { t.runnerWG.Done() return nil }, @@ -360,7 +438,7 @@ func (t *JobManagerTest) Test_RegisterJob() { }) assert.NoError(t.T(), err) err = t.manager.RegisterJob("job_regular", &Job{ - Command: func(log JobLogFunc) error { + Command: func(log logger.Logger) error { t.runnerWG.Done() return nil }, @@ -371,7 +449,7 @@ func (t *JobManagerTest) Test_RegisterJob() { }) assert.NoError(t.T(), err) err = t.manager.RegisterJob("job_sync", &Job{ - Command: func(log JobLogFunc) error { + Command: func(log logger.Logger) error { t.syncRunnerFlag = true return nil }, @@ -398,7 +476,7 @@ func (t *JobManagerTest) Test_FetchJob() { require.Nil(t.T(), recover()) }() - require.NoError(t.T(), t.manager.RegisterJob("test_fetch", &Job{Command: func(logFunc JobLogFunc) error { + require.NoError(t.T(), t.manager.RegisterJob("test_fetch", &Job{Command: func(log logger.Logger) error { return nil }})) require.NotNil(t.T(), t.manager.jobs) @@ -479,8 +557,8 @@ func (t *JobManagerTest) Test_Start() { manager := NewJobManager() _ = manager.RegisterJob("job", &Job{ - Command: func(logFunc JobLogFunc) error { - logFunc("alive!", logging.INFO) + Command: func(log logger.Logger) error { + log.Info("alive!") return nil }, ErrorHandler: DefaultJobErrorHandler(), @@ -488,25 +566,3 @@ func (t *JobManagerTest) Test_Start() { }) manager.Start() } - -func (t *JobManagerTest) Test_log() { - defer func() { - require.Nil(t.T(), recover()) - }() - - testLog := func() { - t.manager.log("test", logging.CRITICAL) - t.manager.log("test", logging.ERROR) - t.manager.log("test", logging.WARNING) - t.manager.log("test", logging.NOTICE) - t.manager.log("test", logging.INFO) - t.manager.log("test", logging.DEBUG) - } - t.manager.SetLogging(false) - testLog() - t.manager.SetLogging(true) - t.manager.logger = nil - testLog() - t.manager.logger = NewLogger("test", logging.DEBUG, DefaultLogFormatter()) - testLog() -} diff --git a/core/localizer.go b/core/localizer.go index b27c831..278c0cc 100644 --- a/core/localizer.go +++ b/core/localizer.go @@ -11,6 +11,8 @@ import ( "github.com/nicksnyder/go-i18n/v2/i18n" "golang.org/x/text/language" "gopkg.in/yaml.v2" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" ) // DefaultLanguages for transports. @@ -138,11 +140,6 @@ func (l *Localizer) LocalizationFuncMap() template.FuncMap { } } -// getLocaleBundle returns current locale bundle and creates it if needed. -func (l *Localizer) getLocaleBundle() *i18n.Bundle { - return l.createLocaleBundleByTag(l.LanguageTag) -} - // createLocaleBundleByTag creates locale bundle by language tag. func (l *Localizer) createLocaleBundleByTag(tag language.Tag) *i18n.Bundle { bundle := i18n.NewBundle(tag) @@ -279,8 +276,8 @@ func (l *Localizer) GetLocalizedMessage(messageID string) string { return l.getCurrentLocalizer().MustLocalize(&i18n.LocalizeConfig{MessageID: messageID}) } -// GetLocalizedTemplateMessage will return localized message with specified data. It doesn't use `Must` prefix in order to keep BC. -// It uses text/template syntax: https://golang.org/pkg/text/template/ +// GetLocalizedTemplateMessage will return localized message with specified data. +// It doesn't use `Must` prefix in order to keep BC. It uses text/template syntax: https://golang.org/pkg/text/template/ func (l *Localizer) GetLocalizedTemplateMessage(messageID string, templateData map[string]interface{}) string { return l.getCurrentLocalizer().MustLocalize(&i18n.LocalizeConfig{ MessageID: messageID, @@ -302,13 +299,52 @@ func (l *Localizer) LocalizeTemplateMessage(messageID string, templateData map[s }) } -// BadRequestLocalized is same as BadRequest(string), but passed string will be localized. +// BadRequestLocalized is same as errorutil.BadRequest(string), but passed string will be localized. func (l *Localizer) BadRequestLocalized(err string) (int, interface{}) { - return BadRequest(l.GetLocalizedMessage(err)) + return errorutil.BadRequest(l.GetLocalizedMessage(err)) } -// GetContextLocalizer returns localizer from context if it exist there. -func GetContextLocalizer(c *gin.Context) (*Localizer, bool) { +// UnauthorizedLocalized is same as errorutil.Unauthorized(string), but passed string will be localized. +func (l *Localizer) UnauthorizedLocalized(err string) (int, interface{}) { + return errorutil.Unauthorized(l.GetLocalizedMessage(err)) +} + +// ForbiddenLocalized is same as errorutil.Forbidden(string), but passed string will be localized. +func (l *Localizer) ForbiddenLocalized(err string) (int, interface{}) { + return errorutil.Forbidden(l.GetLocalizedMessage(err)) +} + +// InternalServerErrorLocalized is same as errorutil.InternalServerError(string), but passed string will be localized. +func (l *Localizer) InternalServerErrorLocalized(err string) (int, interface{}) { + return errorutil.InternalServerError(l.GetLocalizedMessage(err)) +} + +// GetContextLocalizer returns localizer from context if it is present there. +// Language will be set using Accept-Language header and root language tag. +func GetContextLocalizer(c *gin.Context) (loc *Localizer, ok bool) { + loc, ok = extractLocalizerFromContext(c) + if loc != nil { + loc.SetLocale(c.GetHeader("Accept-Language")) + + lang := GetRootLanguageTag(loc.LanguageTag) + if lang != loc.LanguageTag { + loc.SetLanguage(lang) + loc.LoadTranslations() + } + } + return +} + +// MustGetContextLocalizer returns Localizer instance if it exists in provided context. Panics otherwise. +func MustGetContextLocalizer(c *gin.Context) *Localizer { + if localizer, ok := GetContextLocalizer(c); ok { + return localizer + } + panic("localizer is not present in provided context") +} + +// extractLocalizerFromContext returns localizer from context if it exist there. +func extractLocalizerFromContext(c *gin.Context) (*Localizer, bool) { if c == nil { return nil, false } @@ -322,10 +358,14 @@ func GetContextLocalizer(c *gin.Context) (*Localizer, bool) { return nil, false } -// MustGetContextLocalizer returns Localizer instance if it exists in provided context. Panics otherwise. -func MustGetContextLocalizer(c *gin.Context) *Localizer { - if localizer, ok := GetContextLocalizer(c); ok { - return localizer +// GetRootLanguageTag returns root language tag for country-specific tags (e.g "es" for "es_CA"). +// Useful when you don't have country-specific language variations. +func GetRootLanguageTag(t language.Tag) language.Tag { + for { + parent := t.Parent() + if parent == language.Und { + return t + } + t = parent } - panic("localizer is not present in provided context") } diff --git a/core/localizer_test.go b/core/localizer_test.go index 943667e..ee627a1 100644 --- a/core/localizer_test.go +++ b/core/localizer_test.go @@ -17,6 +17,8 @@ import ( "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "golang.org/x/text/language" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" ) var ( @@ -120,7 +122,7 @@ func (l *LocalizerTest) Test_LocalizationMiddleware_Httptest() { wg.Add(1) go func(m map[language.Tag]string, wg *sync.WaitGroup) { var tag language.Tag - switch rand.Intn(3-1) + 1 { + switch rand.Intn(3-1) + 1 { // nolint:gosec case 1: tag = language.English case 2: @@ -183,7 +185,7 @@ func (l *LocalizerTest) Test_BadRequestLocalized() { status, resp := l.localizer.BadRequestLocalized("message") assert.Equal(l.T(), http.StatusBadRequest, status) - assert.Equal(l.T(), "Test message", resp.(ErrorResponse).Error) + assert.Equal(l.T(), "Test message", resp.(errorutil.Response).Error) } // getContextWithLang generates context with Accept-Language header. diff --git a/core/logger/account_logger_decorator.go b/core/logger/account_logger_decorator.go new file mode 100644 index 0000000..e35d8e9 --- /dev/null +++ b/core/logger/account_logger_decorator.go @@ -0,0 +1,96 @@ +package logger + +import ( + "fmt" + + "github.com/op/go-logging" +) + +// DefaultAccountLoggerFormat contains default prefix format for the AccountLoggerDecorator. +// Its messages will look like this (assuming you will provide the connection URL and account name): +// messageHandler (https://any.simla.com => @tg_account): sent message with id=1 +const DefaultAccountLoggerFormat = "%s (%s => %s):" + +type ComponentAware interface { + SetComponent(string) +} + +type ConnectionAware interface { + SetConnectionIdentifier(string) +} + +type AccountAware interface { + SetAccountIdentifier(string) +} + +type PrefixFormatAware interface { + SetPrefixFormat(string) +} + +type AccountLogger interface { + PrefixedLogger + ComponentAware + ConnectionAware + AccountAware + PrefixFormatAware +} + +type AccountLoggerDecorator struct { + format string + component string + connIdentifier string + accIdentifier string + PrefixDecorator +} + +func DecorateForAccount(base Logger, component, connIdentifier, accIdentifier string) AccountLogger { + return (&AccountLoggerDecorator{ + PrefixDecorator: PrefixDecorator{ + backend: base, + }, + component: component, + connIdentifier: connIdentifier, + accIdentifier: accIdentifier, + }).updatePrefix() +} + +// NewForAccount returns logger for account. It uses StandardLogger under the hood. +func NewForAccount( + transportCode, component, connIdentifier, accIdentifier string, + logLevel logging.Level, + logFormat logging.Formatter) AccountLogger { + return DecorateForAccount(NewStandard(transportCode, logLevel, logFormat), + component, connIdentifier, accIdentifier) +} + +func (a *AccountLoggerDecorator) SetComponent(s string) { + a.component = s + a.updatePrefix() +} + +func (a *AccountLoggerDecorator) SetConnectionIdentifier(s string) { + a.connIdentifier = s + a.updatePrefix() +} + +func (a *AccountLoggerDecorator) SetAccountIdentifier(s string) { + a.accIdentifier = s + a.updatePrefix() +} + +func (a *AccountLoggerDecorator) SetPrefixFormat(s string) { + a.format = s + a.updatePrefix() +} + +func (a *AccountLoggerDecorator) updatePrefix() AccountLogger { + a.SetPrefix(fmt.Sprintf(a.prefixFormat(), a.component, a.connIdentifier, a.accIdentifier)) + return a +} + +func (a *AccountLoggerDecorator) prefixFormat() string { + if a.format == "" { + return DefaultAccountLoggerFormat + } + return a.format +} diff --git a/core/logger/account_logger_decorator_test.go b/core/logger/account_logger_decorator_test.go new file mode 100644 index 0000000..ab6a8b9 --- /dev/null +++ b/core/logger/account_logger_decorator_test.go @@ -0,0 +1,98 @@ +package logger + +import ( + "bytes" + "fmt" + "testing" + + "github.com/op/go-logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +const ( + testComponent = "ComponentName" + testConnectionID = "https://test.retailcrm.pro" + testAccountID = "@account_name" +) + +type AccountLoggerDecoratorTest struct { + suite.Suite + buf *bytes.Buffer + logger AccountLogger +} + +func TestAccountLoggerDecorator(t *testing.T) { + suite.Run(t, new(AccountLoggerDecoratorTest)) +} + +func TestNewForAccount(t *testing.T) { + buf := &bytes.Buffer{} + logger := NewForAccount("code", "component", "conn", "acc", logging.DEBUG, DefaultLogFormatter()) + logger.(*AccountLoggerDecorator).backend.(*StandardLogger). + SetBaseLogger(NewBase(buf, "code", logging.DEBUG, DefaultLogFormatter())) + logger.Debugf("message %s", "text") + + assert.Contains(t, buf.String(), fmt.Sprintf(DefaultAccountLoggerFormat+" message", "component", "conn", "acc")) +} + +func (t *AccountLoggerDecoratorTest) SetupSuite() { + t.buf = &bytes.Buffer{} + t.logger = DecorateForAccount((&StandardLogger{}). + SetBaseLogger(NewBase(t.buf, "code", logging.DEBUG, DefaultLogFormatter())), + testComponent, testConnectionID, testAccountID) +} + +func (t *AccountLoggerDecoratorTest) SetupTest() { + t.buf.Reset() + t.logger.SetComponent(testComponent) + t.logger.SetConnectionIdentifier(testConnectionID) + t.logger.SetAccountIdentifier(testAccountID) + t.logger.SetPrefixFormat(DefaultAccountLoggerFormat) +} + +func (t *AccountLoggerDecoratorTest) Test_LogWithNewFormat() { + t.logger.SetPrefixFormat("[%s (%s: %s)] =>") + t.logger.Infof("test message") + + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), + fmt.Sprintf("[%s (%s: %s)] =>", testComponent, testConnectionID, testAccountID)) +} + +func (t *AccountLoggerDecoratorTest) Test_Log() { + t.logger.Infof("test message") + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), + fmt.Sprintf(DefaultAccountLoggerFormat, testComponent, testConnectionID, testAccountID)) +} + +func (t *AccountLoggerDecoratorTest) Test_SetComponent() { + t.logger.SetComponent("NewComponent") + t.logger.Infof("test message") + + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), + fmt.Sprintf(DefaultAccountLoggerFormat, "NewComponent", testConnectionID, testAccountID)) +} + +func (t *AccountLoggerDecoratorTest) Test_SetConnectionIdentifier() { + t.logger.SetComponent("NewComponent") + t.logger.SetConnectionIdentifier("https://test.simla.com") + t.logger.Infof("test message") + + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), + fmt.Sprintf(DefaultAccountLoggerFormat, "NewComponent", "https://test.simla.com", testAccountID)) +} + +func (t *AccountLoggerDecoratorTest) Test_SetAccountIdentifier() { + t.logger.SetComponent("NewComponent") + t.logger.SetConnectionIdentifier("https://test.simla.com") + t.logger.SetAccountIdentifier("@new_account_name") + t.logger.Infof("test message") + + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), + fmt.Sprintf(DefaultAccountLoggerFormat, "NewComponent", "https://test.simla.com", "@new_account_name")) +} diff --git a/core/logger.go b/core/logger/logger.go similarity index 61% rename from core/logger.go rename to core/logger/logger.go index a7dcd40..32445f9 100644 --- a/core/logger.go +++ b/core/logger/logger.go @@ -1,14 +1,15 @@ -package core +package logger import ( + "io" "os" "sync" "github.com/op/go-logging" ) -// LoggerInterface contains methods which should be present in logger implementation. -type LoggerInterface interface { +// Logger contains methods which should be present in logger implementation. +type Logger interface { Fatal(args ...interface{}) Fatalf(format string, args ...interface{}) Panic(args ...interface{}) @@ -27,30 +28,30 @@ type LoggerInterface interface { Debugf(format string, args ...interface{}) } -// Logger component. Uses github.com/op/go-logging under the hood. +// StandardLogger is a default implementation of Logger. Uses github.com/op/go-logging under the hood. // This logger can prevent any write operations (disabled by default, use .Exclusive() method to enable). -type Logger struct { +type StandardLogger struct { logger *logging.Logger mutex *sync.RWMutex } -// NewLogger will create new goroutine-safe logger with specified formatter. +// NewStandard will create new StandardLogger with specified formatter. // Usage: // logger := NewLogger("telegram", logging.ERROR, DefaultLogFormatter()) -func NewLogger(transportCode string, logLevel logging.Level, logFormat logging.Formatter) *Logger { - return &Logger{ - logger: newInheritedLogger(transportCode, logLevel, logFormat), +func NewStandard(transportCode string, logLevel logging.Level, logFormat logging.Formatter) *StandardLogger { + return &StandardLogger{ + logger: NewBase(os.Stdout, transportCode, logLevel, logFormat), } } -// newInheritedLogger is a constructor for underlying logger in Logger struct. -func newInheritedLogger(transportCode string, logLevel logging.Level, logFormat logging.Formatter) *logging.Logger { +// NewBase is a constructor for underlying logger in the StandardLogger struct. +func NewBase(out io.Writer, transportCode string, logLevel logging.Level, logFormat logging.Formatter) *logging.Logger { logger := logging.MustGetLogger(transportCode) - logBackend := logging.NewLogBackend(os.Stdout, "", 0) + logBackend := logging.NewLogBackend(out, "", 0) formatBackend := logging.NewBackendFormatter(logBackend, logFormat) - backend1Leveled := logging.AddModuleLevel(logBackend) + backend1Leveled := logging.AddModuleLevel(formatBackend) backend1Leveled.SetLevel(logLevel, "") - logging.SetBackend(formatBackend) + logger.SetBackend(backend1Leveled) return logger } @@ -63,7 +64,7 @@ func DefaultLogFormatter() logging.Formatter { } // Exclusive makes logger goroutine-safe. -func (l *Logger) Exclusive() *Logger { +func (l *StandardLogger) Exclusive() *StandardLogger { if l.mutex == nil { l.mutex = &sync.RWMutex{} } @@ -71,127 +72,133 @@ func (l *Logger) Exclusive() *Logger { return l } +// SetBaseLogger replaces base logger with the provided instance. +func (l *StandardLogger) SetBaseLogger(logger *logging.Logger) *StandardLogger { + l.logger = logger + return l +} + // lock locks logger. -func (l *Logger) lock() { +func (l *StandardLogger) lock() { if l.mutex != nil { l.mutex.Lock() } } // unlock unlocks logger. -func (l *Logger) unlock() { +func (l *StandardLogger) unlock() { if l.mutex != nil { l.mutex.Unlock() } } // Fatal is equivalent to l.Critical(fmt.Sprint()) followed by a call to os.Exit(1). -func (l *Logger) Fatal(args ...interface{}) { +func (l *StandardLogger) Fatal(args ...interface{}) { l.lock() defer l.unlock() l.logger.Fatal(args...) } // Fatalf is equivalent to l.Critical followed by a call to os.Exit(1). -func (l *Logger) Fatalf(format string, args ...interface{}) { +func (l *StandardLogger) Fatalf(format string, args ...interface{}) { l.lock() defer l.unlock() l.logger.Fatalf(format, args...) } // Panic is equivalent to l.Critical(fmt.Sprint()) followed by a call to panic(). -func (l *Logger) Panic(args ...interface{}) { +func (l *StandardLogger) Panic(args ...interface{}) { l.lock() defer l.unlock() l.logger.Panic(args...) } // Panicf is equivalent to l.Critical followed by a call to panic(). -func (l *Logger) Panicf(format string, args ...interface{}) { +func (l *StandardLogger) Panicf(format string, args ...interface{}) { l.lock() defer l.unlock() l.logger.Panicf(format, args...) } // Critical logs a message using CRITICAL as log level. -func (l *Logger) Critical(args ...interface{}) { +func (l *StandardLogger) Critical(args ...interface{}) { l.lock() defer l.unlock() l.logger.Critical(args...) } // Criticalf logs a message using CRITICAL as log level. -func (l *Logger) Criticalf(format string, args ...interface{}) { +func (l *StandardLogger) Criticalf(format string, args ...interface{}) { l.lock() defer l.unlock() l.logger.Criticalf(format, args...) } // Error logs a message using ERROR as log level. -func (l *Logger) Error(args ...interface{}) { +func (l *StandardLogger) Error(args ...interface{}) { l.lock() defer l.unlock() l.logger.Error(args...) } // Errorf logs a message using ERROR as log level. -func (l *Logger) Errorf(format string, args ...interface{}) { +func (l *StandardLogger) Errorf(format string, args ...interface{}) { l.lock() defer l.unlock() l.logger.Errorf(format, args...) } // Warning logs a message using WARNING as log level. -func (l *Logger) Warning(args ...interface{}) { +func (l *StandardLogger) Warning(args ...interface{}) { l.lock() defer l.unlock() l.logger.Warning(args...) } // Warningf logs a message using WARNING as log level. -func (l *Logger) Warningf(format string, args ...interface{}) { +func (l *StandardLogger) Warningf(format string, args ...interface{}) { l.lock() defer l.unlock() l.logger.Warningf(format, args...) } // Notice logs a message using NOTICE as log level. -func (l *Logger) Notice(args ...interface{}) { +func (l *StandardLogger) Notice(args ...interface{}) { l.lock() defer l.unlock() l.logger.Notice(args...) } // Noticef logs a message using NOTICE as log level. -func (l *Logger) Noticef(format string, args ...interface{}) { +func (l *StandardLogger) Noticef(format string, args ...interface{}) { l.lock() defer l.unlock() l.logger.Noticef(format, args...) } // Info logs a message using INFO as log level. -func (l *Logger) Info(args ...interface{}) { +func (l *StandardLogger) Info(args ...interface{}) { l.lock() defer l.unlock() l.logger.Info(args...) } // Infof logs a message using INFO as log level. -func (l *Logger) Infof(format string, args ...interface{}) { +func (l *StandardLogger) Infof(format string, args ...interface{}) { l.lock() defer l.unlock() l.logger.Infof(format, args...) } // Debug logs a message using DEBUG as log level. -func (l *Logger) Debug(args ...interface{}) { +func (l *StandardLogger) Debug(args ...interface{}) { l.lock() defer l.unlock() l.logger.Debug(args...) } // Debugf logs a message using DEBUG as log level. -func (l *Logger) Debugf(format string, args ...interface{}) { +func (l *StandardLogger) Debugf(format string, args ...interface{}) { l.lock() defer l.unlock() l.logger.Debugf(format, args...) diff --git a/core/logger/logger_test.go b/core/logger/logger_test.go new file mode 100644 index 0000000..cb9c371 --- /dev/null +++ b/core/logger/logger_test.go @@ -0,0 +1,168 @@ +package logger + +import ( + "bytes" + "testing" + + "github.com/op/go-logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type StandardLoggerTest struct { + suite.Suite + logger *StandardLogger + buf *bytes.Buffer +} + +func TestLogger_NewLogger(t *testing.T) { + logger := NewStandard("code", logging.DEBUG, DefaultLogFormatter()) + assert.NotNil(t, logger) +} + +func TestLogger_DefaultLogFormatter(t *testing.T) { + formatter := DefaultLogFormatter() + + assert.NotNil(t, formatter) + assert.IsType(t, logging.MustStringFormatter(`%{message}`), formatter) +} + +func Test_Logger(t *testing.T) { + suite.Run(t, new(StandardLoggerTest)) +} + +func (t *StandardLoggerTest) SetupSuite() { + t.buf = &bytes.Buffer{} + t.logger = (&StandardLogger{}). + Exclusive(). + SetBaseLogger(NewBase(t.buf, "code", logging.DEBUG, DefaultLogFormatter())) +} + +func (t *StandardLoggerTest) SetupTest() { + t.buf.Reset() +} + +// TODO Cover Fatal and Fatalf (implementation below is no-op) +// func (t *StandardLoggerTest) Test_Fatal() { +// if os.Getenv("FLAG") == "1" { +// t.logger.Fatal("test", "fatal") +// return +// } + +// cmd := exec.Command(os.Args[0], "-test.run=TestGetConfig") +// cmd.Env = append(os.Environ(), "FLAG=1") +// err := cmd.Run() + +// e, ok := err.(*exec.ExitError) +// expectedErrorString := "test fatal" +// t.Assert().Equal(true, ok) +// t.Assert().Equal(expectedErrorString, e.Error()) +// } + +func (t *StandardLoggerTest) Test_Panic() { + defer func() { + t.Assert().NotNil(recover()) + t.Assert().Contains(t.buf.String(), "CRIT") + t.Assert().Contains(t.buf.String(), "panic") + }() + t.logger.Panic("panic") +} + +func (t *StandardLoggerTest) Test_Panicf() { + defer func() { + t.Assert().NotNil(recover()) + t.Assert().Contains(t.buf.String(), "CRIT") + t.Assert().Contains(t.buf.String(), "panicf") + }() + t.logger.Panicf("panicf") +} + +func (t *StandardLoggerTest) Test_Critical() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "CRIT") + t.Assert().Contains(t.buf.String(), "critical") + }() + t.logger.Critical("critical") +} + +func (t *StandardLoggerTest) Test_Criticalf() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "CRIT") + t.Assert().Contains(t.buf.String(), "critical") + }() + t.logger.Criticalf("critical") +} + +func (t *StandardLoggerTest) Test_Warning() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "WARN") + t.Assert().Contains(t.buf.String(), "warning") + }() + t.logger.Warning("warning") +} + +func (t *StandardLoggerTest) Test_Notice() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "NOTI") + t.Assert().Contains(t.buf.String(), "notice") + }() + t.logger.Notice("notice") +} + +func (t *StandardLoggerTest) Test_Info() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), "info") + }() + t.logger.Info("info") +} + +func (t *StandardLoggerTest) Test_Debug() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "DEBU") + t.Assert().Contains(t.buf.String(), "debug") + }() + t.logger.Debug("debug") +} + +func (t *StandardLoggerTest) Test_Warningf() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "WARN") + t.Assert().Contains(t.buf.String(), "warning") + }() + t.logger.Warningf("%s", "warning") +} + +func (t *StandardLoggerTest) Test_Noticef() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "NOTI") + t.Assert().Contains(t.buf.String(), "notice") + }() + t.logger.Noticef("%s", "notice") +} + +func (t *StandardLoggerTest) Test_Infof() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), "info") + }() + t.logger.Infof("%s", "info") +} + +func (t *StandardLoggerTest) Test_Debugf() { + defer func() { + t.Require().Nil(recover()) + t.Assert().Contains(t.buf.String(), "DEBU") + t.Assert().Contains(t.buf.String(), "debug") + }() + t.logger.Debugf("%s", "debug") +} diff --git a/core/logger/nil_logger.go b/core/logger/nil_logger.go new file mode 100644 index 0000000..440b4d0 --- /dev/null +++ b/core/logger/nil_logger.go @@ -0,0 +1,45 @@ +package logger + +import ( + "fmt" + "os" +) + +// Nil provides Logger implementation that does almost nothing when called. +// Panic, Panicf, Fatal and Fatalf methods still cause panic and immediate program termination respectively. +// All other methods won't do anything at all. +type Nil struct{} + +func (n Nil) Fatal(args ...interface{}) { + os.Exit(1) +} + +func (n Nil) Fatalf(format string, args ...interface{}) { + os.Exit(1) +} + +func (n Nil) Panic(args ...interface{}) { + panic(fmt.Sprint(args...)) +} + +func (n Nil) Panicf(format string, args ...interface{}) { + panic(fmt.Sprintf(format, args...)) +} + +func (n Nil) Critical(args ...interface{}) {} +func (n Nil) Criticalf(format string, args ...interface{}) {} +func (n Nil) Error(args ...interface{}) {} +func (n Nil) Errorf(format string, args ...interface{}) {} +func (n Nil) Warning(args ...interface{}) {} +func (n Nil) Warningf(format string, args ...interface{}) {} +func (n Nil) Notice(args ...interface{}) {} +func (n Nil) Noticef(format string, args ...interface{}) {} +func (n Nil) Info(args ...interface{}) {} +func (n Nil) Infof(format string, args ...interface{}) {} +func (n Nil) Debug(args ...interface{}) {} +func (n Nil) Debugf(format string, args ...interface{}) {} + +// NewNil is a Nil logger constructor. +func NewNil() Logger { + return &Nil{} +} diff --git a/core/logger/nil_logger_test.go b/core/logger/nil_logger_test.go new file mode 100644 index 0000000..f41c03e --- /dev/null +++ b/core/logger/nil_logger_test.go @@ -0,0 +1,91 @@ +package logger + +import ( + "bytes" + "io" + "os" + "testing" + "time" + + "github.com/stretchr/testify/suite" +) + +type NilTest struct { + suite.Suite + logger Logger + realStdout *os.File + r *os.File + w *os.File +} + +func TestNilLogger(t *testing.T) { + suite.Run(t, new(NilTest)) +} + +func (t *NilTest) SetupSuite() { + t.logger = NewNil() +} + +func (t *NilTest) SetupTest() { + t.realStdout = os.Stdout + t.r, t.w, _ = os.Pipe() + os.Stdout = t.w +} + +func (t *NilTest) TearDownTest() { + if t.realStdout != nil { + t.Require().NoError(t.w.Close()) + os.Stdout = t.realStdout + } +} + +func (t *NilTest) readStdout() string { + outC := make(chan string) + go func() { + var buf bytes.Buffer + _, err := io.Copy(&buf, t.r) + t.Require().NoError(err) + outC <- buf.String() + close(outC) + }() + + t.Require().NoError(t.w.Close()) + os.Stdout = t.realStdout + t.realStdout = nil + + select { + case c := <-outC: + return c + case <-time.After(time.Second): + return "" + } +} + +func (t *NilTest) Test_Noop() { + t.logger.Critical("message") + t.logger.Criticalf("message") + t.logger.Error("message") + t.logger.Errorf("message") + t.logger.Warning("message") + t.logger.Warningf("message") + t.logger.Notice("message") + t.logger.Noticef("message") + t.logger.Info("message") + t.logger.Infof("message") + t.logger.Debug("message") + t.logger.Debugf("message") + + t.Assert().Empty(t.readStdout()) +} + +func (t *NilTest) Test_Panic() { + t.Assert().Panics(func() { + t.logger.Panic("") + }) +} + +func (t *NilTest) Test_Panicf() { + t.Assert().Panics(func() { + t.logger.Panicf("") + }) +} diff --git a/core/logger/prefix_decorator.go b/core/logger/prefix_decorator.go new file mode 100644 index 0000000..d9d9572 --- /dev/null +++ b/core/logger/prefix_decorator.go @@ -0,0 +1,106 @@ +package logger + +import "github.com/op/go-logging" + +// PrefixAware is implemented if the logger allows you to change the prefix. +type PrefixAware interface { + SetPrefix(string) +} + +// PrefixedLogger is a base interface for the logger with prefix. +type PrefixedLogger interface { + Logger + PrefixAware +} + +// PrefixDecorator is an implementation of the PrefixedLogger. It will allow you to decorate any Logger with +// the provided predefined prefix. +type PrefixDecorator struct { + backend Logger + prefix []interface{} +} + +// DecorateWithPrefix using provided base logger and provided prefix. +// No internal state of the base logger will be touched. +func DecorateWithPrefix(backend Logger, prefix string) PrefixedLogger { + return &PrefixDecorator{backend: backend, prefix: []interface{}{prefix}} +} + +// NewWithPrefix returns logger with prefix. It uses StandardLogger under the hood. +func NewWithPrefix(transportCode, prefix string, logLevel logging.Level, logFormat logging.Formatter) PrefixedLogger { + return DecorateWithPrefix(NewStandard(transportCode, logLevel, logFormat), prefix) +} + +// SetPrefix will replace existing prefix with the provided value. +// Use this format for prefixes: "prefix here:" - omit space at the end (it will be inserted automatically). +func (p *PrefixDecorator) SetPrefix(prefix string) { + p.prefix = []interface{}{prefix} +} + +func (p *PrefixDecorator) getFormat(fmt string) string { + return p.prefix[0].(string) + " " + fmt +} + +func (p *PrefixDecorator) Fatal(args ...interface{}) { + p.backend.Fatal(append(p.prefix, args...)...) +} + +func (p *PrefixDecorator) Fatalf(format string, args ...interface{}) { + p.backend.Fatalf(p.getFormat(format), args...) +} + +func (p *PrefixDecorator) Panic(args ...interface{}) { + p.backend.Panic(append(p.prefix, args...)...) +} + +func (p *PrefixDecorator) Panicf(format string, args ...interface{}) { + p.backend.Panicf(p.getFormat(format), args...) +} + +func (p *PrefixDecorator) Critical(args ...interface{}) { + p.backend.Critical(append(p.prefix, args...)...) +} + +func (p *PrefixDecorator) Criticalf(format string, args ...interface{}) { + p.backend.Criticalf(p.getFormat(format), args...) +} + +func (p *PrefixDecorator) Error(args ...interface{}) { + p.backend.Error(append(p.prefix, args...)...) +} + +func (p *PrefixDecorator) Errorf(format string, args ...interface{}) { + p.backend.Errorf(p.getFormat(format), args...) +} + +func (p *PrefixDecorator) Warning(args ...interface{}) { + p.backend.Warning(append(p.prefix, args...)...) +} + +func (p *PrefixDecorator) Warningf(format string, args ...interface{}) { + p.backend.Warningf(p.getFormat(format), args...) +} + +func (p *PrefixDecorator) Notice(args ...interface{}) { + p.backend.Notice(append(p.prefix, args...)...) +} + +func (p *PrefixDecorator) Noticef(format string, args ...interface{}) { + p.backend.Noticef(p.getFormat(format), args...) +} + +func (p *PrefixDecorator) Info(args ...interface{}) { + p.backend.Info(append(p.prefix, args...)...) +} + +func (p *PrefixDecorator) Infof(format string, args ...interface{}) { + p.backend.Infof(p.getFormat(format), args...) +} + +func (p *PrefixDecorator) Debug(args ...interface{}) { + p.backend.Debug(append(p.prefix, args...)...) +} + +func (p *PrefixDecorator) Debugf(format string, args ...interface{}) { + p.backend.Debugf(p.getFormat(format), args...) +} diff --git a/core/logger/prefix_decorator_test.go b/core/logger/prefix_decorator_test.go new file mode 100644 index 0000000..274b727 --- /dev/null +++ b/core/logger/prefix_decorator_test.go @@ -0,0 +1,166 @@ +package logger + +import ( + "bytes" + "testing" + + "github.com/op/go-logging" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +const testPrefix = "TestPrefix:" + +type PrefixDecoratorTest struct { + suite.Suite + buf *bytes.Buffer + logger PrefixedLogger +} + +func TestPrefixDecorator(t *testing.T) { + suite.Run(t, new(PrefixDecoratorTest)) +} + +func TestNewWithPrefix(t *testing.T) { + buf := &bytes.Buffer{} + logger := NewWithPrefix("code", "Prefix:", logging.DEBUG, DefaultLogFormatter()) + logger.(*PrefixDecorator).backend.(*StandardLogger). + SetBaseLogger(NewBase(buf, "code", logging.DEBUG, DefaultLogFormatter())) + logger.Debugf("message %s", "text") + + assert.Contains(t, buf.String(), "Prefix: message text") +} + +func (t *PrefixDecoratorTest) SetupSuite() { + t.buf = &bytes.Buffer{} + t.logger = DecorateWithPrefix((&StandardLogger{}). + SetBaseLogger(NewBase(t.buf, "code", logging.DEBUG, DefaultLogFormatter())), testPrefix) +} + +func (t *PrefixDecoratorTest) SetupTest() { + t.buf.Reset() + t.logger.SetPrefix(testPrefix) +} + +func (t *PrefixDecoratorTest) Test_SetPrefix() { + t.logger.Info("message") + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), testPrefix+" message") + + t.logger.SetPrefix(testPrefix + testPrefix) + t.logger.Info("message") + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), testPrefix+testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Panic() { + t.Require().Panics(func() { + t.logger.Panic("message") + }) + t.Assert().Contains(t.buf.String(), "CRIT") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Panicf() { + t.Require().Panics(func() { + t.logger.Panicf("%s", "message") + }) + t.Assert().Contains(t.buf.String(), "CRIT") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Critical() { + t.Require().NotPanics(func() { + t.logger.Critical("message") + }) + t.Assert().Contains(t.buf.String(), "CRIT") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Criticalf() { + t.Require().NotPanics(func() { + t.logger.Criticalf("%s", "message") + }) + t.Assert().Contains(t.buf.String(), "CRIT") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Error() { + t.Require().NotPanics(func() { + t.logger.Error("message") + }) + t.Assert().Contains(t.buf.String(), "ERRO") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Errorf() { + t.Require().NotPanics(func() { + t.logger.Errorf("%s", "message") + }) + t.Assert().Contains(t.buf.String(), "ERRO") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Warning() { + t.Require().NotPanics(func() { + t.logger.Warning("message") + }) + t.Assert().Contains(t.buf.String(), "WARN") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Warningf() { + t.Require().NotPanics(func() { + t.logger.Warningf("%s", "message") + }) + t.Assert().Contains(t.buf.String(), "WARN") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Notice() { + t.Require().NotPanics(func() { + t.logger.Notice("message") + }) + t.Assert().Contains(t.buf.String(), "NOTI") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Noticef() { + t.Require().NotPanics(func() { + t.logger.Noticef("%s", "message") + }) + t.Assert().Contains(t.buf.String(), "NOTI") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Info() { + t.Require().NotPanics(func() { + t.logger.Info("message") + }) + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Infof() { + t.Require().NotPanics(func() { + t.logger.Infof("%s", "message") + }) + t.Assert().Contains(t.buf.String(), "INFO") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Debug() { + t.Require().NotPanics(func() { + t.logger.Debug("message") + }) + t.Assert().Contains(t.buf.String(), "DEBU") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} + +func (t *PrefixDecoratorTest) Test_Debugf() { + t.Require().NotPanics(func() { + t.logger.Debugf("%s", "message") + }) + t.Assert().Contains(t.buf.String(), "DEBU") + t.Assert().Contains(t.buf.String(), testPrefix+" message") +} diff --git a/core/logger_test.go b/core/logger_test.go deleted file mode 100644 index 8894087..0000000 --- a/core/logger_test.go +++ /dev/null @@ -1,121 +0,0 @@ -package core - -import ( - // "os" - // "os/exec". - "testing" - - "github.com/op/go-logging" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/suite" -) - -type LoggerTest struct { - suite.Suite - logger *Logger -} - -func TestLogger_NewLogger(t *testing.T) { - logger := NewLogger("code", logging.DEBUG, DefaultLogFormatter()) - assert.NotNil(t, logger) -} - -func TestLogger_DefaultLogFormatter(t *testing.T) { - formatter := DefaultLogFormatter() - - assert.NotNil(t, formatter) - assert.IsType(t, logging.MustStringFormatter(`%{message}`), formatter) -} - -func Test_Logger(t *testing.T) { - suite.Run(t, new(LoggerTest)) -} - -func (t *LoggerTest) SetupSuite() { - t.logger = NewLogger("code", logging.DEBUG, DefaultLogFormatter()).Exclusive() -} - -// TODO Cover Fatal and Fatalf (implementation below is no-op) -// func (t *LoggerTest) Test_Fatal() { -// if os.Getenv("FLAG") == "1" { -// t.logger.Fatal("test", "fatal") -// return -// } - -// cmd := exec.Command(os.Args[0], "-test.run=TestGetConfig") -// cmd.Env = append(os.Environ(), "FLAG=1") -// err := cmd.Run() - -// e, ok := err.(*exec.ExitError) -// expectedErrorString := "test fatal" -// assert.Equal(t.T(), true, ok) -// assert.Equal(t.T(), expectedErrorString, e.Error()) -// } - -func (t *LoggerTest) Test_Panic() { - defer func() { - assert.NotNil(t.T(), recover()) - }() - t.logger.Panic("panic") -} - -func (t *LoggerTest) Test_Panicf() { - defer func() { - assert.NotNil(t.T(), recover()) - }() - t.logger.Panicf("panic") -} - -func (t *LoggerTest) Test_Critical() { - defer func() { - if v := recover(); v != nil { - t.T().Fatal(v) - } - }() - t.logger.Critical("critical") -} - -func (t *LoggerTest) Test_Criticalf() { - defer func() { - if v := recover(); v != nil { - t.T().Fatal(v) - } - }() - t.logger.Criticalf("critical") -} - -func (t *LoggerTest) Test_Warning() { - defer func() { - if v := recover(); v != nil { - t.T().Fatal(v) - } - }() - t.logger.Warning("warning") -} - -func (t *LoggerTest) Test_Notice() { - defer func() { - if v := recover(); v != nil { - t.T().Fatal(v) - } - }() - t.logger.Notice("notice") -} - -func (t *LoggerTest) Test_Info() { - defer func() { - if v := recover(); v != nil { - t.T().Fatal(v) - } - }() - t.logger.Info("info") -} - -func (t *LoggerTest) Test_Debug() { - defer func() { - if v := recover(); v != nil { - t.T().Fatal(v) - } - }() - t.logger.Debug("debug") -} diff --git a/core/csrf.go b/core/middleware/csrf.go similarity index 89% rename from core/csrf.go rename to core/middleware/csrf.go index e340cbd..3fe4716 100644 --- a/core/csrf.go +++ b/core/middleware/csrf.go @@ -1,4 +1,4 @@ -package core +package middleware import ( "bytes" @@ -41,6 +41,11 @@ const ( CSRFErrorTokenMismatch ) +const ( + keySize = 8 + randomStringSize = 64 +) + // DefaultCSRFTokenGetter default getter. var DefaultCSRFTokenGetter = func(c *gin.Context) string { r := c.Request @@ -70,31 +75,43 @@ var DefaultIgnoredMethods = []string{"GET", "HEAD", "OPTIONS"} // CSRF struct. Provides CSRF token verification. type CSRF struct { + store sessions.Store + abortFunc CSRFAbortFunc + csrfTokenGetter CSRFTokenGetter salt string secret string sessionName string - abortFunc CSRFAbortFunc - csrfTokenGetter CSRFTokenGetter - store sessions.Store } // NewCSRF creates CSRF struct with specified configuration and session store. // GenerateCSRFMiddleware and VerifyCSRFMiddleware returns CSRF middlewares. -// Salt must be different every time (pass empty salt to use random), secret must be provided, sessionName is optional - pass empty to use default, -// store will be used to store sessions, abortFunc will be called to return error if token is invalid, csrfTokenGetter will be used to obtain token. +// Salt must be different every time (pass empty salt to use random), secret must be provided, +// sessionName is optional - pass empty to use default, +// store will be used to store sessions, abortFunc will be called to return error if token is invalid, +// csrfTokenGetter will be used to obtain token. +// // Usage (with random salt): // core.NewCSRF("", "super secret", "csrf_session", store, func (c *gin.Context, reason core.CSRFErrorReason) { // c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "Invalid CSRF token"}) // }, core.DefaultCSRFTokenGetter) -// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) - don't forget to restore Body data! +// +// Note for csrfTokenGetter: if you want to read token from request body (for example, from form field) +// - don't forget to restore Body data! +// // Body in http.Request is io.ReadCloser instance. Reading CSRF token from form like that: // if t := r.FormValue("csrf_token"); len(t) > 0 { // return t // } // will close body - and all next middlewares won't be able to read body at all! +// // Use DefaultCSRFTokenGetter as example to implement your own token getter. // CSRFErrorReason will be passed to abortFunc and can be used for better error messages. -func NewCSRF(salt, secret, sessionName string, store sessions.Store, abortFunc CSRFAbortFunc, csrfTokenGetter CSRFTokenGetter) *CSRF { +func NewCSRF( + salt, secret, sessionName string, + store sessions.Store, + abortFunc CSRFAbortFunc, + csrfTokenGetter CSRFTokenGetter, +) *CSRF { if store == nil { panic("store must not be nil") } @@ -157,10 +174,10 @@ func (x *CSRF) generateCSRFToken() string { // Default secure salt length: 8 bytes. // Default pseudo-random salt length: 64 bytes. func (x *CSRF) generateSalt() string { - salt := securecookie.GenerateRandomKey(8) + salt := securecookie.GenerateRandomKey(keySize) if salt == nil { - return x.pseudoRandomString(64) + return x.pseudoRandomString(randomStringSize) } return string(salt) @@ -171,15 +188,15 @@ func (x *CSRF) pseudoRandomString(length int) string { rand.Seed(time.Now().UnixNano()) data := make([]byte, length) - for i := 0; i < length; i++ { - data[i] = byte(65 + rand.Intn(90-65)) + for i := 0; i < length; i++ { // it is supposed to use pseudo-random data. + data[i] = byte(65 + rand.Intn(90-65)) // nolint:gosec,gomnd } return string(data) } -// CSRFFromContext returns csrf token or random token. It shouldn't return empty string because it will make csrf protection useless. -// e.g. any request without token will work fine, which is inacceptable. +// CSRFFromContext returns csrf token or random token. It shouldn't return empty string because +// it will make csrf protection useless. e.g. any request without token will work fine, which is unacceptable. func (x *CSRF) CSRFFromContext(c *gin.Context) string { if i, ok := c.Get("csrf_token"); ok { if token, ok := i.(string); ok { diff --git a/core/csrf_test.go b/core/middleware/csrf_test.go similarity index 99% rename from core/csrf_test.go rename to core/middleware/csrf_test.go index 28e8467..5f7b831 100644 --- a/core/csrf_test.go +++ b/core/middleware/csrf_test.go @@ -1,4 +1,4 @@ -package core +package middleware import ( "bytes" @@ -21,10 +21,10 @@ type CSRFTest struct { } type requestOptions struct { + Body io.Reader + Headers map[string]string Method string URL string - Headers map[string]string - Body io.Reader } func TestCSRF_DefaultCSRFTokenGetter_Empty(t *testing.T) { diff --git a/core/one_step_connection.go b/core/middleware/one_step_connection.go similarity index 98% rename from core/one_step_connection.go rename to core/middleware/one_step_connection.go index 1c073e7..1cc0b64 100644 --- a/core/one_step_connection.go +++ b/core/middleware/one_step_connection.go @@ -1,4 +1,4 @@ -package core +package middleware import ( "encoding/json" diff --git a/core/one_step_connection_test.go b/core/middleware/one_step_connection_test.go similarity index 99% rename from core/one_step_connection_test.go rename to core/middleware/one_step_connection_test.go index 70b7af7..ae06cc5 100644 --- a/core/one_step_connection_test.go +++ b/core/middleware/one_step_connection_test.go @@ -1,4 +1,4 @@ -package core +package middleware import ( "crypto/hmac" diff --git a/core/models.go b/core/models.go deleted file mode 100644 index 150a927..0000000 --- a/core/models.go +++ /dev/null @@ -1,53 +0,0 @@ -package core - -import "time" - -// Connection model. -type Connection struct { - ID int `gorm:"primary_key"` - ClientID string `gorm:"column:client_id; type:varchar(70); not null; unique" json:"clientId,omitempty"` - Key string `gorm:"column:api_key; type:varchar(100); not null" json:"api_key,omitempty" binding:"required,max=100"` - URL string `gorm:"column:api_url; type:varchar(255); not null" json:"api_url,omitempty" binding:"required,validateCrmURL,max=255"` // nolint:lll - GateURL string `gorm:"column:mg_url; type:varchar(255); not null;" json:"mg_url,omitempty" binding:"max=255"` - GateToken string `gorm:"column:mg_token; type:varchar(100); not null; unique" json:"mg_token,omitempty" binding:"max=100"` - CreatedAt time.Time - UpdatedAt time.Time - Active bool `json:"active,omitempty"` - Accounts []Account `gorm:"foreignkey:ConnectionID"` -} - -// Account model. -type Account struct { - ID int `gorm:"primary_key"` - ConnectionID int `gorm:"column:connection_id" json:"connectionId,omitempty"` - Channel uint64 `gorm:"column:channel; not null; unique" json:"channel,omitempty"` - ChannelSettingsHash string `gorm:"column:channel_settings_hash; type:varchar(70)" binding:"max=70"` - Name string `gorm:"column:name; type:varchar(100)" json:"name,omitempty" binding:"max=100"` - Lang string `gorm:"column:lang; type:varchar(2)" json:"lang,omitempty" binding:"max=2"` - CreatedAt time.Time - UpdatedAt time.Time -} - -// User model. -type User struct { - ID int `gorm:"primary_key"` - ExternalID string `gorm:"column:external_id; type:varchar(255); not null; unique"` - UserPhotoURL string `gorm:"column:user_photo_url; type:varchar(255)" binding:"max=255"` - UserPhotoID string `gorm:"column:user_photo_id; type:varchar(100)" binding:"max=100"` - CreatedAt time.Time - UpdatedAt time.Time -} - -// TableName will return table name for User -// It will not work if User is not embedded, but mapped as another type -// type MyUser User // will not work -// but -// type MyUser struct { // will work -// User -// } -func (User) TableName() string { - return "mg_user" -} - -// Accounts list. -type Accounts []Account diff --git a/core/sentry.go b/core/sentry.go index 964209f..ad9c5ea 100644 --- a/core/sentry.go +++ b/core/sentry.go @@ -9,7 +9,9 @@ import ( "github.com/pkg/errors" - "github.com/retailcrm/mg-transport-core/core/stacktrace" + "github.com/retailcrm/mg-transport-core/v2/core/logger" + + "github.com/retailcrm/mg-transport-core/v2/core/stacktrace" "github.com/getsentry/raven-go" "github.com/gin-gonic/gin" @@ -34,19 +36,19 @@ type SentryTagged interface { // Sentry struct. Holds SentryTaggedStruct list. type Sentry struct { + Logger logger.Logger + Client stacktrace.RavenClientInterface + Localizer *Localizer + DefaultError string TaggedTypes SentryTaggedTypes Stacktrace bool - DefaultError string - Localizer *Localizer - Logger LoggerInterface - Client stacktrace.RavenClientInterface } // SentryTaggedStruct holds information about type, it's key in gin.Context (for middleware), and it's properties. type SentryTaggedStruct struct { Type reflect.Type - GinContextKey string Tags SentryTags + GinContextKey string } // SentryTaggedScalar variable from context. @@ -56,7 +58,13 @@ type SentryTaggedScalar struct { } // NewSentry constructor. -func NewSentry(sentryDSN string, defaultError string, taggedTypes SentryTaggedTypes, logger LoggerInterface, localizer *Localizer) *Sentry { +func NewSentry( + sentryDSN string, + defaultError string, + taggedTypes SentryTaggedTypes, + logger logger.Logger, + localizer *Localizer, +) *Sentry { sentry := &Sentry{ DefaultError: defaultError, TaggedTypes: taggedTypes, @@ -196,7 +204,7 @@ func (s *Sentry) ErrorResponseHandler() ErrorHandlerFunc { } // ErrorCaptureHandler will generate error data and send it to sentry. -func (s *Sentry) ErrorCaptureHandler() ErrorHandlerFunc { +func (s *Sentry) ErrorCaptureHandler() ErrorHandlerFunc { // nolint:gocognit return func(recovery interface{}, c *gin.Context) { tags := map[string]string{ "endpoint": c.Request.RequestURI, diff --git a/core/sentry_test.go b/core/sentry_test.go index f3ad5fd..40f3d34 100644 --- a/core/sentry_test.go +++ b/core/sentry_test.go @@ -16,12 +16,14 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" ) type sampleStruct struct { - ID int Pointer *int Field string + ID int } type ravenPacket struct { @@ -51,21 +53,11 @@ func (r ravenPacket) getException() (*raven.Exception, bool) { return nil, false } -func (r ravenPacket) getRequest() (*raven.Http, bool) { - if i, ok := r.getInterface("request"); ok { - if r, ok := i.(*raven.Http); ok { - return r, true - } - } - - return nil, false -} - type ravenClientMock struct { - raven.Client captured []ravenPacket - mu sync.RWMutex - wg sync.WaitGroup + raven.Client + mu sync.RWMutex + wg sync.WaitGroup } func newRavenMock() *ravenClientMock { @@ -94,7 +86,7 @@ func (r *ravenClientMock) CaptureMessageAndWait(message string, tags map[string] r.mu.Lock() defer r.mu.Unlock() defer r.wg.Done() - eventID := strconv.FormatUint(rand.Uint64(), 10) + eventID := strconv.FormatUint(rand.Uint64(), 10) // nolint:gosec r.captured = append(r.captured, ravenPacket{ EventID: eventID, Message: message, @@ -127,8 +119,8 @@ func (n *simpleError) Error() string { // wrappableError is a simple implementation of wrappable error. type wrappableError struct { - msg string err error + msg string } func newWrappableError(msg string, child error) error { @@ -304,7 +296,7 @@ func (s *SentryTest) TestSentry_CaptureRegularError() { c.Error(newSimpleError("test")) }) - var resp ErrorsResponse + var resp errorutil.ListResponse req, err := http.NewRequest(http.MethodGet, "/test_regularError", nil) require.NoError(s.T(), err) @@ -338,7 +330,7 @@ func (s *SentryTest) TestSentry_CaptureWrappedError() { c.Error(first) }) - var resp ErrorsResponse + var resp errorutil.ListResponse req, err := http.NewRequest(http.MethodGet, "/test_wrappableError", nil) require.NoError(s.T(), err) @@ -389,7 +381,7 @@ func (s *SentryTest) TestSentry_CaptureTags() { }), } - var resp ErrorsResponse + var resp errorutil.ListResponse req, err := http.NewRequest(http.MethodGet, "/test_taggedError", nil) require.NoError(s.T(), err) diff --git a/core/stacktrace/err_collector_builder.go b/core/stacktrace/err_collector_builder.go new file mode 100644 index 0000000..5861061 --- /dev/null +++ b/core/stacktrace/err_collector_builder.go @@ -0,0 +1,57 @@ +package stacktrace + +import ( + "path/filepath" + + "github.com/getsentry/raven-go" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" +) + +// ErrorNodesList is the interface for the errorutil.errList. +type ErrorNodesList interface { + Iterate() <-chan errorutil.Node + Len() int +} + +// ErrCollectorBuilder builds stacktrace from the list of errors collected by errorutil.Collector. +type ErrCollectorBuilder struct { + AbstractStackBuilder +} + +// IsErrorNodesList returns true if error contains error nodes. +func IsErrorNodesList(err error) bool { + _, ok := err.(ErrorNodesList) // nolint:errorlint + return ok +} + +// AsErrorNodesList returns ErrorNodesList instance from the error. +func AsErrorNodesList(err error) ErrorNodesList { + return err.(ErrorNodesList) // nolint:errorlint +} + +// Build stacktrace. +func (b *ErrCollectorBuilder) Build() StackBuilderInterface { + if !IsErrorNodesList(b.err) { + b.buildErr = ErrUnfeasibleBuilder + return b + } + + i := 0 + errs := AsErrorNodesList(b.err) + frames := make([]*raven.StacktraceFrame, errs.Len()) + + for err := range errs.Iterate() { + frames[i] = raven.NewStacktraceFrame( + err.PC, filepath.Base(err.File), err.File, err.Line, 3, b.client.IncludePaths()) + i++ + } + + if len(frames) <= 1 { + b.buildErr = ErrUnfeasibleBuilder + return b + } + + b.stack = &raven.Stacktrace{Frames: frames} + return b +} diff --git a/core/stacktrace/err_collector_builder_test.go b/core/stacktrace/err_collector_builder_test.go new file mode 100644 index 0000000..a7f6b6b --- /dev/null +++ b/core/stacktrace/err_collector_builder_test.go @@ -0,0 +1,55 @@ +package stacktrace + +import ( + "errors" + "path/filepath" + "runtime" + "strings" + "testing" + + "github.com/getsentry/raven-go" + "github.com/stretchr/testify/suite" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" +) + +type ErrCollectorBuilderTest struct { + builder *ErrCollectorBuilder + c *errorutil.Collector + suite.Suite +} + +func TestErrCollectorBuilder(t *testing.T) { + suite.Run(t, new(ErrCollectorBuilderTest)) +} + +func (t *ErrCollectorBuilderTest) SetupTest() { + t.c = errorutil.NewCollector() + client, _ := raven.New("fake dsn") + t.builder = &ErrCollectorBuilder{AbstractStackBuilder{ + client: client, + err: t.c, + }} +} + +func (t *ErrCollectorBuilderTest) TestBuild() { + t.c.Do( + errors.New("first"), + errors.New("second"), + errors.New("third")) + + stack, err := t.builder.Build().GetResult() + _, file, _, _ := runtime.Caller(0) + + t.Require().NoError(err) + t.Require().NotZero(stack) + t.Assert().Len(stack.Frames, 3) + + for _, frame := range stack.Frames { + t.Assert().Equal(file, frame.Filename) + t.Assert().Equal(file, frame.AbsolutePath) + t.Assert().Equal("go", frame.Function) + t.Assert().Equal(strings.TrimSuffix(filepath.Base(file), ".go"), frame.Module) + t.Assert().NotZero(frame.Lineno) + } +} diff --git a/core/stacktrace/pkg_errors_builder.go b/core/stacktrace/pkg_errors_builder.go index 18536ff..996e73f 100644 --- a/core/stacktrace/pkg_errors_builder.go +++ b/core/stacktrace/pkg_errors_builder.go @@ -14,6 +14,13 @@ type PkgErrorTraceable interface { StackTrace() pkgErrors.StackTrace } +// IsPkgErrorsError returns true if passed error might be github.com/pkg/errors error. +func IsPkgErrorsError(err error) bool { + _, okTraceable := err.(PkgErrorTraceable) // nolint:errorlint + _, okCauseable := err.(PkgErrorCauseable) // nolint:errorlint + return okTraceable || okCauseable +} + // PkgErrorsStackTransformer transforms stack data from github.com/pkg/errors error to stacktrace.Stacktrace. type PkgErrorsStackTransformer struct { stack pkgErrors.StackTrace @@ -44,7 +51,7 @@ type PkgErrorsBuilder struct { // Build stacktrace. func (b *PkgErrorsBuilder) Build() StackBuilderInterface { - if !isPkgErrors(b.err) { + if !IsPkgErrorsError(b.err) { b.buildErr = ErrUnfeasibleBuilder return b } @@ -71,7 +78,7 @@ func (b *PkgErrorsBuilder) Build() StackBuilderInterface { // getErrorCause will try to extract original error from wrapper - it is used only if stacktrace is not present. func (b *PkgErrorsBuilder) getErrorCause(err error) error { - causeable, ok := err.(PkgErrorCauseable) + causeable, ok := err.(PkgErrorCauseable) // nolint:errorlint if !ok { return nil } @@ -81,7 +88,7 @@ func (b *PkgErrorsBuilder) getErrorCause(err error) error { // getErrorStackTrace will try to extract stacktrace from error using StackTrace method // (default errors doesn't have it). func (b *PkgErrorsBuilder) getErrorStack(err error) pkgErrors.StackTrace { - traceable, ok := err.(PkgErrorTraceable) + traceable, ok := err.(PkgErrorTraceable) // nolint:errorlint if !ok { return nil } diff --git a/core/stacktrace/pkg_errors_builder_test.go b/core/stacktrace/pkg_errors_builder_test.go index 4309eab..7379afe 100644 --- a/core/stacktrace/pkg_errors_builder_test.go +++ b/core/stacktrace/pkg_errors_builder_test.go @@ -13,8 +13,8 @@ import ( // errorWithCause has Cause() method, but doesn't have StackTrace() method. type errorWithCause struct { - msg string cause error + msg string } func newErrorWithCause(msg string, cause error) error { @@ -53,7 +53,7 @@ func (s *PkgErrorsStackProviderSuite) Test_Empty() { func (s *PkgErrorsStackProviderSuite) Test_Full() { testErr := pkgErrors.New("test") - s.transformer.stack = testErr.(PkgErrorTraceable).StackTrace() + s.transformer.stack = testErr.(PkgErrorTraceable).StackTrace() // nolint:errorlint assert.NotEmpty(s.T(), s.transformer.Stack()) } diff --git a/core/stacktrace/raven_stacktrace_builder.go b/core/stacktrace/raven_stacktrace_builder.go index b283c46..44978e1 100644 --- a/core/stacktrace/raven_stacktrace_builder.go +++ b/core/stacktrace/raven_stacktrace_builder.go @@ -55,16 +55,14 @@ func (b *RavenStacktraceBuilder) Build(context int, appPackagePrefixes []string) } // convertFrame converts single generic stacktrace frame to github.com/pkg/errors.Frame. -func (b *RavenStacktraceBuilder) convertFrame(f Frame, context int, appPackagePrefixes []string) *raven.StacktraceFrame { +func (b *RavenStacktraceBuilder) convertFrame( + f Frame, context int, appPackagePrefixes []string) *raven.StacktraceFrame { // This code is borrowed from github.com/pkg/errors.Frame. pc := uintptr(f) - 1 - fn := runtime.FuncForPC(pc) - var file string - var line int - if fn != nil { + line := 0 + file := "unknown" + if fn := runtime.FuncForPC(pc); fn != nil { file, line = fn.FileLine(pc) - } else { - file = "unknown" } return raven.NewStacktraceFrame(pc, path.Dir(file), file, line, context, appPackagePrefixes) } diff --git a/core/stacktrace/stack_builder_factory.go b/core/stacktrace/stack_builder_factory.go index 565c042..0d58623 100644 --- a/core/stacktrace/stack_builder_factory.go +++ b/core/stacktrace/stack_builder_factory.go @@ -3,20 +3,17 @@ package stacktrace // GetStackBuilderByErrorType tries to guess which stacktrace builder would be feasible for passed error. // For example, errors from github.com/pkg/errors have StackTrace() method, and Go 1.13 errors can be unwrapped. func GetStackBuilderByErrorType(err error) StackBuilderInterface { - if isPkgErrors(err) { + if IsPkgErrorsError(err) { return &PkgErrorsBuilder{AbstractStackBuilder{err: err}} } - if _, ok := err.(Unwrappable); ok { + if IsUnwrappableError(err) { return &UnwrapBuilder{AbstractStackBuilder{err: err}} } + if IsErrorNodesList(err) { + return &ErrCollectorBuilder{AbstractStackBuilder{err: err}} + } + return &GenericStackBuilder{AbstractStackBuilder{err: err}} } - -// isPkgErrors returns true if passed error might be github.com/pkg/errors error. -func isPkgErrors(err error) bool { - _, okTraceable := err.(PkgErrorTraceable) - _, okCauseable := err.(PkgErrorCauseable) - return okTraceable || okCauseable -} diff --git a/core/stacktrace/stack_builder_factory_test.go b/core/stacktrace/stack_builder_factory_test.go index 4edc3c3..a543c90 100644 --- a/core/stacktrace/stack_builder_factory_test.go +++ b/core/stacktrace/stack_builder_factory_test.go @@ -6,6 +6,8 @@ import ( pkgErrors "github.com/pkg/errors" "github.com/stretchr/testify/assert" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" ) func TestGetStackBuilderByErrorType_PkgErrors(t *testing.T) { @@ -20,6 +22,12 @@ func TestGetStackBuilderByErrorType_UnwrapBuilder(t *testing.T) { assert.IsType(t, &UnwrapBuilder{}, builder) } +func TestGetStackBuilderByErrorType_ErrCollectorBuilder(t *testing.T) { + testErr := errorutil.NewCollector().Do(errors.New("first"), errors.New("second")).AsError() + builder := GetStackBuilderByErrorType(testErr) + assert.IsType(t, &ErrCollectorBuilder{}, builder) +} + func TestGetStackBuilderByErrorType_Generic(t *testing.T) { defaultErr := errors.New("default err") builder := GetStackBuilderByErrorType(defaultErr) diff --git a/core/stacktrace/unwrap_builder.go b/core/stacktrace/unwrap_builder.go index 4785a3d..737cff6 100644 --- a/core/stacktrace/unwrap_builder.go +++ b/core/stacktrace/unwrap_builder.go @@ -14,15 +14,21 @@ type UnwrapBuilder struct { AbstractStackBuilder } +// IsUnwrappableError returns true if error can be unwrapped. +func IsUnwrappableError(err error) bool { + _, ok := err.(Unwrappable) // nolint:errorlint + return ok +} + // Build stacktrace. func (b *UnwrapBuilder) Build() StackBuilderInterface { - if _, ok := b.err.(Unwrappable); !ok { + if !IsUnwrappableError(b.err) { b.buildErr = ErrUnfeasibleBuilder return b } err := b.err - frames := []*raven.StacktraceFrame{} + var frames []*raven.StacktraceFrame for err != nil { frames = append(frames, raven.NewStacktraceFrame( @@ -34,7 +40,7 @@ func (b *UnwrapBuilder) Build() StackBuilderInterface { b.client.IncludePaths(), )) - if item, ok := err.(Unwrappable); ok { + if item, ok := err.(Unwrappable); ok { // nolint:errorlint err = item.Unwrap() } else { err = nil diff --git a/core/stacktrace/unwrap_builder_test.go b/core/stacktrace/unwrap_builder_test.go index ebe6934..ddfae71 100644 --- a/core/stacktrace/unwrap_builder_test.go +++ b/core/stacktrace/unwrap_builder_test.go @@ -25,8 +25,8 @@ func (n *simpleError) Error() string { // wrappableError is a simple implementation of wrappable error. type wrappableError struct { - msg string err error + msg string } func newWrappableError(msg string, child error) error { @@ -75,8 +75,7 @@ func (s *UnwrapBuilderSuite) TestBuild_NoUnwrap() { func (s *UnwrapBuilderSuite) TestBuild_WrappableHasWrapped() { testErr := newWrappableError("first", newWrappableError("second", errors.New("third"))) - _, ok := testErr.(Unwrappable) - require.True(s.T(), ok) + require.True(s.T(), IsUnwrappableError(testErr)) s.builder.SetError(testErr) stack, buildErr := s.builder.Build().GetResult() diff --git a/core/template.go b/core/template.go index 51256a4..96ae5c5 100644 --- a/core/template.go +++ b/core/template.go @@ -54,15 +54,15 @@ func (r *Renderer) Push(name string, files ...string) *template.Template { // addFromFS adds embedded template. func (r *Renderer) addFromFS(name string, funcMap template.FuncMap, files ...string) *template.Template { - var filesData []string + filesData := make([]string, len(files)) - for _, fileName := range files { - data, err := fs.ReadFile(r.TemplatesFS, fileName) + for i := 0; i < len(files); i++ { + data, err := fs.ReadFile(r.TemplatesFS, files[i]) if err != nil { panic(err) } - filesData = append(filesData, string(data)) + filesData[i] = string(data) } return r.AddFromStringsFuncs(name, funcMap, filesData...) diff --git a/core/util/errorutil/err_collector.go b/core/util/errorutil/err_collector.go new file mode 100644 index 0000000..bc6b669 --- /dev/null +++ b/core/util/errorutil/err_collector.go @@ -0,0 +1,135 @@ +package errorutil + +import ( + "fmt" + "runtime" + "strings" +) + +// Collector is a replacement for the core.ErrorCollector function. It is easier to use and contains more functionality. +// For example, you can iterate over the errors or use Collector.Panic() to immediately panic +// if there are errors in the chain. +// +// Error messages will differ from the ones produced by ErrorCollector. However, it's for the best because +// new error messages contain a lot more useful information and can be even used as a stacktrace. +// +// Collector implements Error() and String() methods. As a result, you can use the collector as an error itself or just +// print out it as a value. However, it is better to use AsError() method if you want to use Collector as an error value +// because AsError() returns nil if there are no errors in the list. +// +// Example: +// err := errorutil.NewCollector(). +// Do(errors.New("error 1")). +// Do(errors.New("error 2"), errors.New("error 3")) +// // Will print error message. +// fmt.Println(err) +// +// This code will produce something like this: +// #1 err at /home/user/main.go:62: error 1 +// #2 err at /home/user/main.go:63: error 2 +// #3 err at /home/user/main.go:64: error 3 +// +// You can also iterate over the error to use their data instead of using predefined message: +// err := errorutil.NewCollector(). +// Do(errors.New("error 1")). +// Do(errors.New("error 2"), errors.New("error 3")) +// +// for err := range c.Iterate() { +// fmt.Printf("Error at %s:%d: %v\n", err.File, err.Line, err) +// } +// +// This code will produce output that looks like this: +// Error at /home/user/main.go:164: error 0 +// Error at /home/user/main.go:164: error 1 +// Error at /home/user/main.go:164: error 2 +// +// Example with GORM migration (Collector is returned as an error here). +// return errorutil.NewCollector().Do( +// db.CreateTable(models.Account{}, models.Connection{}).Error, +// db.Table("account").AddUniqueIndex("account_key", "channel").Error, +// ).AsError() +type Collector struct { + errors *errList +} + +// NewCollector returns new errorutil.Collector instance. +func NewCollector() *Collector { + return &Collector{ + errors: &errList{}, + } +} + +// Collect errors, return one error for all of them (shorthand for errorutil.NewCollector().Do(...).AsError()). +// Returns nil if there are no errors. +func Collect(errs ...error) error { + return NewCollector().Do(errs...).AsError() +} + +// Do some operation that returns the error. Supports multiple operations at once. +func (e *Collector) Do(errs ...error) *Collector { + pc, file, line, _ := runtime.Caller(1) + + for _, err := range errs { + if err != nil { + e.errors.Push(pc, err, file, line) + } + } + + return e +} + +// OK returns true if there is no errors in the list. +func (e *Collector) OK() bool { + return e.errors.Len() == 0 +} + +// Error message. +func (e *Collector) Error() string { + return e.buildErrorMessage() +} + +// AsError returns the Collector itself as an error, but only if there are errors in the list. +// It returns nil otherwise. This method should be used if you want to return error to the caller, but only if\ +// Collector actually caught something. +func (e *Collector) AsError() error { + if e.OK() { + return nil + } + return e +} + +// String with an error message. +func (e *Collector) String() string { + return e.Error() +} + +// Panic with the error data if there are errors in the list. +func (e *Collector) Panic() { + if !e.OK() { + panic(e) + } +} + +// Iterate over the errors in the list. Every error is represented as an errorutil.Node value. +func (e *Collector) Iterate() <-chan Node { + return e.errors.Iterate() +} + +// Len returns the number of the errors in the list. +func (e *Collector) Len() int { + return e.errors.Len() +} + +// buildErrorMessage builds error message for the Collector.Error() and Collector.String() methods. +func (e *Collector) buildErrorMessage() string { + i := 0 + var sb strings.Builder + sb.Grow(128 * e.errors.Len()) // nolint:gomnd + + for node := range e.errors.Iterate() { + i++ + sb.WriteString(fmt.Sprintf("#%d err at %s:%d: %v\n", i, node.File, node.Line, node.Err)) + } + + return strings.TrimRight(sb.String(), "\n") +} diff --git a/core/util/errorutil/err_collector_test.go b/core/util/errorutil/err_collector_test.go new file mode 100644 index 0000000..59b557e --- /dev/null +++ b/core/util/errorutil/err_collector_test.go @@ -0,0 +1,74 @@ +package errorutil + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type ErrCollectorTest struct { + c *Collector + suite.Suite +} + +func TestErrCollector(t *testing.T) { + suite.Run(t, new(ErrCollectorTest)) +} + +func TestCollect_NoError(t *testing.T) { + require.NoError(t, Collect()) +} + +func TestCollect_NoError_Nils(t *testing.T) { + require.NoError(t, Collect(nil, nil, nil)) +} + +func TestCollect_Error(t *testing.T) { + require.Error(t, Collect(errors.New("first"), errors.New("second"))) +} + +func (t *ErrCollectorTest) SetupTest() { + t.c = NewCollector() +} + +func (t *ErrCollectorTest) TestDo() { + t.c.Do( + errors.New("first"), + errors.New("second"), + errors.New("third")) + + t.Assert().False(t.c.OK()) + t.Assert().NotEmpty(t.c.String()) + t.Assert().Error(t.c.AsError()) + t.Assert().Equal(3, t.c.Len()) + t.Assert().Panics(func() { + t.c.Panic() + }) + + i := 0 + for err := range t.c.Iterate() { + t.Assert().Error(err.Err) + t.Assert().NotEmpty(err.File) + t.Assert().NotZero(err.Line) + t.Assert().NotZero(err.PC) + + switch i { + case 0: + t.Assert().Equal("first", err.Error()) + case 1: + t.Assert().Equal("second", err.Error()) + case 2: + t.Assert().Equal("third", err.Error()) + } + + i++ + } +} + +func (t *ErrCollectorTest) Test_PanicNone() { + t.Assert().NotPanics(func() { + t.c.Panic() + }) +} diff --git a/core/util/errorutil/err_list.go b/core/util/errorutil/err_list.go new file mode 100644 index 0000000..9f52e99 --- /dev/null +++ b/core/util/errorutil/err_list.go @@ -0,0 +1,68 @@ +package errorutil + +// Node contains information about error in the list. +type Node struct { + Err error + next *Node + File string + PC uintptr + Line int +} + +// Error returns error message from the Node's Err. +func (e Node) Error() string { + if e.Err == nil { + return "" + } + return e.Err.Error() +} + +// errList is an immutable linked list of error. +type errList struct { + head *Node + tail *Node + len int +} + +// Push an error into the list. +func (l *errList) Push(pc uintptr, err error, file string, line int) { + item := &Node{ + PC: pc, + Err: err, + File: file, + Line: line, + } + + if l.head == nil { + l.head = item + } + + if l.tail != nil { + l.tail.next = item + } + + l.tail = item + l.len++ +} + +// Iterate over error list. +func (l *errList) Iterate() <-chan Node { + c := make(chan Node) + go func() { + item := l.head + for { + if item == nil { + close(c) + return + } + c <- *item + item = item.next + } + }() + return c +} + +// Len returns length of the list. +func (l *errList) Len() int { + return l.len +} diff --git a/core/util/errorutil/handler_errors.go b/core/util/errorutil/handler_errors.go new file mode 100644 index 0000000..a04a42b --- /dev/null +++ b/core/util/errorutil/handler_errors.go @@ -0,0 +1,50 @@ +package errorutil + +import "net/http" + +// Response with the error message. +type Response struct { + Error string `json:"error"` +} + +// ListResponse contains multiple errors in the list. +type ListResponse struct { + Error []string `json:"error"` +} + +// GetErrorResponse returns ErrorResponse with specified status code +// Usage (with gin): +// context.JSON(GetErrorResponse(http.StatusPaymentRequired, "Not enough money")) +func GetErrorResponse(statusCode int, err string) (int, interface{}) { + return statusCode, Response{ + Error: err, + } +} + +// BadRequest returns ErrorResponse with code 400 +// Usage (with gin): +// context.JSON(BadRequest("invalid data")) +func BadRequest(err string) (int, interface{}) { + return GetErrorResponse(http.StatusBadRequest, err) +} + +// Unauthorized returns ErrorResponse with code 401 +// Usage (with gin): +// context.JSON(Unauthorized("invalid credentials")) +func Unauthorized(err string) (int, interface{}) { + return GetErrorResponse(http.StatusUnauthorized, err) +} + +// Forbidden returns ErrorResponse with code 403 +// Usage (with gin): +// context.JSON(Forbidden("forbidden")) +func Forbidden(err string) (int, interface{}) { + return GetErrorResponse(http.StatusForbidden, err) +} + +// InternalServerError returns ErrorResponse with code 500 +// Usage (with gin): +// context.JSON(BadRequest("invalid data")) +func InternalServerError(err string) (int, interface{}) { + return GetErrorResponse(http.StatusInternalServerError, err) +} diff --git a/core/error_test.go b/core/util/errorutil/handler_errors_test.go similarity index 73% rename from core/error_test.go rename to core/util/errorutil/handler_errors_test.go index 8915dd8..68d1f46 100644 --- a/core/error_test.go +++ b/core/util/errorutil/handler_errors_test.go @@ -1,4 +1,4 @@ -package core +package errorutil import ( "net/http" @@ -11,19 +11,19 @@ func TestError_GetErrorResponse(t *testing.T) { code, resp := GetErrorResponse(http.StatusBadRequest, "error string") assert.Equal(t, http.StatusBadRequest, code) - assert.Equal(t, "error string", resp.(ErrorResponse).Error) + assert.Equal(t, "error string", resp.(Response).Error) } func TestError_BadRequest(t *testing.T) { code, resp := BadRequest("error string") assert.Equal(t, http.StatusBadRequest, code) - assert.Equal(t, "error string", resp.(ErrorResponse).Error) + assert.Equal(t, "error string", resp.(Response).Error) } func TestError_InternalServerError(t *testing.T) { code, resp := InternalServerError("error string") assert.Equal(t, http.StatusInternalServerError, code) - assert.Equal(t, "error string", resp.(ErrorResponse).Error) + assert.Equal(t, "error string", resp.(Response).Error) } diff --git a/core/error_scopes.go b/core/util/errorutil/scopes_err.go similarity index 87% rename from core/error_scopes.go rename to core/util/errorutil/scopes_err.go index 2c295e3..29f45ed 100644 --- a/core/error_scopes.go +++ b/core/util/errorutil/scopes_err.go @@ -1,4 +1,4 @@ -package core +package errorutil import ( "errors" @@ -17,8 +17,8 @@ type ScopesList interface { // insufficientScopesErr contains information about missing auth scopes. type insufficientScopesErr struct { - scopes []string wrapped error + scopes []string } // Error message. @@ -48,3 +48,8 @@ func NewInsufficientScopesErr(scopes []string) error { wrapped: ErrInsufficientScopes, } } + +// AsInsufficientScopesErr returns ScopesList instance. +func AsInsufficientScopesErr(err error) ScopesList { + return err.(ScopesList) // nolint:errorlint +} diff --git a/core/error_scopes_test.go b/core/util/errorutil/scopes_err_test.go similarity index 75% rename from core/error_scopes_test.go rename to core/util/errorutil/scopes_err_test.go index 05697a2..2a336bf 100644 --- a/core/error_scopes_test.go +++ b/core/util/errorutil/scopes_err_test.go @@ -1,4 +1,4 @@ -package core +package errorutil import ( "errors" @@ -12,4 +12,5 @@ func TestError_NewScopesError(t *testing.T) { scopesError := NewInsufficientScopesErr(scopes) assert.True(t, errors.Is(scopesError, ErrInsufficientScopes)) + assert.Equal(t, scopes, AsInsufficientScopesErr(scopesError).Scopes()) } diff --git a/core/http_client_builder.go b/core/util/httputil/http_client_builder.go similarity index 85% rename from core/http_client_builder.go rename to core/util/httputil/http_client_builder.go index 5f9e5de..5a7d4f3 100644 --- a/core/http_client_builder.go +++ b/core/util/httputil/http_client_builder.go @@ -1,4 +1,4 @@ -package core +package httputil import ( "context" @@ -10,6 +10,10 @@ import ( "time" "github.com/pkg/errors" + + "github.com/retailcrm/mg-transport-core/v2/core/config" + + "github.com/retailcrm/mg-transport-core/v2/core/logger" ) // DefaultClient stores original http.DefaultClient. @@ -43,18 +47,18 @@ var DefaultTransport = http.DefaultTransport // fmt.Print(err) // } type HTTPClientBuilder struct { + logger logger.Logger httpClient *http.Client httpTransport *http.Transport - certsPool *x509.CertPool dialer *net.Dialer - logger LoggerInterface - built bool - logging bool - timeout time.Duration mockAddress string mockHost string mockPort string mockedDomains []string + timeout time.Duration + tlsVersion uint16 + logging bool + built bool } // NewHTTPClientBuilder returns HTTPClientBuilder with default values. @@ -63,6 +67,7 @@ func NewHTTPClientBuilder() *HTTPClientBuilder { built: false, httpClient: &http.Client{}, httpTransport: &http.Transport{}, + tlsVersion: tls.VersionTLS12, timeout: 30 * time.Second, mockAddress: "", mockedDomains: []string{}, @@ -71,7 +76,7 @@ func NewHTTPClientBuilder() *HTTPClientBuilder { } // WithLogger sets provided logger into HTTPClientBuilder. -func (b *HTTPClientBuilder) WithLogger(logger LoggerInterface) *HTTPClientBuilder { +func (b *HTTPClientBuilder) WithLogger(logger logger.Logger) *HTTPClientBuilder { if logger != nil { b.logger = logger } @@ -81,7 +86,7 @@ func (b *HTTPClientBuilder) WithLogger(logger LoggerInterface) *HTTPClientBuilde // SetTimeout sets timeout for http client. func (b *HTTPClientBuilder) SetTimeout(seconds time.Duration) *HTTPClientBuilder { - seconds = seconds * time.Second + seconds *= time.Second b.timeout = seconds b.httpClient.Timeout = seconds return b @@ -108,7 +113,7 @@ func (b *HTTPClientBuilder) SetMockedDomains(domains []string) *HTTPClientBuilde // SetSSLVerification enables or disables SSL certificates verification in client. func (b *HTTPClientBuilder) SetSSLVerification(enabled bool) *HTTPClientBuilder { if b.httpTransport.TLSClientConfig == nil { - b.httpTransport.TLSClientConfig = &tls.Config{} + b.httpTransport.TLSClientConfig = b.baseTLSConfig() } b.httpTransport.TLSClientConfig.InsecureSkipVerify = !enabled @@ -116,10 +121,19 @@ func (b *HTTPClientBuilder) SetSSLVerification(enabled bool) *HTTPClientBuilder return b } -// SetSSLVerification enables or disables SSL certificates verification in client. +// UseTLS10 restores TLS 1.0 as a minimal supported TLS version. +func (b *HTTPClientBuilder) UseTLS10() *HTTPClientBuilder { + b.tlsVersion = tls.VersionTLS10 + if b.httpTransport.TLSClientConfig != nil { + b.httpTransport.TLSClientConfig.MinVersion = b.tlsVersion + } + return b +} + +// SetCertPool sets provided TLS certificates pool into the client. func (b *HTTPClientBuilder) SetCertPool(pool *x509.CertPool) *HTTPClientBuilder { if b.httpTransport.TLSClientConfig == nil { - b.httpTransport.TLSClientConfig = &tls.Config{} + b.httpTransport.TLSClientConfig = b.baseTLSConfig() } b.httpTransport.TLSClientConfig.RootCAs = pool @@ -134,7 +148,7 @@ func (b *HTTPClientBuilder) SetLogging(flag bool) *HTTPClientBuilder { } // FromConfig fulfills mock configuration from HTTPClientConfig. -func (b *HTTPClientBuilder) FromConfig(config *HTTPClientConfig) *HTTPClientBuilder { +func (b *HTTPClientBuilder) FromConfig(config *config.HTTPClientConfig) *HTTPClientBuilder { if config == nil { return b } @@ -153,9 +167,9 @@ func (b *HTTPClientBuilder) FromConfig(config *HTTPClientConfig) *HTTPClientBuil return b } -// FromEngine fulfills mock configuration from ConfigInterface inside Engine. -func (b *HTTPClientBuilder) FromEngine(engine *Engine) *HTTPClientBuilder { - return b.FromConfig(engine.GetHTTPClientConfig()) +// baseTLSConfig returns *tls.Config with TLS 1.2 as a minimal supported version. +func (b *HTTPClientBuilder) baseTLSConfig() *tls.Config { + return &tls.Config{MinVersion: b.tlsVersion} // nolint:gosec } // buildDialer initializes dialer with provided timeout. diff --git a/core/http_client_builder_test.go b/core/util/httputil/http_client_builder_test.go similarity index 90% rename from core/http_client_builder_test.go rename to core/util/httputil/http_client_builder_test.go index 7d0418c..75e5125 100644 --- a/core/http_client_builder_test.go +++ b/core/util/httputil/http_client_builder_test.go @@ -1,8 +1,10 @@ -package core +package httputil import ( "context" + "crypto/tls" "crypto/x509" + "errors" "fmt" "io" "io/ioutil" @@ -18,6 +20,12 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/retailcrm/mg-transport-core/v2/core/config" + + "github.com/retailcrm/mg-transport-core/v2/core/logger" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" ) type HTTPClientBuilderTest struct { @@ -92,7 +100,7 @@ func (t *HTTPClientBuilderTest) Test_FromConfigNil() { } func (t *HTTPClientBuilderTest) Test_FromConfig() { - config := &HTTPClientConfig{ + config := &config.HTTPClientConfig{ SSLVerification: boolPtr(true), MockAddress: "anothermock.local:3004", MockedDomains: []string{"example.gov"}, @@ -107,22 +115,6 @@ func (t *HTTPClientBuilderTest) Test_FromConfig() { assert.Equal(t.T(), config.Timeout*time.Second, t.builder.httpClient.Timeout) } -func (t *HTTPClientBuilderTest) Test_FromEngine() { - engine := &Engine{ - Config: Config{ - HTTPClientConfig: &HTTPClientConfig{ - SSLVerification: boolPtr(true), - MockAddress: "anothermock.local:3004", - MockedDomains: []string{"example.gov"}, - }, - Debug: false, - }, - } - - t.builder.FromEngine(engine) - assert.Equal(t.T(), engine.Config.GetHTTPClientConfig().MockAddress, t.builder.mockAddress) -} - func (t *HTTPClientBuilderTest) Test_buildDialer() { t.builder.buildDialer() @@ -138,7 +130,7 @@ func (t *HTTPClientBuilderTest) Test_buildMocks() { } func (t *HTTPClientBuilderTest) Test_WithLogger() { - logger := NewLogger("telegram", logging.ERROR, DefaultLogFormatter()) + logger := logger.NewStandard("telegram", logging.ERROR, logger.DefaultLogFormatter()) builder := NewHTTPClientBuilder() require.Nil(t.T(), builder.logger) @@ -248,9 +240,9 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc= _, err = keyFile.WriteString(keyFileData) require.NoError(t.T(), err, "cannot write temp key file") require.NoError(t.T(), - ErrorCollector(certFile.Sync(), certFile.Close()), "cannot sync and close temp cert file") + errorutil.Collect(certFile.Sync(), certFile.Close()), "cannot sync and close temp cert file") require.NoError(t.T(), - ErrorCollector(keyFile.Sync(), keyFile.Close()), "cannot sync and close temp key file") + errorutil.Collect(keyFile.Sync(), keyFile.Close()), "cannot sync and close temp key file") mux := &http.ServeMux{} srv := &http.Server{Addr: mockServerAddr, Handler: mux} @@ -261,8 +253,8 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc= testSkipChan := make(chan error, 1) go func(skip chan error) { - if err := srv.ListenAndServeTLS(certFile.Name(), keyFile.Name()); err != nil && err != http.ErrServerClosed { - skip <- fmt.Errorf("skipping test because server won't start: %s", err.Error()) + if err := srv.ListenAndServeTLS(certFile.Name(), keyFile.Name()); err != nil && !errors.Is(err, http.ErrServerClosed) { + skip <- fmt.Errorf("skipping test because server won't start: %w", err) } }(testSkipChan) @@ -314,6 +306,16 @@ uf/TQPpjrGW5nxOf94qn6FzV2WSype9BcM5MD7z7rk202Fs7Zqc= assert.Equal(t.T(), "ok", string(data), "invalid body contents") } +func (t *HTTPClientBuilderTest) Test_UseTLS10() { + client, err := NewHTTPClientBuilder().SetSSLVerification(true).UseTLS10().Build() + + t.Require().NoError(err) + t.Require().NotNil(client) + t.Require().NotNil(client.Transport) + t.Require().NotNil(client.Transport.(*http.Transport).TLSClientConfig) + t.Assert().Equal(uint16(tls.VersionTLS10), client.Transport.(*http.Transport).TLSClientConfig.MinVersion) +} + // taken from https://stackoverflow.com/questions/23558425/how-do-i-get-the-local-ip-address-in-go func getOutboundIP() net.IP { conn, err := net.Dial("udp", "8.8.8.8:80") @@ -330,3 +332,8 @@ func getOutboundIP() net.IP { func Test_HTTPClientBuilder(t *testing.T) { suite.Run(t, new(HTTPClientBuilderTest)) } + +func boolPtr(val bool) *bool { + b := val + return &b +} diff --git a/core/util/testutil/buffer_logger.go b/core/util/testutil/buffer_logger.go new file mode 100644 index 0000000..2a6a9ca --- /dev/null +++ b/core/util/testutil/buffer_logger.go @@ -0,0 +1,133 @@ +package testutil + +import ( + "bytes" + "fmt" + "io" + "os" + + "github.com/op/go-logging" + + "github.com/retailcrm/mg-transport-core/v2/core/logger" +) + +// ReadBuffer is implemented by the BufferLogger. +// Its methods give access to the buffer contents and ability to read buffer as an io.Reader or reset its contents. +type ReadBuffer interface { + io.Reader + fmt.Stringer + Bytes() []byte + Reset() +} + +// BufferedLogger is a logger that can return the data written to it. +type BufferedLogger interface { + ReadBuffer + logger.Logger +} + +// BufferLogger is an implementation of the BufferedLogger. +type BufferLogger struct { + buf bytes.Buffer +} + +// NewBufferedLogger returns new BufferedLogger instance. +func NewBufferedLogger() BufferedLogger { + return &BufferLogger{} +} + +// Read bytes from the logger buffer. io.Reader implementation. +func (l *BufferLogger) Read(p []byte) (n int, err error) { + return l.buf.Read(p) +} + +// String contents of the logger buffer. fmt.Stringer implementation. +func (l *BufferLogger) String() string { + return l.buf.String() +} + +// Bytes is a shorthand for the underlying bytes.Buffer method. Returns byte slice with the buffer contents. +func (l *BufferLogger) Bytes() []byte { + return l.buf.Bytes() +} + +// Reset is a shorthand for the underlying bytes.Buffer method. It will reset buffer contents. +func (l *BufferLogger) Reset() { + l.buf.Reset() +} + +func (l *BufferLogger) write(level logging.Level, args ...interface{}) { + l.buf.WriteString(fmt.Sprintln(append([]interface{}{level.String(), "=>"}, args...)...)) +} + +func (l *BufferLogger) writef(level logging.Level, format string, args ...interface{}) { + l.buf.WriteString(fmt.Sprintf(level.String()+" => "+format, args...)) +} + +func (l *BufferLogger) Fatal(args ...interface{}) { + l.write(logging.CRITICAL, args...) + os.Exit(1) +} + +func (l *BufferLogger) Fatalf(format string, args ...interface{}) { + l.writef(logging.CRITICAL, format, args...) + os.Exit(1) +} + +func (l *BufferLogger) Panic(args ...interface{}) { + l.write(logging.CRITICAL, args...) + panic(fmt.Sprint(args...)) +} + +func (l *BufferLogger) Panicf(format string, args ...interface{}) { + l.writef(logging.CRITICAL, format, args...) + panic(fmt.Sprintf(format, args...)) +} + +func (l *BufferLogger) Critical(args ...interface{}) { + l.write(logging.CRITICAL, args...) +} + +func (l *BufferLogger) Criticalf(format string, args ...interface{}) { + l.writef(logging.CRITICAL, format, args...) +} + +func (l *BufferLogger) Error(args ...interface{}) { + l.write(logging.ERROR, args...) +} + +func (l *BufferLogger) Errorf(format string, args ...interface{}) { + l.writef(logging.ERROR, format, args...) +} + +func (l *BufferLogger) Warning(args ...interface{}) { + l.write(logging.WARNING, args...) +} + +func (l *BufferLogger) Warningf(format string, args ...interface{}) { + l.writef(logging.WARNING, format, args...) +} + +func (l *BufferLogger) Notice(args ...interface{}) { + l.write(logging.NOTICE, args...) +} + +func (l *BufferLogger) Noticef(format string, args ...interface{}) { + l.writef(logging.NOTICE, format, args...) +} + +func (l *BufferLogger) Info(args ...interface{}) { + l.write(logging.INFO, args...) +} + +func (l *BufferLogger) Infof(format string, args ...interface{}) { + l.writef(logging.INFO, format, args...) +} + +func (l *BufferLogger) Debug(args ...interface{}) { + l.write(logging.DEBUG, args...) +} + +func (l *BufferLogger) Debugf(format string, args ...interface{}) { + l.writef(logging.DEBUG, format, args...) +} diff --git a/core/util/testutil/buffer_logger_test.go b/core/util/testutil/buffer_logger_test.go new file mode 100644 index 0000000..55dee6d --- /dev/null +++ b/core/util/testutil/buffer_logger_test.go @@ -0,0 +1,43 @@ +package testutil + +import ( + "io" + "testing" + + "github.com/op/go-logging" + "github.com/stretchr/testify/suite" +) + +type BufferLoggerTest struct { + suite.Suite + logger BufferedLogger +} + +func TestBufferLogger(t *testing.T) { + suite.Run(t, new(BufferLoggerTest)) +} + +func (t *BufferLoggerTest) SetupSuite() { + t.logger = NewBufferedLogger() +} + +func (t *BufferLoggerTest) SetupTest() { + t.logger.Reset() +} + +func (t *BufferLoggerTest) Log() string { + return t.logger.String() +} + +func (t *BufferLoggerTest) Test_Read() { + t.logger.Debug("test") + + data, err := io.ReadAll(t.logger) + t.Require().NoError(err) + t.Assert().Equal([]byte(logging.DEBUG.String()+" => test\n"), data) +} + +func (t *BufferLoggerTest) Test_Bytes() { + t.logger.Debug("test") + t.Assert().Equal([]byte(logging.DEBUG.String()+" => test\n"), t.logger.Bytes()) +} diff --git a/core/util/testutil/gock.go b/core/util/testutil/gock.go new file mode 100644 index 0000000..923cab9 --- /dev/null +++ b/core/util/testutil/gock.go @@ -0,0 +1,69 @@ +package testutil + +import ( + "fmt" + "io" + "net/http" + + "gopkg.in/h2non/gock.v1" +) + +// UnmatchedRequestsTestingT contains all of *testing.T methods which are needed for AssertNoUnmatchedRequests. +type UnmatchedRequestsTestingT interface { + Log(...interface{}) + Logf(string, ...interface{}) + FailNow() +} + +// AssertNoUnmatchedRequests check that gock didn't receive any request that it was not able to match. +// It will print out an entire request data for every unmatched request. +func AssertNoUnmatchedRequests(t UnmatchedRequestsTestingT) { + if gock.HasUnmatchedRequest() { // nolint:nestif + t.Log("gock has unmatched requests. their contents will be dumped here.\n") + + for _, r := range gock.GetUnmatchedRequests() { + printRequestData(t, r) + fmt.Println() + } + + t.FailNow() + } +} + +func printRequestData(t UnmatchedRequestsTestingT, r *http.Request) { + t.Logf("%s %s %s\n", r.Proto, r.Method, r.URL.String()) + t.Logf(" > RemoteAddr: %s\n", r.RemoteAddr) + t.Logf(" > Host: %s\n", r.Host) + t.Logf(" > Length: %d\n", r.ContentLength) + + for _, encoding := range r.TransferEncoding { + t.Logf(" > Transfer-Encoding: %s\n", encoding) + } + + for header, values := range r.Header { + for _, value := range values { + t.Logf("[header] %s: %s\n", header, value) + } + } + + if r.Body == nil { + t.Log("No body is present.") + } else { + data, err := io.ReadAll(r.Body) + if err != nil { + t.Logf("Cannot read body: %s\n", err) + } + + if len(data) == 0 { + t.Log("Body is empty.") + } else { + t.Logf("Body:\n%s\n", string(data)) + } + } + + for header, values := range r.Trailer { + for _, value := range values { + t.Logf("[trailer header] %s: %s\n", header, value) + } + } +} diff --git a/core/util/testutil/gock_test.go b/core/util/testutil/gock_test.go new file mode 100644 index 0000000..23cd421 --- /dev/null +++ b/core/util/testutil/gock_test.go @@ -0,0 +1,92 @@ +package testutil + +import ( + "bytes" + "fmt" + "net/http" + "testing" + + "github.com/stretchr/testify/suite" + "gopkg.in/h2non/gock.v1" +) + +type testingTMock struct { + logs *bytes.Buffer + failed bool +} + +func (t *testingTMock) Log(args ...interface{}) { + t.logs.WriteString(fmt.Sprintln(append([]interface{}{"=>"}, args...)...)) +} + +func (t *testingTMock) Logf(format string, args ...interface{}) { + t.logs.WriteString(fmt.Sprintf(" => "+format, args...)) +} + +func (t *testingTMock) FailNow() { + t.failed = true +} + +func (t *testingTMock) Reset() { + t.logs.Reset() + t.failed = false +} + +func (t *testingTMock) Logs() string { + return t.logs.String() +} + +func (t *testingTMock) Failed() bool { + return t.failed +} + +type AssertNoUnmatchedRequestsTest struct { + suite.Suite + tmock *testingTMock +} + +func TestAssertNoUnmatchedRequests(t *testing.T) { + suite.Run(t, new(AssertNoUnmatchedRequestsTest)) +} + +func (t *AssertNoUnmatchedRequestsTest) SetupSuite() { + t.tmock = &testingTMock{logs: &bytes.Buffer{}} +} + +func (t *AssertNoUnmatchedRequestsTest) SetupTest() { + t.tmock.Reset() + gock.CleanUnmatchedRequest() +} + +func (t *AssertNoUnmatchedRequestsTest) Test_OK() { + AssertNoUnmatchedRequests(t.tmock) + + t.Assert().Empty(t.tmock.Logs()) + t.Assert().False(t.tmock.Failed()) +} + +func (t *AssertNoUnmatchedRequestsTest) Test_HasUnmatchedRequests() { + defer gock.Off() + + gock.New("https://example.com"). + Post("/dial"). + MatchHeader("X-Client-Data", "something"). + BodyString("something in body"). + Reply(http.StatusOK) + + _, _ = http.Get("https://example.com/nil") + + AssertNoUnmatchedRequests(t.tmock) + + t.Assert().True(gock.HasUnmatchedRequest()) + t.Assert().NotEmpty(t.tmock.Logs()) + t.Assert().Equal(`=> gock has unmatched requests. their contents will be dumped here. + + => HTTP/1.1 GET https://example.com/nil + => > RemoteAddr: + => > Host: example.com + => > Length: 0 +=> No body is present. +`, t.tmock.Logs()) + t.Assert().True(t.tmock.Failed()) +} diff --git a/core/util/testutil/gorm.go b/core/util/testutil/gorm.go new file mode 100644 index 0000000..b53217f --- /dev/null +++ b/core/util/testutil/gorm.go @@ -0,0 +1,55 @@ +package testutil + +import ( + "database/sql" + "fmt" + + "github.com/jinzhu/gorm" +) + +// DeleteCreatedEntities sets up GORM `onCreate` hook and return a function that can be deferred to +// remove all the entities created after the hook was set up. +// You can use it like this: +// +// func TestSomething(t *testing.T){ +// db, _ := gorm.Open(...) +// cleaner := DeleteCreatedEntities(db) +// defer cleaner() +// }. +func DeleteCreatedEntities(db *gorm.DB) func() { // nolint + type entity struct { + key interface{} + table string + keyname string + } + var entries []entity + hookName := "cleanupHook" + + db.Callback().Create().After("gorm:create").Register(hookName, func(scope *gorm.Scope) { + fmt.Printf("Inserted entities of %s with %s=%v\n", scope.TableName(), scope.PrimaryKey(), scope.PrimaryKeyValue()) + entries = append(entries, entity{table: scope.TableName(), keyname: scope.PrimaryKey(), key: scope.PrimaryKeyValue()}) + }) + return func() { + // Remove the hook once we're done + defer db.Callback().Create().Remove(hookName) + // Find out if the current db object is already a transaction + _, inTransaction := db.CommonDB().(*sql.Tx) + tx := db + if !inTransaction { + tx = db.Begin() + } + // Loop from the end. It is important that we delete the entries in the + // reverse order of their insertion + for i := len(entries) - 1; i >= 0; i-- { + entry := entries[i] + fmt.Printf("Deleting entities from '%s' table with key %v\n", entry.table, entry.key) + if err := tx.Table(entry.table).Where(entry.keyname+" = ?", entry.key).Delete("").Error; err != nil { + panic(err) + } + } + + if !inTransaction { + tx.Commit() + } + } +} diff --git a/core/util/testutil/gorm_test.go b/core/util/testutil/gorm_test.go new file mode 100644 index 0000000..5aa48f9 --- /dev/null +++ b/core/util/testutil/gorm_test.go @@ -0,0 +1,49 @@ +package testutil + +import ( + "regexp" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/jinzhu/gorm" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testUser struct { + Username string `gorm:"column:username; type:varchar(255); not null;"` + ID int `gorm:"primary_key"` +} + +func TestDeleteCreatedEntities(t *testing.T) { + sqlDB, mock, err := sqlmock.New() + require.NoError(t, err) + + db, err := gorm.Open("postgres", sqlDB) + require.NoError(t, err) + + mock. + ExpectExec(regexp.QuoteMeta(`CREATE TABLE "test_users" ("username" varchar(255) NOT NULL,"id" serial , PRIMARY KEY ("id"))`)). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectBegin() + mock. + ExpectQuery(regexp.QuoteMeta(`INSERT INTO "test_users" ("username") VALUES ($1) RETURNING "test_users"."id"`)). + WithArgs("user"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow("1")) + mock.ExpectCommit() + mock.ExpectBegin() + mock. + ExpectExec(regexp.QuoteMeta(`DELETE FROM "test_users" WHERE (id = $1)`)). + WithArgs(1). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + + cleaner := DeleteCreatedEntities(db) + + require.NoError(t, db.AutoMigrate(&testUser{}).Error) + require.NoError(t, db.Create(&testUser{Username: "user"}).Error) + + cleaner() + + assert.NoError(t, mock.ExpectationsWereMet()) +} diff --git a/core/translations_extractor.go b/core/util/testutil/translations_extractor.go similarity index 98% rename from core/translations_extractor.go rename to core/util/testutil/translations_extractor.go index 9fef5d3..4abb1a8 100644 --- a/core/translations_extractor.go +++ b/core/util/testutil/translations_extractor.go @@ -1,4 +1,4 @@ -package core +package testutil import ( "errors" @@ -92,9 +92,8 @@ func (t *TranslationsExtractor) loadYAMLFile(fileName string) (map[string]interf func (t *TranslationsExtractor) loadYAML(fileName string) (map[string]interface{}, error) { if t.TranslationsPath != "" { return t.loadYAMLFile(filepath.Join(t.TranslationsPath, fileName)) - } else { - return t.loadYAMLFromFS(fileName) } + return t.loadYAMLFromFS(fileName) } // GetMapKeys returns sorted map keys from map[string]interface{} - useful to check keys in several translation files. diff --git a/core/translations_extractor_test.go b/core/util/testutil/translations_extractor_test.go similarity index 99% rename from core/translations_extractor_test.go rename to core/util/testutil/translations_extractor_test.go index 13bac08..11de554 100644 --- a/core/translations_extractor_test.go +++ b/core/util/testutil/translations_extractor_test.go @@ -1,4 +1,4 @@ -package core +package testutil import ( "io/ioutil" diff --git a/core/utils.go b/core/util/utils.go similarity index 83% rename from core/utils.go rename to core/util/utils.go index ac2cddd..f26f5a6 100644 --- a/core/utils.go +++ b/core/util/utils.go @@ -1,4 +1,4 @@ -package core +package util import ( // nolint:gosec @@ -18,6 +18,17 @@ import ( "github.com/aws/aws-sdk-go/service/s3/s3manager" retailcrm "github.com/retailcrm/api-client-go/v2" v1 "github.com/retailcrm/mg-transport-api-client-go/v1" + + "github.com/retailcrm/mg-transport-core/v2/core/config" + + "github.com/retailcrm/mg-transport-core/v2/core/logger" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" +) + +var ( + markdownSymbols = []string{"*", "_", "`", "["} + slashRegex = regexp.MustCompile(`/+$`) ) var DefaultScopes = []string{ @@ -78,28 +89,28 @@ var defaultCurrencies = map[string]string{ // Utils service object. type Utils struct { - IsDebug bool - TokenCounter uint32 - ConfigAWS ConfigAWS - Logger LoggerInterface + Logger logger.Logger slashRegex *regexp.Regexp + AWS config.AWS + TokenCounter uint32 + IsDebug bool } // NewUtils will create new Utils instance. -func NewUtils(awsConfig ConfigAWS, logger LoggerInterface, debug bool) *Utils { +func NewUtils(awsConfig config.AWS, logger logger.Logger, debug bool) *Utils { return &Utils{ IsDebug: debug, - ConfigAWS: awsConfig, + AWS: awsConfig, Logger: logger, TokenCounter: 0, slashRegex: slashRegex, } } -// resetUtils. -func (u *Utils) resetUtils(awsConfig ConfigAWS, debug bool, tokenCounter uint32) { +// ResetUtils resets the utils inner state. +func (u *Utils) ResetUtils(awsConfig config.AWS, debug bool, tokenCounter uint32) { u.TokenCounter = tokenCounter - u.ConfigAWS = awsConfig + u.AWS = awsConfig u.IsDebug = debug u.slashRegex = slashRegex } @@ -124,7 +135,7 @@ func (u *Utils) GetAPIClient(url, key string, scopes []string) (*retailcrm.Clien if res := u.checkScopes(cr.Scopes, scopes); len(res) != 0 { u.Logger.Error(url, status, res) - return nil, http.StatusBadRequest, NewInsufficientScopesErr(res) + return nil, http.StatusBadRequest, errorutil.NewInsufficientScopesErr(res) } return client, 0, nil @@ -153,10 +164,10 @@ func (u *Utils) checkScopes(scopes []string, scopesRequired []string) []string { func (u *Utils) UploadUserAvatar(url string) (picURLs3 string, err error) { s3Config := &aws.Config{ Credentials: credentials.NewStaticCredentials( - u.ConfigAWS.AccessKeyID, - u.ConfigAWS.SecretAccessKey, + u.AWS.AccessKeyID, + u.AWS.SecretAccessKey, ""), - Region: aws.String(u.ConfigAWS.Region), + Region: aws.String(u.AWS.Region), } s := session.Must(session.NewSession(s3Config)) @@ -174,10 +185,10 @@ func (u *Utils) UploadUserAvatar(url string) (picURLs3 string, err error) { } result, err := uploader.Upload(&s3manager.UploadInput{ - Bucket: aws.String(u.ConfigAWS.Bucket), - Key: aws.String(fmt.Sprintf("%v/%v.jpg", u.ConfigAWS.FolderName, u.GenerateToken())), + Bucket: aws.String(u.AWS.Bucket), + Key: aws.String(fmt.Sprintf("%v/%v.jpg", u.AWS.FolderName, u.GenerateToken())), Body: resp.Body, - ContentType: aws.String(u.ConfigAWS.ContentType), + ContentType: aws.String(u.AWS.ContentType), ACL: aws.String("public-read"), }) if err != nil { @@ -228,7 +239,7 @@ func GetEntitySHA1(v interface{}) (hash string, err error) { // ReplaceMarkdownSymbols will remove markdown symbols from text. func ReplaceMarkdownSymbols(s string) string { for _, v := range markdownSymbols { - s = strings.Replace(s, v, "\\"+v, -1) + s = strings.ReplaceAll(s, v, "\\"+v) } return s diff --git a/core/utils_test.go b/core/util/utils_test.go similarity index 92% rename from core/utils_test.go rename to core/util/utils_test.go index d77c122..c3868eb 100644 --- a/core/utils_test.go +++ b/core/util/utils_test.go @@ -1,4 +1,4 @@ -package core +package util import ( "encoding/json" @@ -8,13 +8,19 @@ import ( "testing" "time" - "github.com/h2non/gock" "github.com/op/go-logging" retailcrm "github.com/retailcrm/api-client-go/v2" v1 "github.com/retailcrm/mg-transport-api-client-go/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + "gopkg.in/h2non/gock.v1" + + "github.com/retailcrm/mg-transport-core/v2/core/config" + + "github.com/retailcrm/mg-transport-core/v2/core/logger" + + "github.com/retailcrm/mg-transport-core/v2/core/util/errorutil" ) var ( @@ -32,8 +38,8 @@ func mgClient() *v1.MgClient { } func (u *UtilsTest) SetupSuite() { - logger := NewLogger("code", logging.DEBUG, DefaultLogFormatter()) - awsConfig := ConfigAWS{ + logger := logger.NewStandard("code", logging.DEBUG, logger.DefaultLogFormatter()) + awsConfig := config.AWS{ AccessKeyID: "access key id (will be removed)", SecretAccessKey: "secret access key", Region: "region", @@ -47,15 +53,15 @@ func (u *UtilsTest) SetupSuite() { } func (u *UtilsTest) Test_ResetUtils() { - assert.Equal(u.T(), "access key id (will be removed)", u.utils.ConfigAWS.AccessKeyID) + assert.Equal(u.T(), "access key id (will be removed)", u.utils.AWS.AccessKeyID) assert.Equal(u.T(), uint32(12346), u.utils.TokenCounter) assert.False(u.T(), u.utils.IsDebug) - awsConfig := u.utils.ConfigAWS + awsConfig := u.utils.AWS awsConfig.AccessKeyID = "access key id" - u.utils.resetUtils(awsConfig, true, 0) + u.utils.ResetUtils(awsConfig, true, 0) - assert.Equal(u.T(), "access key id", u.utils.ConfigAWS.AccessKeyID) + assert.Equal(u.T(), "access key id", u.utils.AWS.AccessKeyID) assert.Equal(u.T(), uint32(0), u.utils.TokenCounter) assert.True(u.T(), u.utils.IsDebug) } @@ -110,7 +116,7 @@ func (u *UtilsTest) Test_GetAPIClient_FailAPICredentials() { _, status, err := u.utils.GetAPIClient(testCRMURL, "key", DefaultScopes) assert.Equal(u.T(), http.StatusBadRequest, status) if assert.NotNil(u.T(), err) { - assert.True(u.T(), errors.Is(err, ErrInsufficientScopes)) + assert.True(u.T(), errors.Is(err, errorutil.ErrInsufficientScopes)) } } diff --git a/core/validator_test.go b/core/validator_test.go index acceb4a..9b8df61 100644 --- a/core/validator_test.go +++ b/core/validator_test.go @@ -8,6 +8,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" + + "github.com/retailcrm/mg-transport-core/v2/core/db/models" ) type ValidatorSuite struct { @@ -49,7 +51,7 @@ func (s *ValidatorSuite) Test_ValidationFails() { } for _, domain := range crmDomains { - conn := Connection{ + conn := models.Connection{ Key: "key", URL: domain, } @@ -81,7 +83,7 @@ func (s *ValidatorSuite) Test_ValidationSuccess() { } for _, domain := range crmDomains { - conn := Connection{ + conn := models.Connection{ Key: "key", URL: domain, } diff --git a/go.mod b/go.mod index 4b51bb6..77479ba 100644 --- a/go.mod +++ b/go.mod @@ -1,4 +1,4 @@ -module github.com/retailcrm/mg-transport-core +module github.com/retailcrm/mg-transport-core/v2 go 1.12 @@ -16,7 +16,6 @@ require ( github.com/golang/protobuf v1.5.2 // indirect github.com/gorilla/securecookie v1.1.1 github.com/gorilla/sessions v1.2.0 - github.com/h2non/gock v1.0.10 github.com/jessevdk/go-flags v1.4.0 github.com/jinzhu/gorm v1.9.11 github.com/json-iterator/go v1.1.11 // indirect @@ -36,6 +35,7 @@ require ( google.golang.org/protobuf v1.27.1 // indirect gopkg.in/check.v1 v1.0.0-20200227125254-8fa46927fb4f // indirect gopkg.in/gormigrate.v1 v1.6.0 + gopkg.in/h2non/gock.v1 v1.1.2 gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect ) diff --git a/go.sum b/go.sum index 42274f1..b6e4bf0 100644 --- a/go.sum +++ b/go.sum @@ -159,8 +159,6 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ= github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= -github.com/h2non/gock v1.0.10 h1:EzHYzKKSLN4xk0w193uAy3tp8I3+L1jmaI2Mjg4lCgU= -github.com/h2non/gock v1.0.10/go.mod h1:CZMcB0Lg5IWnr9bF79pPMg9WeV6WumxQiUJ1UvdO1iE= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= @@ -249,10 +247,6 @@ github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y8 github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk= github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= -github.com/retailcrm/api-client-go/v2 v2.0.1 h1:wyM0F1VTSJPO8PVEXB0u7s6ZEs0xRCnUu7YdQ1E4UZ8= -github.com/retailcrm/api-client-go/v2 v2.0.1/go.mod h1:1yTZl9+gd3+/k0kAJe7sYvC+mL4fqMwIwtnSgSWZlkQ= -github.com/retailcrm/api-client-go/v2 v2.0.2 h1:oFQycGqwcvfgW2JrbeWmPjxH7Wmh9j762c4FRxCDGNs= -github.com/retailcrm/api-client-go/v2 v2.0.2/go.mod h1:1yTZl9+gd3+/k0kAJe7sYvC+mL4fqMwIwtnSgSWZlkQ= github.com/retailcrm/api-client-go/v2 v2.0.3 h1:7oKwOZgRLM7eEJUvFNhzfnyIJVomy80ffOEBdYpQRIs= github.com/retailcrm/api-client-go/v2 v2.0.3/go.mod h1:1yTZl9+gd3+/k0kAJe7sYvC+mL4fqMwIwtnSgSWZlkQ= github.com/retailcrm/mg-transport-api-client-go v1.1.32 h1:IBPltSoD5q2PPZJbNC/prK5F9rEVPXVx/ZzDpi7HKhs=