214 lines
5.9 KiB
Go
214 lines
5.9 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Neur0toxine/sshpoke/internal/config"
|
|
"github.com/Neur0toxine/sshpoke/internal/docker"
|
|
"github.com/Neur0toxine/sshpoke/internal/logger"
|
|
"github.com/Neur0toxine/sshpoke/internal/server/driver"
|
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
|
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/plugin"
|
|
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type Manager struct {
|
|
rw sync.RWMutex
|
|
forwardsLock sync.Mutex
|
|
servers map[string]base.Driver
|
|
plugins map[string]plugin.Plugin
|
|
statusMap map[string]ServerStatus
|
|
forwards map[string]bool
|
|
statusLock sync.RWMutex
|
|
ctx context.Context
|
|
defaultServer string
|
|
log *zap.SugaredLogger
|
|
}
|
|
|
|
type ServerStatus struct {
|
|
Name string `json:"name"`
|
|
Connections Connections `json:"connections"`
|
|
}
|
|
|
|
var DefaultManager *Manager
|
|
var (
|
|
ErrNoServer = errors.New("server is not specified")
|
|
ErrNoSuchServer = errors.New("server does not exist")
|
|
)
|
|
|
|
func NewManager(ctx context.Context, servers []config.Server, defaultServer string) *Manager {
|
|
m := &Manager{
|
|
ctx: ctx,
|
|
servers: make(map[string]base.Driver),
|
|
plugins: make(map[string]plugin.Plugin),
|
|
statusMap: make(map[string]ServerStatus),
|
|
forwards: make(map[string]bool),
|
|
defaultServer: defaultServer,
|
|
log: logger.Sugar.With("component", "manager"),
|
|
}
|
|
for _, serverConfig := range servers {
|
|
server, err := driver.New(ctx, serverConfig.Name, serverConfig.Driver, serverConfig.Params)
|
|
if err != nil {
|
|
m.log.Errorf("cannot initialize server '%s': %s", serverConfig.Name, err)
|
|
continue
|
|
}
|
|
server.SetEventStatusCallback(m.eventStatusCallback(server.Name()))
|
|
if server.Driver() == config.DriverPlugin {
|
|
pl := server.(plugin.Plugin)
|
|
if pl.Token() == "" {
|
|
m.log.Warnf("server '%s' will not work because it doesn't have a token", pl.Name())
|
|
continue
|
|
}
|
|
existing, found := m.plugins[pl.Token()]
|
|
if found {
|
|
m.log.Fatalw("two plugins cannot have the same token",
|
|
"plugin1", existing.Name(), "plugin2", pl.Name(), "token", pl.Token())
|
|
}
|
|
m.plugins[pl.Token()] = pl
|
|
}
|
|
m.servers[serverConfig.Name] = server
|
|
m.statusMap[serverConfig.Name] = ServerStatus{Name: serverConfig.Name, Connections: make(Connections)}
|
|
}
|
|
go m.runMarkAndSweepForwards()
|
|
return m
|
|
}
|
|
|
|
func (m *Manager) ProcessEvent(event dto.Event) error {
|
|
serverName := event.Container.Server
|
|
if serverName == "" {
|
|
serverName = m.defaultServer
|
|
}
|
|
if serverName == "" {
|
|
return ErrNoServer
|
|
}
|
|
defer m.rw.RUnlock()
|
|
m.rw.RLock()
|
|
srv, ok := m.servers[event.Container.Server]
|
|
if !ok {
|
|
return ErrNoSuchServer
|
|
}
|
|
if err := srv.Handle(event); err != nil {
|
|
return err
|
|
}
|
|
defer m.forwardsLock.Unlock()
|
|
m.forwardsLock.Lock()
|
|
switch event.Type {
|
|
case dto.EventStart:
|
|
m.forwards[m.forwardID(serverName, event.Container.ID)] = false
|
|
case dto.EventStop, dto.EventError, dto.EventShutdown:
|
|
delete(m.forwards, m.forwardID(serverName, event.Container.ID))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *Manager) forwardID(serverName, containerID string) string {
|
|
return serverName + ":" + containerID
|
|
}
|
|
|
|
func (m *Manager) eventStatusCallback(serverName string) base.EventStatusCallback {
|
|
return func(status dto.EventStatus) {
|
|
m.processEventStatus(serverName, status)
|
|
}
|
|
}
|
|
|
|
func (m *Manager) processEventStatus(serverName string, event dto.EventStatus) {
|
|
m.log.Debugw("received EventStatus from server",
|
|
"serverName", serverName, "eventStatus", event)
|
|
item, found := docker.Default.GetContainer(event.ID, true)
|
|
if !found {
|
|
return
|
|
}
|
|
defer m.forwardsLock.Unlock()
|
|
m.forwardsLock.Lock()
|
|
switch event.Type {
|
|
case dto.EventStart:
|
|
defer m.statusLock.Unlock()
|
|
m.statusLock.Lock()
|
|
item.Domain = event.Domain
|
|
m.forwards[m.forwardID(serverName, item.ID)] = false
|
|
m.statusMap[serverName].Connections[item.ID] = item
|
|
case dto.EventStop, dto.EventShutdown, dto.EventError:
|
|
defer m.statusLock.Unlock()
|
|
m.statusLock.Lock()
|
|
item.Domain = ""
|
|
delete(m.forwards, m.forwardID(serverName, item.ID))
|
|
delete(m.statusMap[serverName].Connections, item.ID)
|
|
default:
|
|
return
|
|
}
|
|
}
|
|
|
|
func (m *Manager) StatusMap() map[string]ServerStatus {
|
|
defer m.statusLock.RUnlock()
|
|
m.statusLock.RLock()
|
|
return m.statusMap
|
|
}
|
|
|
|
// runMarkAndSweepForwards runs mark-and-sweep on the started forwards every 10 seconds.
|
|
// This job is necessary because Docker sometimes forgets to notify us that containers
|
|
// were stopped (usually happens when spamming Ctrl+C after `docker compose run`).
|
|
func (m *Manager) runMarkAndSweepForwards() {
|
|
ticker := time.NewTicker(time.Second * 10)
|
|
for {
|
|
select {
|
|
case <-m.ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
m.markAndSweepForwards()
|
|
}
|
|
}
|
|
}
|
|
|
|
// markAndSweepForwards marks stopped containers for removal on the first run and removes them from forwards later.
|
|
// This job will remove containers from forwards if Docker didn't notify us about stopping containers for some reason.
|
|
func (m *Manager) markAndSweepForwards() {
|
|
defer m.forwardsLock.Unlock()
|
|
m.forwardsLock.Lock()
|
|
|
|
for id, state := range m.forwards {
|
|
forwardIDs := strings.Split(id, ":")
|
|
serverName, containerID := forwardIDs[0], forwardIDs[1]
|
|
_, found := docker.Default.GetContainer(containerID, false)
|
|
if found {
|
|
m.forwards[id] = false // unmark
|
|
} else {
|
|
if state {
|
|
err := m.ProcessEvent(dto.Event{
|
|
Type: dto.EventStop,
|
|
Container: dto.Container{
|
|
ID: containerID,
|
|
Server: serverName,
|
|
},
|
|
})
|
|
if err != nil {
|
|
m.log.Warnf("cannot process mark-and-sweep event: %s", err)
|
|
}
|
|
continue
|
|
}
|
|
m.forwards[id] = true // mark
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) PluginByToken(token string) plugin.Plugin {
|
|
server, ok := m.plugins[token]
|
|
if !ok {
|
|
return nil
|
|
}
|
|
return server
|
|
}
|
|
|
|
func (m *Manager) Wait() {
|
|
defer m.rw.RUnlock()
|
|
m.rw.RLock()
|
|
for _, srv := range m.servers {
|
|
srv.WaitForShutdown()
|
|
}
|
|
return
|
|
}
|