mg-transport-core/core/db/migrate.go

278 lines
5.7 KiB
Go
Raw Permalink Normal View History

package db
2019-09-18 13:40:36 +03:00
import (
2019-09-19 14:16:52 +03:00
"fmt"
2019-09-18 13:40:36 +03:00
"sort"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"gopkg.in/gormigrate.v1"
)
// migrations default GORMigrate tool.
2019-09-18 13:40:36 +03:00
var migrations *Migrate
// Migrate tool, decorates gormigrate.Migration in order to provide better interface & versioning.
2019-09-18 13:40:36 +03:00
type Migrate struct {
db *gorm.DB
first *gormigrate.Migration
migrations map[string]*gormigrate.Migration
GORMigrate *gormigrate.Gormigrate
versions []string
2019-09-18 13:40:36 +03:00
prepared bool
}
// MigrationInfo with migration info.
2019-09-18 13:40:36 +03:00
type MigrationInfo struct {
ID string `gorm:"column:id; type:varchar(255)"`
}
// TableName for MigrationInfo.
2019-09-18 13:40:36 +03:00
func (MigrationInfo) TableName() string {
return "migrations"
}
// Migrations returns default migrate.
2019-09-18 13:40:36 +03:00
func Migrations() *Migrate {
if migrations == nil {
migrations = &Migrate{
db: nil,
prepared: false,
migrations: map[string]*gormigrate.Migration{},
}
}
return migrations
}
// Add GORMigrate to migrate.
2019-09-18 13:40:36 +03:00
func (m *Migrate) Add(migration *gormigrate.Migration) {
if migration == nil {
return
}
m.migrations[migration.ID] = migration
}
// SetDB to migrate.
2019-09-18 13:40:36 +03:00
func (m *Migrate) SetDB(db *gorm.DB) *Migrate {
m.db = db
return m
}
// Migrate all, including schema initialization.
2019-09-18 13:40:36 +03:00
func (m *Migrate) Migrate() error {
if err := m.prepareMigrations(); err != nil {
return err
}
if len(m.migrations) > 0 {
return m.GORMigrate.Migrate()
}
return nil
}
// Rollback all migrations.
2019-09-18 13:40:36 +03:00
func (m *Migrate) Rollback() error {
if err := m.prepareMigrations(); err != nil {
return err
}
2019-09-19 14:16:52 +03:00
if m.first == nil {
return errors.New("abnormal termination: first migration is nil")
}
2019-12-12 10:08:26 +03:00
if err := m.GORMigrate.RollbackTo(m.first.ID); err != nil {
2019-12-12 09:35:05 +03:00
return err
2019-12-12 10:08:26 +03:00
}
if err := m.GORMigrate.RollbackMigration(m.first); err != nil {
2019-09-18 13:40:36 +03:00
return err
}
2019-12-12 10:08:26 +03:00
return nil
2019-09-18 13:40:36 +03:00
}
// MigrateTo specified version.
2019-09-18 13:40:36 +03:00
func (m *Migrate) MigrateTo(version string) error {
if err := m.prepareMigrations(); err != nil {
return err
}
current := m.Current()
switch {
case current > version:
return m.GORMigrate.RollbackTo(version)
case current < version:
return m.GORMigrate.MigrateTo(version)
default:
return nil
}
}
// MigrateNextTo migrate to next version from specified version.
2019-09-18 13:40:36 +03:00
func (m *Migrate) MigrateNextTo(version string) error {
if err := m.prepareMigrations(); err != nil {
return err
}
if next, err := m.NextFrom(version); err == nil {
current := m.Current()
switch {
case current < next:
return m.GORMigrate.MigrateTo(next)
2019-09-19 14:16:52 +03:00
case current > next:
return fmt.Errorf("current migration version '%s' is higher than fetched version '%s'", current, next)
2019-09-18 13:40:36 +03:00
default:
return nil
}
} else {
return nil
}
}
// MigratePreviousTo migrate to previous version from specified version.
2019-09-18 13:40:36 +03:00
func (m *Migrate) MigratePreviousTo(version string) error {
if err := m.prepareMigrations(); err != nil {
return err
}
if prev, err := m.PreviousFrom(version); err == nil {
current := m.Current()
switch {
case current > prev:
return m.GORMigrate.RollbackTo(prev)
case current < prev:
return fmt.Errorf("current migration version '%s' is lower than fetched version '%s'", current, prev)
2019-09-19 14:16:52 +03:00
case prev == "0":
return m.GORMigrate.RollbackMigration(m.first)
2019-09-18 13:40:36 +03:00
default:
return nil
}
} else {
return nil
}
}
// RollbackTo specified version.
2019-09-18 13:40:36 +03:00
func (m *Migrate) RollbackTo(version string) error {
if err := m.prepareMigrations(); err != nil {
return err
}
return m.GORMigrate.RollbackTo(version)
}
// Current migration version.
2019-09-18 13:40:36 +03:00
func (m *Migrate) Current() string {
var migrationInfo MigrationInfo
if m.db == nil {
2019-09-19 14:16:52 +03:00
fmt.Println("warning => db is nil - cannot return migration version")
2019-09-18 13:40:36 +03:00
return "0"
}
if !m.db.HasTable(MigrationInfo{}) {
2019-09-19 14:16:52 +03:00
if err := m.db.CreateTable(MigrationInfo{}).Error; err == nil {
fmt.Println("info => created migrations table")
} else {
panic(err.Error())
}
2019-09-18 13:40:36 +03:00
return "0"
}
2019-12-12 10:08:26 +03:00
if err := m.db.Last(&migrationInfo).Error; err != nil {
fmt.Printf("warning => cannot fetch migration version: %s\n", err.Error())
return "0"
2019-09-18 13:40:36 +03:00
}
2019-12-12 09:35:05 +03:00
2019-12-12 10:08:26 +03:00
return migrationInfo.ID
2019-09-18 13:40:36 +03:00
}
// NextFrom returns next version from passed version.
2019-09-18 13:40:36 +03:00
func (m *Migrate) NextFrom(version string) (string, error) {
for key, ver := range m.versions {
if ver == version {
if key < (len(m.versions) - 1) {
return m.versions[key+1], nil
}
2019-12-12 09:35:05 +03:00
return "", errors.New("this is last migration")
2019-09-18 13:40:36 +03:00
}
}
return "", errors.New("cannot find specified migration")
}
// PreviousFrom returns previous version from passed version.
2019-09-18 13:40:36 +03:00
func (m *Migrate) PreviousFrom(version string) (string, error) {
for key, ver := range m.versions {
if ver == version {
if key > 0 {
return m.versions[key-1], nil
}
2019-12-12 09:35:05 +03:00
return "0", nil
2019-09-18 13:40:36 +03:00
}
}
return "", errors.New("cannot find specified migration")
}
// Close db connection.
2019-09-18 13:40:36 +03:00
func (m *Migrate) Close() error {
return m.db.Close()
}
// prepareMigrations prepare migrate.
2019-09-18 13:40:36 +03:00
func (m *Migrate) prepareMigrations() error {
var (
keys []string
migrations []*gormigrate.Migration
)
if m.db == nil {
return errors.New("db must not be nil")
}
if m.prepared {
return nil
}
i := 0
keys = make([]string, len(m.migrations))
2019-09-18 13:40:36 +03:00
for key := range m.migrations {
keys[i] = key
i++
2019-09-18 13:40:36 +03:00
}
sort.Strings(keys)
m.versions = keys
if len(keys) > 0 {
if i, ok := m.migrations[keys[0]]; ok {
m.first = i
}
}
for _, key := range keys {
if i, ok := m.migrations[key]; ok {
migrations = append(migrations, i)
}
}
options := &gormigrate.Options{
TableName: gormigrate.DefaultOptions.TableName,
IDColumnName: gormigrate.DefaultOptions.IDColumnName,
IDColumnSize: gormigrate.DefaultOptions.IDColumnSize,
UseTransaction: true,
ValidateUnknownMigrations: true,
}
m.GORMigrate = gormigrate.New(m.db, options, migrations)
m.prepared = true
return nil
}