From 408f88ebd20872ed53b7a969fcedfe83a96cae45 Mon Sep 17 00:00:00 2001 From: Neur0toxine Date: Sat, 18 Nov 2023 23:02:41 +0300 Subject: [PATCH] ssh: mode parameter support, banner support --- internal/server/driver/ssh/driver.go | 16 ++++++++++++++-- internal/server/driver/ssh/sshtun/forward.go | 2 +- internal/server/driver/ssh/sshtun/printer.go | 7 +++++++ internal/server/driver/ssh/sshtun/ssh.go | 13 ++++++++----- 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/internal/server/driver/ssh/driver.go b/internal/server/driver/ssh/driver.go index 7520f70..eb06d61 100644 --- a/internal/server/driver/ssh/driver.go +++ b/internal/server/driver/ssh/driver.go @@ -2,6 +2,7 @@ package ssh import ( "context" + "errors" "net" "path" "strconv" @@ -16,6 +17,8 @@ import ( "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" ) +var ErrAlreadyInUse = errors.New("domain is already in use") + type SSH struct { base.Base params Params @@ -45,9 +48,15 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri } func (d *SSH) forward(val sshtun.Forward) conn { - tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.FakeRemoteHost, val, d.auth, d.Log()) + tun := sshtun.New(d.params.Address, + d.params.Auth.User, + d.params.FakeRemoteHost, + val, + d.auth, + d.Log()) ctx, cancel := context.WithCancel(d.Context()) - go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String()))) + tunDbgLog := d.Log().With("ssh-output", val.Remote.String()) + go tun.Connect(ctx, sshtun.StdoutPrinterBannerCallback(tunDbgLog), sshtun.StdoutPrinterSessionCallback(tunDbgLog)) return conn{ctx: ctx, cancel: cancel, tun: tun} } @@ -79,6 +88,9 @@ func (d *SSH) Handle(event dto.Event) error { d.rw.Lock() switch event.Type { case dto.EventStart: + if d.params.Mode == types.DomainModeSingle && len(d.conns) > 0 { + return ErrAlreadyInUse + } conn := d.forward(sshtun.Forward{ Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))), Remote: d.remoteEndpoint(event.Container.RemoteHost), diff --git a/internal/server/driver/ssh/sshtun/forward.go b/internal/server/driver/ssh/sshtun/forward.go index e7f0ec3..05716f9 100644 --- a/internal/server/driver/ssh/sshtun/forward.go +++ b/internal/server/driver/ssh/sshtun/forward.go @@ -18,7 +18,7 @@ type Forward struct { func AddrToEndpoint(address string) Endpoint { host, port, err := net.SplitHostPort(address) if err != nil && errtools.IsPortMissingErr(err) { - return Endpoint{Host: host, Port: 22} + return Endpoint{Host: address, Port: 22} } portNum, err := strconv.Atoi(port) if err != nil { diff --git a/internal/server/driver/ssh/sshtun/printer.go b/internal/server/driver/ssh/sshtun/printer.go index b5653cd..c176a47 100644 --- a/internal/server/driver/ssh/sshtun/printer.go +++ b/internal/server/driver/ssh/sshtun/printer.go @@ -19,3 +19,10 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback { } } } + +func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback { + return func(msg string) error { + log.Debug(msg) + return nil + } +} diff --git a/internal/server/driver/ssh/sshtun/ssh.go b/internal/server/driver/ssh/sshtun/ssh.go index d1b86f8..1c13011 100644 --- a/internal/server/driver/ssh/sshtun/ssh.go +++ b/internal/server/driver/ssh/sshtun/ssh.go @@ -41,7 +41,7 @@ func New(address, user string, fakeRemoteHost bool, } } -func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) { +func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) { if c.connected.Load() { return } @@ -50,7 +50,7 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) { backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) for { c.connected.Store(true) - err := c.connect(ctx, sessionCb) + err := c.connect(ctx, bannerCb, sessionCb) if err != nil { c.log.Error("connect error:", err) } @@ -67,12 +67,13 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) { // connect once to the SSH server. if the connection breaks, we return error and the caller // will try to re-connect -func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error { +func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error { c.log.Debug("connecting") sshConfig := &ssh.ClientConfig{ User: c.user, Auth: c.authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), + BannerCallback: bannerCb, } var sshClient *ssh.Client @@ -89,6 +90,7 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error { c.log.Debug("connected") listenerStopped := make(chan error) + var wg sync.WaitGroup sess, err := sshClient.NewSession() if err != nil { c.log.Errorf("session error: %s", err) @@ -96,15 +98,16 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error { } defer sess.Close() - var wg sync.WaitGroup if sessionCb == nil { sessionCb = func(*ssh.Session) {} } - wg.Add(2) + wg.Add(1) go func() { defer wg.Done() sessionCb(sess) }() + + wg.Add(1) reverseErr := make(chan error) go func() { defer wg.Done()