package sshtun import ( "context" "net" "sync" "sync/atomic" "time" "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" "github.com/function61/gokit/app/backoff" "go.uber.org/zap" ) type SessionCallback func(*ssh.Session) var NoopSessionCallback SessionCallback = func(*ssh.Session) {} type Tunnel struct { user string address Endpoint authMethods []ssh.AuthMethod log *zap.SugaredLogger tunConfig TunnelConfig connected atomic.Bool } type TunnelConfig struct { Forward Forward HostKeys []ssh.PublicKey NoPTY bool Shell bool FakeRemoteHost bool KeepAliveInterval uint KeepAliveMax uint } func New(address, user string, auth []ssh.AuthMethod, sc TunnelConfig, log *zap.SugaredLogger) *Tunnel { return &Tunnel{ address: AddrToEndpoint(address), user: user, authMethods: auth, tunConfig: sc, log: log.With(zap.String("sshServer", address)), } } func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) { if t.connected.Load() { return } defer t.connected.Store(false) backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) for { t.connected.Store(true) err := t.connect(ctx, bannerCb, sessionCb) if err != nil { t.log.Error("connect error: ", err) } select { case <-ctx.Done(): return default: } time.Sleep(backoffTime()) } } func (t *Tunnel) buildHostKeyCallback() ssh.HostKeyCallback { if t.tunConfig.HostKeys == nil || len(t.tunConfig.HostKeys) == 0 { return ssh.InsecureIgnoreHostKey() } if len(t.tunConfig.HostKeys) == 1 { return ssh.FixedHostKey(t.tunConfig.HostKeys[0]) } return FixedHostKeys(t.tunConfig.HostKeys) } // connect once to the SSH server. if the connection breaks, we return error and the caller // will try to re-connect func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error { t.log.Debug("connecting") sshConfig := &ssh.ClientConfig{ User: t.user, Auth: t.authMethods, HostKeyCallback: t.buildHostKeyCallback(), BannerCallback: bannerCb, } var sshClient *ssh.Client var errConnect error sshClient, errConnect = dialSSH(ctx, t.address.String(), sshConfig) if errConnect != nil { return errConnect } defer sshClient.Close() defer t.log.Debug("disconnecting") t.log.Debug("connected") listenerStopped := make(chan error) var wg sync.WaitGroup sess, err := sshClient.NewSession() if err != nil { t.log.Errorf("session error: %s", err) return err } defer sess.Close() if sessionCb == nil { sessionCb = func(*ssh.Session) {} } if !t.tunConfig.NoPTY { err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{ ssh.ECHO: 0, ssh.IGNCR: 1, }) if err != nil { t.log.Warnf("PTY allocation failed: %s", err.Error()) } } if t.tunConfig.Shell { if err := sess.Shell(); err != nil { t.log.Warnf("failed to start shell: %s", err.Error()) } wg.Add(1) go func() { defer wg.Done() _ = sess.Wait() }() } wg.Add(1) go func() { defer wg.Done() sessionCb(sess) }() wg.Add(1) reverseErr := make(chan error) go func() { defer wg.Done() reverseErr <- t.reverseForwardOnePort(sshClient, listenerStopped) }() wg.Add(1) go t.keepAlive(ctx, sshClient, &wg) if err := <-reverseErr; err != nil { return err } select { case <-ctx.Done(): return nil case listenerFirstErr := <-listenerStopped: select { case <-ctx.Done(): return nil default: return listenerFirstErr } } } func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { dialer := net.Dialer{ Timeout: 10 * time.Second, } conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, err } clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig) if err != nil { return nil, err } return ssh.NewClient(clConn, newChan, reqChan), nil }