From 666daabe958d845e49a6a39708dfa6a0fa46b033 Mon Sep 17 00:00:00 2001 From: Neur0toxine Date: Sun, 19 Nov 2023 13:06:38 +0300 Subject: [PATCH] ssh: pty, keepAlive, domain matcher, refactor --- internal/api/server.go | 2 +- internal/config/model.go | 7 +- internal/server/driver/ssh/driver.go | 67 +++++- internal/server/driver/ssh/params.go | 9 +- internal/server/driver/ssh/regexp.go | 28 +++ internal/server/driver/ssh/sshtun/connect.go | 171 +++++++++++++ internal/server/driver/ssh/sshtun/forward.go | 78 ++++++ .../server/driver/ssh/sshtun/keepalive.go | 62 +++++ internal/server/driver/ssh/sshtun/printer.go | 6 +- internal/server/driver/ssh/sshtun/ssh.go | 224 ------------------ internal/server/driver/ssh/types/commands.go | 6 + internal/server/driver/util/validator.go | 2 +- internal/server/manager.go | 2 + pkg/dto/models.go | 16 ++ 14 files changed, 437 insertions(+), 243 deletions(-) create mode 100644 internal/server/driver/ssh/regexp.go create mode 100644 internal/server/driver/ssh/sshtun/connect.go create mode 100644 internal/server/driver/ssh/sshtun/keepalive.go delete mode 100644 internal/server/driver/ssh/sshtun/ssh.go create mode 100644 internal/server/driver/ssh/types/commands.go diff --git a/internal/api/server.go b/internal/api/server.go index 45f3bb1..e348bac 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -61,7 +61,7 @@ func (p *pluginAPI) receiverForContext(ctx context.Context) plugin.Plugin { } func StartPluginAPI() { - port := config.Default.PluginAPIPort + port := config.Default.API.PluginPort if port == 0 { port = plugin2.DefaultPort } diff --git a/internal/config/model.go b/internal/config/model.go index 138b84d..9d7a377 100644 --- a/internal/config/model.go +++ b/internal/config/model.go @@ -12,12 +12,17 @@ var Default Config type Config struct { Debug bool `mapstructure:"debug"` - PluginAPIPort int `mapstructure:"plugin_api_port" validate:"gte=0,lte=65535"` + API API `mapstructure:"api"` Docker DockerConfig `mapstructure:"docker"` DefaultServer string `mapstructure:"default_server"` Servers []Server `mapstructure:"servers"` } +type API struct { + WebPort int `mapstructure:"web_port" validate:"gte=0,lte=65535"` + PluginPort int `mapstructure:"plugin_port" validate:"gte=0,lte=65535"` +} + type DockerConfig struct { FromEnv *bool `mapstructure:"from_env,omitempty"` CertPath string `mapstructure:"cert_path"` diff --git a/internal/server/driver/ssh/driver.go b/internal/server/driver/ssh/driver.go index 41754d2..b2c1de1 100644 --- a/internal/server/driver/ssh/driver.go +++ b/internal/server/driver/ssh/driver.go @@ -3,8 +3,10 @@ package ssh import ( "context" "errors" + "fmt" "net" "path" + "regexp" "strconv" "sync" @@ -21,11 +23,12 @@ var ErrAlreadyInUse = errors.New("domain is already in use") type SSH struct { base.Base - params Params - auth []ssh.AuthMethod - conns map[string]conn - rw sync.RWMutex - wg sync.WaitGroup + params Params + auth []ssh.AuthMethod + conns map[string]conn + rw sync.RWMutex + wg sync.WaitGroup + domainRegExp *regexp.Regexp } type conn struct { @@ -42,26 +45,58 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri if err := util.UnmarshalParams(params, &drv.params); err != nil { return nil, err } + matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex) + if err != nil { + return nil, fmt.Errorf("invalid domain_extract_regex: %w", err) + } + drv.domainRegExp = matcher drv.populateFromSSHConfig() drv.auth = drv.authenticators() return drv, nil } -func (d *SSH) forward(val sshtun.Forward) conn { +func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { tun := sshtun.New(d.params.Address, d.params.Auth.User, - d.params.FakeRemoteHost, val, d.auth, + sshtun.SessionConfig{ + NoPTY: d.params.NoPTY, + FakeRemoteHost: d.params.FakeRemoteHost, + KeepAliveInterval: uint(d.params.KeepAlive.Interval), + KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts), + }, d.Log()) ctx, cancel := context.WithCancel(d.Context()) tunDbgLog := d.Log().With("ssh-output", val.Remote.String()) go tun.Connect(ctx, - sshtun.StdoutPrinterBannerCallback(tunDbgLog), - sshtun.StdoutPrinterSessionCallback(tunDbgLog)) + sshtun.BannerDebugLogCallback(tunDbgLog), + sshtun.OutputReaderCallback(func(msg string) { + d.Log().Debug(msg) + if domainMatcher != nil { + domainMatcher(msg) + } + })) return conn{ctx: ctx, cancel: cancel, tun: tun} } +func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) { + if d.domainRegExp == nil { + return nil + } + return func(msg string) { + domain := d.domainRegExp.FindString(msg) + if domain == "" { + return + } + d.PushEventStatus(dto.EventStatus{ + Type: dto.EventStart, + ID: containerID, + Domain: domain, + }) + } +} + func (d *SSH) populateFromSSHConfig() { if d.params.Auth.Directory == "" { return @@ -94,9 +129,9 @@ func (d *SSH) Handle(event dto.Event) error { return ErrAlreadyInUse } conn := d.forward(sshtun.Forward{ - Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))), + Local: d.localEndpoint(event.Container.IP, event.Container.Port), Remote: d.remoteEndpoint(event.Container.RemoteHost), - }) + }, d.makeDomainMatcherFunc(event.Container.ID)) d.conns[event.Container.ID] = conn d.wg.Add(1) case dto.EventStop: @@ -106,17 +141,27 @@ func (d *SSH) Handle(event dto.Event) error { } conn.cancel() delete(d.conns, event.Container.ID) + d.propagateStop(event.Container.ID) d.wg.Done() case dto.EventShutdown: for id, conn := range d.conns { conn.cancel() delete(d.conns, id) + d.propagateStop(id) d.wg.Done() } } return nil } +func (d *SSH) propagateStop(containerID string) { + d.PushEventStatus(dto.EventStatus{Type: dto.EventStop, ID: containerID}) +} + +func (d *SSH) localEndpoint(ip net.IP, port uint16) sshtun.Endpoint { + return sshtun.AddrToEndpoint(net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))) +} + func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint { port := int(d.params.ForwardPort) if port == 0 { diff --git a/internal/server/driver/ssh/params.go b/internal/server/driver/ssh/params.go index 452d346..1403b09 100644 --- a/internal/server/driver/ssh/params.go +++ b/internal/server/driver/ssh/params.go @@ -1,6 +1,8 @@ package ssh import ( + "errors" + "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" "github.com/Neur0toxine/sshpoke/internal/server/driver/util" ) @@ -11,16 +13,19 @@ type Params struct { ForwardPort uint16 `mapstructure:"forward_port"` Auth types.Auth `mapstructure:"auth"` KeepAlive types.KeepAlive `mapstructure:"keepalive"` - Domain string `mapstructure:"domain"` - DomainProto string `mapstructure:"domain_proto"` DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"` Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"` FakeRemoteHost bool `mapstructure:"fake_remote_host"` + NoPTY bool `mapstructure:"nopty"` + Commands types.Commands `mapstructure:"commands"` } func (p *Params) Validate() error { if err := util.Validator.Struct(p); err != nil { return err } + if p.NoPTY && (len(p.Commands.OnConnect) > 0 || len(p.Commands.OnDisconnect) > 0) { + return errors.New("commands aren't available without PTY (nopty = true)") + } return p.Auth.Validate() } diff --git a/internal/server/driver/ssh/regexp.go b/internal/server/driver/ssh/regexp.go new file mode 100644 index 0000000..7555808 --- /dev/null +++ b/internal/server/driver/ssh/regexp.go @@ -0,0 +1,28 @@ +package ssh + +import ( + "fmt" + "regexp" + "strings" +) + +var builtInRegExps = map[string]*regexp.Regexp{ + "webUrl": regexp.MustCompile(`(?m)https?:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`), + "httpUrl": regexp.MustCompile(`(?m)http:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`), + "httpsUrl": regexp.MustCompile(`(?m)https:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`), +} + +func makeDomainCatchRegExp(expr string) (*regexp.Regexp, error) { + if expr == "" { + return nil, nil + } + if strings.HasPrefix(expr, "!") { + exprName := expr[1:] + builtIn, found := builtInRegExps[exprName] + if !found { + return nil, fmt.Errorf("no builtin regexp with name '%s'", exprName) + } + return builtIn, nil + } + return regexp.Compile(expr) +} diff --git a/internal/server/driver/ssh/sshtun/connect.go b/internal/server/driver/ssh/sshtun/connect.go new file mode 100644 index 0000000..fdad0af --- /dev/null +++ b/internal/server/driver/ssh/sshtun/connect.go @@ -0,0 +1,171 @@ +package sshtun + +import ( + "context" + "net" + "sync" + "sync/atomic" + "time" + + "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" + "github.com/function61/gokit/app/backoff" + "go.uber.org/zap" +) + +type SessionCallback func(*ssh.Session) + +var NoopSessionCallback SessionCallback = func(*ssh.Session) {} + +type Tunnel struct { + user string + address Endpoint + forward Forward + authMethods []ssh.AuthMethod + log *zap.SugaredLogger + sessConfig SessionConfig + connected atomic.Bool +} + +type SessionConfig struct { + NoPTY bool + FakeRemoteHost bool + KeepAliveInterval uint + KeepAliveMax uint +} + +func New(address, user string, forward Forward, auth []ssh.AuthMethod, sc SessionConfig, log *zap.SugaredLogger) *Tunnel { + return &Tunnel{ + address: AddrToEndpoint(address), + user: user, + forward: forward, + authMethods: auth, + sessConfig: sc, + log: log.With(zap.String("sshServer", address)), + } +} + +func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) { + if t.connected.Load() { + return + } + + defer t.connected.Store(false) + backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) + for { + t.connected.Store(true) + err := t.connect(ctx, bannerCb, sessionCb) + if err != nil { + t.log.Error("connect error: ", err) + } + + select { + case <-ctx.Done(): + return + default: + } + + time.Sleep(backoffTime()) + } +} + +// connect once to the SSH server. if the connection breaks, we return error and the caller +// will try to re-connect +func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error { + t.log.Debug("connecting") + sshConfig := &ssh.ClientConfig{ + User: t.user, + Auth: t.authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + BannerCallback: bannerCb, + } + + var sshClient *ssh.Client + var errConnect error + + sshClient, errConnect = dialSSH(ctx, t.address.String(), sshConfig) + if errConnect != nil { + return errConnect + } + + defer sshClient.Close() + defer t.log.Debug("disconnecting") + + t.log.Debug("connected") + listenerStopped := make(chan error) + + var wg sync.WaitGroup + sess, err := sshClient.NewSession() + if err != nil { + t.log.Errorf("session error: %s", err) + return err + } + defer sess.Close() + + if sessionCb == nil { + sessionCb = func(*ssh.Session) {} + } + + if !t.sessConfig.NoPTY { + err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{ + ssh.ECHO: 0, + ssh.IGNCR: 1, + }) + if err != nil { + t.log.Warnf("PTY allocation failed: %s", err.Error()) + } else { + if err := sess.Shell(); err != nil { + t.log.Warnf("failed to start shell: %s", err.Error()) + } + } + } + + wg.Add(1) + go func() { + defer wg.Done() + sessionCb(sess) + }() + + wg.Add(1) + reverseErr := make(chan error) + go func() { + defer wg.Done() + reverseErr <- t.reverseForwardOnePort(sshClient, listenerStopped) + }() + + wg.Add(1) + go t.keepAlive(ctx, sshClient, &wg) + + if err := <-reverseErr; err != nil { + return err + } + + select { + case <-ctx.Done(): + return nil + case listenerFirstErr := <-listenerStopped: + select { + case <-ctx.Done(): + return nil + default: + return listenerFirstErr + } + } +} + +func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { + dialer := net.Dialer{ + Timeout: 10 * time.Second, + } + + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig) + if err != nil { + return nil, err + } + + return ssh.NewClient(clConn, newChan, reqChan), nil +} diff --git a/internal/server/driver/ssh/sshtun/forward.go b/internal/server/driver/ssh/sshtun/forward.go index 05716f9..4cfa959 100644 --- a/internal/server/driver/ssh/sshtun/forward.go +++ b/internal/server/driver/ssh/sshtun/forward.go @@ -4,8 +4,12 @@ import ( "fmt" "net" "strconv" + "strings" "github.com/Neur0toxine/sshpoke/pkg/errtools" + "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" + "github.com/function61/gokit/io/bidipipe" + "go.uber.org/zap" ) type Forward struct { @@ -35,3 +39,77 @@ type Endpoint struct { func (endpoint *Endpoint) String() string { return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) } + +func (t *Tunnel) ipFromAddr(addr net.Addr) net.IP { + host, _, _ := net.SplitHostPort(addr.String()) + return net.ParseIP(host) +} + +// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error +// +// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped +func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error { + var ( + listener net.Listener + err error + ) + if t.sessConfig.FakeRemoteHost { + listener, err = sshClient.ListenTCP(&net.TCPAddr{ + IP: t.ipFromAddr(sshClient.Conn.RemoteAddr()), + Port: t.forward.Remote.Port, + }, t.forward.Remote.Host) + } else { + listener, err = sshClient.Listen("tcp", t.forward.Remote.String()) + } + if err != nil { + return err + } + + go func() { + defer listener.Close() + t.log.Debugf("forwarding %s <- %s", t.forward.Local.String(), t.forward.Remote.String()) + + for { + client, err := listener.Accept() + if err != nil { + listenerStopped <- fmt.Errorf("error on Accept(): %w", err) + return + } + + go handleReverseForwardConn(client, t.forward, t.log) + } + }() + + return nil +} + +func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) { + defer client.Close() + + remote, err := net.Dial("tcp", forward.Local.String()) + if err != nil { + log.Errorf("cannot dial local service: %s", err.Error()) + return + } + + log.Debugf("proxying %s <-> %s", forward.Local.String(), client.RemoteAddr()) + + // pipe data in both directions: + // - client => remote + // - remote => client + // + // - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed + // remote endpoint. + // - the "client" and "remote" strings we give Pipe() is just for error&log messages + // - this blocks until either of the parties' socket closes (or breaks) + if err := bidipipe.Pipe( + bidipipe.WithName("client", client), + bidipipe.WithName("remote", remote), + ); err != nil { + // we can safely ignore those errors. + if strings.Contains(err.Error(), "use of closed network connection") { + return + } + log.Error(err) + } +} diff --git a/internal/server/driver/ssh/sshtun/keepalive.go b/internal/server/driver/ssh/sshtun/keepalive.go new file mode 100644 index 0000000..a6de49c --- /dev/null +++ b/internal/server/driver/ssh/sshtun/keepalive.go @@ -0,0 +1,62 @@ +package sshtun + +import ( + "context" + "io" + "sync" + "sync/atomic" + "time" + + "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" +) + +// keepAlive periodically sends messages to invoke a response. +// If the server does not respond after some period of time, +// assume that the underlying net.Conn abruptly died. +func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) { + defer wg.Done() + if t.sessConfig.KeepAliveInterval == 0 || t.sessConfig.KeepAliveMax == 0 { + return + } + + // Detect when the SSH connection is closed. + wait := make(chan error, 1) + wg.Add(1) + go func() { + defer wg.Done() + wait <- client.Wait() + }() + + // Repeatedly check if the remote server is still alive. + var aliveCount int32 + ticker := time.NewTicker(time.Duration(t.sessConfig.KeepAliveInterval) * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + t.log.Debug("stopping keep-alive...") + _ = client.Close() + return + case err := <-wait: + if err != nil && err != io.EOF { + t.log.Error("ssh error:", err) + } + return + case <-ticker.C: + if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.sessConfig.KeepAliveMax) { + t.log.Error("keep-alive failed, closing connection...") + _ = client.Close() + return + } + } + + wg.Add(1) + go func() { + defer wg.Done() + _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) + if err == nil { + atomic.StoreInt32(&aliveCount, 0) + } + }() + } +} diff --git a/internal/server/driver/ssh/sshtun/printer.go b/internal/server/driver/ssh/sshtun/printer.go index 0805ed6..b241183 100644 --- a/internal/server/driver/ssh/sshtun/printer.go +++ b/internal/server/driver/ssh/sshtun/printer.go @@ -9,7 +9,7 @@ import ( "go.uber.org/zap" ) -func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback { +func OutputReaderCallback(callback func(string)) SessionCallback { return func(session *ssh.Session) { stdout, err := session.StdoutPipe() if err != nil { @@ -27,7 +27,7 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback { combined.Read(func(r io.Reader) error { scan := bufio.NewScanner(r) for scan.Scan() { - log.Debug(scan.Text()) + callback(scan.Text()) } return nil }) @@ -35,7 +35,7 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback { } } -func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback { +func BannerDebugLogCallback(log *zap.SugaredLogger) ssh.BannerCallback { return func(msg string) error { log.Debug(msg) return nil diff --git a/internal/server/driver/ssh/sshtun/ssh.go b/internal/server/driver/ssh/sshtun/ssh.go deleted file mode 100644 index 1c13011..0000000 --- a/internal/server/driver/ssh/sshtun/ssh.go +++ /dev/null @@ -1,224 +0,0 @@ -package sshtun - -import ( - "context" - "fmt" - "net" - "strings" - "sync" - "sync/atomic" - "time" - - "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" - "github.com/function61/gokit/app/backoff" - "github.com/function61/gokit/io/bidipipe" - "go.uber.org/zap" -) - -type SessionCallback func(*ssh.Session) - -var NoopSessionCallback SessionCallback = func(*ssh.Session) {} - -type Tunnel struct { - user string - address Endpoint - forward Forward - authMethods []ssh.AuthMethod - log *zap.SugaredLogger - connected atomic.Bool - fakeRemoteHost bool -} - -func New(address, user string, fakeRemoteHost bool, - forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel { - return &Tunnel{ - address: AddrToEndpoint(address), - user: user, - fakeRemoteHost: fakeRemoteHost, - forward: forward, - authMethods: auth, - log: log.With(zap.String("sshServer", address)), - } -} - -func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) { - if c.connected.Load() { - return - } - - defer c.connected.Store(false) - backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) - for { - c.connected.Store(true) - err := c.connect(ctx, bannerCb, sessionCb) - if err != nil { - c.log.Error("connect error:", err) - } - - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(backoffTime()) - } -} - -// connect once to the SSH server. if the connection breaks, we return error and the caller -// will try to re-connect -func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error { - c.log.Debug("connecting") - sshConfig := &ssh.ClientConfig{ - User: c.user, - Auth: c.authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - BannerCallback: bannerCb, - } - - var sshClient *ssh.Client - var errConnect error - - sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig) - if errConnect != nil { - return errConnect - } - - defer sshClient.Close() - defer c.log.Debug("disconnecting") - - c.log.Debug("connected") - listenerStopped := make(chan error) - - var wg sync.WaitGroup - sess, err := sshClient.NewSession() - if err != nil { - c.log.Errorf("session error: %s", err) - return err - } - defer sess.Close() - - if sessionCb == nil { - sessionCb = func(*ssh.Session) {} - } - wg.Add(1) - go func() { - defer wg.Done() - sessionCb(sess) - }() - - wg.Add(1) - reverseErr := make(chan error) - go func() { - defer wg.Done() - reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped) - }() - - if err := <-reverseErr; err != nil { - return err - } - - select { - case <-ctx.Done(): - return nil - case listenerFirstErr := <-listenerStopped: - select { - case <-ctx.Done(): - return nil - default: - return listenerFirstErr - } - } -} - -// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error -// -// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped -func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error { - var ( - listener net.Listener - err error - ) - if c.fakeRemoteHost { - listener, err = sshClient.ListenTCP(&net.TCPAddr{ - IP: c.ipFromAddr(sshClient.Conn.RemoteAddr()), - Port: c.forward.Remote.Port, - }, c.forward.Remote.Host) - } else { - listener, err = sshClient.Listen("tcp", c.forward.Remote.String()) - } - if err != nil { - return err - } - - go func() { - defer listener.Close() - c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String()) - - for { - client, err := listener.Accept() - if err != nil { - listenerStopped <- fmt.Errorf("error on Accept(): %w", err) - return - } - - go handleReverseForwardConn(client, c.forward, c.log) - } - }() - - return nil -} - -func (c *Tunnel) ipFromAddr(addr net.Addr) net.IP { - host, _, _ := net.SplitHostPort(addr.String()) - return net.ParseIP(host) -} - -func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) { - defer client.Close() - - remote, err := net.Dial("tcp", forward.Local.String()) - if err != nil { - log.Errorf("cannot dial local service: %s", err.Error()) - return - } - - log.Debugf("proxying %s <-> %s", forward.Local.String(), client.RemoteAddr()) - - // pipe data in both directions: - // - client => remote - // - remote => client - // - // - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed - // remote endpoint. - // - the "client" and "remote" strings we give Pipe() is just for error&log messages - // - this blocks until either of the parties' socket closes (or breaks) - if err := bidipipe.Pipe( - bidipipe.WithName("client", client), - bidipipe.WithName("remote", remote), - ); err != nil { - // we can safely ignore those errors. - if strings.Contains(err.Error(), "use of closed network connection") { - return - } - log.Error(err) - } -} - -func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { - dialer := net.Dialer{ - Timeout: 10 * time.Second, - } - - conn, err := dialer.DialContext(ctx, "tcp", addr) - if err != nil { - return nil, err - } - - clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig) - if err != nil { - return nil, err - } - - return ssh.NewClient(clConn, newChan, reqChan), nil -} diff --git a/internal/server/driver/ssh/types/commands.go b/internal/server/driver/ssh/types/commands.go new file mode 100644 index 0000000..1475de2 --- /dev/null +++ b/internal/server/driver/ssh/types/commands.go @@ -0,0 +1,6 @@ +package types + +type Commands struct { + OnConnect []string `mapstructure:"on_connect"` + OnDisconnect []string `mapstructure:"on_disconnect"` +} diff --git a/internal/server/driver/util/validator.go b/internal/server/driver/util/validator.go index 2d49764..a5f6395 100644 --- a/internal/server/driver/util/validator.go +++ b/internal/server/driver/util/validator.go @@ -9,7 +9,7 @@ import ( var Validator *validator.Validate func init() { - Validator = validator.New() + Validator = validator.New(validator.WithRequiredStructEnabled()) _ = Validator.RegisterValidation("validregexp", isValidRegExp) } diff --git a/internal/server/manager.go b/internal/server/manager.go index cc7dd60..f00ad00 100644 --- a/internal/server/manager.go +++ b/internal/server/manager.go @@ -89,6 +89,8 @@ func (m *Manager) eventStatusCallback(serverName string) base.EventStatusCallbac } func (m *Manager) processEventStatus(serverName string, event dto.EventStatus) { + logger.Sugar.Debugw("received EventStatus from server", + "serverName", serverName, "eventStatus", event) m.statusLock.RLock() _, exists := m.statusMap[serverName] if !exists { diff --git a/pkg/dto/models.go b/pkg/dto/models.go index eb4ee99..4456dfb 100644 --- a/pkg/dto/models.go +++ b/pkg/dto/models.go @@ -12,6 +12,22 @@ const ( EventUnknown ) +var eventTypeNames = map[EventType]string{ + EventStart: "start", + EventStop: "stop", + EventShutdown: "shutdown", + EventError: "error", + EventUnknown: "unknown", +} + +func (e EventType) String() string { + name, ok := eventTypeNames[e] + if ok { + return name + } + return eventTypeNames[EventUnknown] +} + func TypeFromAction(action string) EventType { switch action { case "start":