package sshtun import ( "context" "fmt" "net" "os" "sync" "sync/atomic" "time" "github.com/function61/gokit/app/backoff" "github.com/function61/gokit/io/bidipipe" "go.uber.org/zap" "golang.org/x/crypto/ssh" ) type SessionCallback func(*ssh.Session) var NoopSessionCallback SessionCallback = func(*ssh.Session) {} type Tunnel struct { user string address Endpoint forward Forward authMethods []ssh.AuthMethod log *zap.SugaredLogger connected atomic.Bool fakeRemoteHost bool } func New(address, user string, fakeRemoteHost bool, forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel { return &Tunnel{ address: AddrToEndpoint(address), user: user, fakeRemoteHost: fakeRemoteHost, forward: forward, authMethods: auth, log: log.With(zap.String("sshServer", address)), } } func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) { if c.connected.Load() { return } defer c.connected.Store(false) backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) for { c.connected.Store(true) err := c.connect(ctx, sessionCb) if err != nil { c.log.Error("connect error:", err) } select { case <-ctx.Done(): return default: } time.Sleep(backoffTime()) } } // connect once to the SSH server. if the connection breaks, we return error and the caller // will try to re-connect func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error { c.log.Debug("connecting") sshConfig := &ssh.ClientConfig{ User: c.user, Auth: c.authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } var sshClient *ssh.Client var errConnect error sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig) if errConnect != nil { return errConnect } defer sshClient.Close() defer c.log.Debug("disconnecting") c.log.Debug("connected") listenerStopped := make(chan error) sess, err := sshClient.NewSession() if err != nil { c.log.Errorf("session error: %s", err) return err } defer sess.Close() var wg sync.WaitGroup if sessionCb == nil { sessionCb = func(*ssh.Session) {} } wg.Add(2) go func() { defer wg.Done() sessionCb(sess) }() reverseErr := make(chan error) go func() { defer wg.Done() reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped) }() if err := <-reverseErr; err != nil { return err } select { case <-ctx.Done(): return nil case listenerFirstErr := <-listenerStopped: select { case <-ctx.Done(): return nil default: return listenerFirstErr } } } // blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error // // nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error { if c.fakeRemoteHost { newSishHostListener(sshClient).ListenFakeRemoteHost(c.forward.Remote) time.Sleep(time.Second) os.Exit(0) } listener, err := sshClient.Listen("tcp", c.forward.Remote.String()) if err != nil { return err } go func() { defer listener.Close() c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String()) for { client, err := listener.Accept() if err != nil { listenerStopped <- fmt.Errorf("error on Accept(): %w", err) return } go handleReverseForwardConn(client, c.forward, c.log) } }() return nil } func (c *Tunnel) listenTCPWithoutResolving() { } func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) { defer client.Close() log.Debugf("%s connected", client.RemoteAddr()) defer log.Debug("closed") remote, err := net.Dial("tcp", forward.Local.String()) if err != nil { log.Errorf("dial INTO local service error: %s", err.Error()) return } // pipe data in both directions: // - client => remote // - remote => client // // - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed // remote endpoint. // - the "client" and "remote" strings we give Pipe() is just for error&log messages // - this blocks until either of the parties' socket closes (or breaks) if err := bidipipe.Pipe( bidipipe.WithName("client", client), bidipipe.WithName("remote", remote), ); err != nil { log.Error(err) } } func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { dialer := net.Dialer{ Timeout: 10 * time.Second, } conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, err } clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig) if err != nil { return nil, err } return ssh.NewClient(clConn, newChan, reqChan), nil }