package ssh import ( "context" "net" "path" "strconv" "sync" "github.com/Neur0toxine/sshpoke/internal/config" "github.com/Neur0toxine/sshpoke/internal/server/driver/base" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshtun" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" "github.com/Neur0toxine/sshpoke/internal/server/driver/util" "github.com/Neur0toxine/sshpoke/pkg/dto" "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" ) type SSH struct { base.Base params Params auth []ssh.AuthMethod conns map[string]conn rw sync.RWMutex wg sync.WaitGroup } type conn struct { ctx context.Context cancel func() tun *sshtun.Tunnel } func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) { drv := &SSH{ Base: base.New(ctx, name), conns: make(map[string]conn), } if err := util.UnmarshalParams(params, &drv.params); err != nil { return nil, err } drv.populateFromSSHConfig() drv.auth = drv.authenticators() return drv, nil } func (d *SSH) forward(val sshtun.Forward) conn { tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.DisableRemoteHostResolve, val, d.auth, d.Log()) ctx, cancel := context.WithCancel(d.Context()) go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String()))) return conn{ctx: ctx, cancel: cancel, tun: tun} } func (d *SSH) populateFromSSHConfig() { if d.params.Auth.Directory == "" { return } cfg, err := parseSSHConfig(types.SmartPath(path.Join(string(d.params.Auth.Directory), "config"))) if err != nil { return } if user, err := cfg.Get(d.params.Address, "User"); err == nil && user != "" { d.params.Auth.User = user } if usePass, err := cfg.Get(d.params.Address, "PasswordAuthentication"); err == nil && usePass == "yes" { d.params.Auth.Type = types.AuthTypePassword } if keyfile, err := cfg.Get(d.params.Address, "IdentityFile"); err == nil && keyfile != "" { resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false) if err == nil { d.params.Auth.Type = types.AuthTypeKey d.params.Auth.Keyfile = resolvedKeyFile } } } func (d *SSH) Handle(event dto.Event) error { defer d.rw.Unlock() d.rw.Lock() switch event.Type { case dto.EventStart: conn := d.forward(sshtun.Forward{ Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))), Remote: d.remoteEndpoint(event.Container.RemoteHost), }) d.conns[event.Container.ID] = conn d.wg.Add(1) case dto.EventStop: conn, found := d.conns[event.Container.ID] if !found { return nil } conn.cancel() delete(d.conns, event.Container.ID) d.wg.Done() case dto.EventShutdown: for id, conn := range d.conns { conn.cancel() delete(d.conns, id) d.wg.Done() } } return nil } func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint { port := int(d.params.ForwardPort) if port == 0 { port = 80 } return sshtun.Endpoint{ Host: remoteHost, Port: port, } } func (d *SSH) Driver() config.DriverType { return config.DriverSSH } func (d *SSH) WaitForShutdown() { go d.Handle(dto.Event{Type: dto.EventShutdown}) d.wg.Wait() } func (d *SSH) authenticators() []ssh.AuthMethod { auth := d.authenticator() if auth == nil { return nil } return []ssh.AuthMethod{auth} } func (d *SSH) authenticator() ssh.AuthMethod { switch d.params.Auth.Type { case types.AuthTypePasswordless: return sshtun.AuthPassword("") case types.AuthTypePassword: return sshtun.AuthPassword(d.params.Auth.Password) case types.AuthTypeKey: if d.params.Auth.Keyfile != "" { keyAuth, err := sshtun.AuthKeyFile( types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile))) if err != nil { return nil } return keyAuth } dirAuth, err := sshtun.AuthKeyDir(d.params.Auth.Directory) if err != nil { return nil } return dirAuth } return nil }