package sshproto import ( "context" "fmt" "net" "sync/atomic" "time" "github.com/Neur0toxine/sshpoke/pkg/errtools" "github.com/function61/gokit/app/backoff" "github.com/function61/gokit/io/bidipipe" "go.uber.org/zap" "golang.org/x/crypto/ssh" ) type Client struct { user string address string authMethods []ssh.AuthMethod log *zap.SugaredLogger connected atomic.Bool } func New(address, user string, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Client { return &Client{ address: prepareAddress(address), user: user, authMethods: auth, log: log.With(zap.String("sshServer", address)), } } func prepareAddress(address string) string { _, _, err := net.SplitHostPort(address) if err != nil && errtools.IsPortMissingErr(err) { return net.JoinHostPort(address, "22") } return address } func (c *Client) Connect(ctx context.Context) { if c.connected.Load() { return } defer c.connected.Store(false) backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) for { c.connected.Store(true) err := c.connect(ctx) if err != nil { c.log.Error("connect error:", err) } select { case <-ctx.Done(): return default: } time.Sleep(backoffTime()) } } // connect once to the SSH server. if the connection breaks, we return error and the caller // will try to re-connect func (c *Client) connect(ctx context.Context) error { c.log.Debug("connecting") sshConfig := &ssh.ClientConfig{ User: c.user, Auth: c.authMethods, HostKeyCallback: ssh.InsecureIgnoreHostKey(), } var sshClient *ssh.Client var errConnect error sshClient, errConnect = dialSSH(ctx, c.address, sshConfig) if errConnect != nil { return errConnect } // always disconnect when function returns defer sshClient.Close() defer c.log.Debug("disconnecting") c.log.Debug("connected") <-ctx.Done() return nil } // // connect once to the SSH server. if the connection breaks, we return error and the caller // // will try to re-connect // func connectToSshAndServe2( // ctx context.Context, // address string, // authConfig types.Auth, // auth ssh.AuthMethod, // log *zap.SugaredLogger, // ) error { // log = log.With(zap.String("sshServer", address)) // log.Debug("connecting") // sshConfig := &ssh.ClientConfig{ // User: authConfig.User, // Auth: []ssh.AuthMethod{auth}, // HostKeyCallback: ssh.InsecureIgnoreHostKey(), // } // // var sshClient *ssh.Client // var errConnect error // // sshClient, errConnect = dialSSH(ctx, address, sshConfig) // if errConnect != nil { // return errConnect // } // // // always disconnect when function returns // defer sshClient.Close() // defer log.Debug("disconnecting") // // log.Debug("connected") // // // each listener in reverseForwardOnePort() can return one error, so make sure channel // // has enough buffering so even if we return from here, the goroutines won't block trying // // to return an error // listenerStopped := make(chan error, len(conf.Forwards)) // // for _, forward := range conf.Forwards { // // TODO: "if any fails, tear down all workers" -style error handling would be better // // handled with https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc // if err := reverseForwardOnePort( // forward, // sshClient, // listenerStopped, // makeLogger("reverseForwardOnePort"), // makeLogger, // ); err != nil { // // closes SSH connection if even one forward Listen() fails // return err // } // } // // // we're connected and have succesfully started listening on all reverse forwards, wait // // for either user to ask us to stop or any of the listeners to error // select { // case <-ctx.Done(): // cancel requested // return nil // case listenerFirstErr := <-listenerStopped: // // one or more of the listeners encountered an error. react by closing the connection // // assumes all the other listeners failed too so no teardown necessary // select { // case <-ctx.Done(): // pretty much errors are to be expected if cancellation triggered // return nil // default: // return listenerFirstErr // } // } // } // blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error // // nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped func reverseForwardOnePort( forward Forward, sshClient *ssh.Client, listenerStopped chan<- error, log *zap.SugaredLogger, ) error { // reverse listen on remote server port listener, err := sshClient.Listen("tcp", forward.Remote.String()) if err != nil { return err } go func() { defer listener.Close() log.Debugf("forwarding %s -> %s", forward.Local.String(), forward.Remote.String()) // handle incoming connections on reverse forwarded tunnel for { client, err := listener.Accept() if err != nil { listenerStopped <- fmt.Errorf("error on Accept(): %w", err) return } // handle the connection in another goroutine, so we can support multiple concurrent // connections on the same port go handleReverseForwardConn(client, forward, log) } }() return nil } func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) { defer client.Close() log.Debugf("%s connected", client.RemoteAddr()) defer log.Debug("closed") remote, err := net.Dial("tcp", forward.Local.String()) if err != nil { log.Errorf("dial INTO local service error: %s", err.Error()) return } // pipe data in both directions: // - client => remote // - remote => client // // - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed // remote endpoint. // - the "client" and "remote" strings we give Pipe() is just for error&log messages // - this blocks until either of the parties' socket closes (or breaks) if err := bidipipe.Pipe( bidipipe.WithName("client", client), bidipipe.WithName("remote", remote), ); err != nil { log.Error(err) } } func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { dialer := net.Dialer{ Timeout: 10 * time.Second, } conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, err } clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig) if err != nil { return nil, err } return ssh.NewClient(clConn, newChan, reqChan), nil }