diff --git a/cmd/root.go b/cmd/root.go index be1a781..f588601 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -5,7 +5,6 @@ import ( "os" "os/signal" "syscall" - "time" "github.com/Neur0toxine/sshpoke/internal/api/plugin" "github.com/Neur0toxine/sshpoke/internal/api/rest" @@ -46,12 +45,8 @@ var rootCmd = &cobra.Command{ config.Rehash() logger.Default.Info("configuration has been updated, restarting the app...") cancel() - go server.DefaultManager.WaitForShutdown() - ctx, innerCancel := context.WithTimeout(context.Background(), 2*time.Second) - defer innerCancel() - select { - case <-ctx.Done(): - } + server.DefaultManager.Wait() + docker.Default.Wait() initApp() }) @@ -179,7 +174,8 @@ func runDockerEventListener(ctx context.Context) { func makeShutdownFunc(cancel func()) func(os.Signal) { return func(sig os.Signal) { cancel() - server.DefaultManager.WaitForShutdown() + server.DefaultManager.Wait() + docker.Default.Wait() logger.Sugar.Infof("received %s, exiting...", sig) os.Exit(0) } diff --git a/internal/docker/api.go b/internal/docker/api.go index fabe9ff..3f4314e 100644 --- a/internal/docker/api.go +++ b/internal/docker/api.go @@ -19,6 +19,7 @@ var Default *Docker type Docker struct { cli *client.Client ctx context.Context + wait chan struct{} services map[smarttypes.MatchableString]config.ServiceLabels defaultServer string } @@ -35,6 +36,7 @@ func New(ctx context.Context, services []config.Service, defaultServer string) ( return &Docker{ cli: cli, ctx: ctx, + wait: make(chan struct{}), services: servicesMap, defaultServer: defaultServer, }, nil @@ -104,7 +106,7 @@ func (d *Docker) Listen() (chan dto.Event, error) { select { case event := <-eventSource: eventType := dto.TypeFromAction(event.Action) - if (eventType != dto.EventStart && eventType != dto.EventStop) || !actorEnabled(event.Actor) { + if (eventType != dto.EventStart && eventType != dto.EventStop) || !actorEnabled(event.Actor).Bool() { continue } container, err := d.cli.ContainerList(d.ctx, types.ContainerListOptions{ @@ -139,6 +141,7 @@ func (d *Docker) Listen() (chan dto.Event, error) { output <- newEvent case err := <-errSource: if errors.Is(err, context.Canceled) { + d.wait <- struct{}{} logger.Sugar.Debug("stopping docker event listener...") return } @@ -151,3 +154,7 @@ func (d *Docker) Listen() (chan dto.Event, error) { return output, nil } + +func (d *Docker) Wait() { + <-d.wait +} diff --git a/internal/docker/convert.go b/internal/docker/convert.go index 80c8a47..98661dd 100644 --- a/internal/docker/convert.go +++ b/internal/docker/convert.go @@ -26,14 +26,14 @@ type labelsConfig struct { func actorEnabled(actor events.Actor) smarttypes.BoolStr { label, ok := actor.Attributes["sshpoke.enable"] if !ok { - return false + return smarttypes.AsBoolStr(false) } - return smarttypes.BoolFromStr(label) + return smarttypes.BoolStr(label) } func populateLabelsFromConfig(labels *labelsConfig, config *config.ServiceLabels) { - if labels.Enable != config.Enable { - labels.Enable = config.Enable + if labels.Enable.Bool() != config.Enable.Bool() { + labels.Enable = smarttypes.AsBoolStr(config.Enable.Bool()) } if labels.Server != config.Server { labels.Server = config.Server @@ -59,7 +59,7 @@ func dockerContainerToInternal( if configLabels != nil { populateLabelsFromConfig(&labels, configLabels) } - if !labels.Enable { + if !labels.Enable.Bool() { logger.Sugar.Debugf("skipping container %s because sshpoke is not enabled for it", container.ID) return result, false } diff --git a/internal/server/manager.go b/internal/server/manager.go index 9440751..1589594 100644 --- a/internal/server/manager.go +++ b/internal/server/manager.go @@ -203,7 +203,7 @@ func (m *Manager) PluginByToken(token string) plugin.Plugin { return server } -func (m *Manager) WaitForShutdown() { +func (m *Manager) Wait() { defer m.rw.RUnlock() m.rw.RLock() for _, srv := range m.servers { diff --git a/pkg/smarttypes/bool_str.go b/pkg/smarttypes/bool_str.go index c4c9d41..c376dc8 100644 --- a/pkg/smarttypes/bool_str.go +++ b/pkg/smarttypes/bool_str.go @@ -1,19 +1,16 @@ package smarttypes -type BoolStr bool +import ( + "strconv" +) -func BoolFromStr(str string) BoolStr { - return str == "true" || str == "1" +type BoolStr string + +func AsBoolStr(val bool) BoolStr { + return BoolStr(strconv.FormatBool(val)) } -func (b BoolStr) MarshalText() ([]byte, error) { - if b { - return []byte("true"), nil - } - return []byte("false"), nil -} - -func (b *BoolStr) UnmarshalText(src []byte) error { - *b = BoolFromStr(string(src)) - return nil +func (b BoolStr) Bool() bool { + val, _ := strconv.ParseBool(string(b)) + return val }