sshpoke/internal/server/driver/ssh/sshproto/ssh.go
2023-11-18 17:51:04 +03:00

243 lines
6.3 KiB
Go

package sshproto
import (
"context"
"fmt"
"net"
"sync/atomic"
"time"
"github.com/Neur0toxine/sshpoke/pkg/errtools"
"github.com/function61/gokit/app/backoff"
"github.com/function61/gokit/io/bidipipe"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
)
type Client struct {
user string
address string
authMethods []ssh.AuthMethod
log *zap.SugaredLogger
connected atomic.Bool
}
func New(address, user string, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Client {
return &Client{
address: prepareAddress(address),
user: user,
authMethods: auth,
log: log.With(zap.String("sshServer", address)),
}
}
func prepareAddress(address string) string {
_, _, err := net.SplitHostPort(address)
if err != nil && errtools.IsPortMissingErr(err) {
return net.JoinHostPort(address, "22")
}
return address
}
func (c *Client) Connect(ctx context.Context) {
if c.connected.Load() {
return
}
defer c.connected.Store(false)
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
for {
c.connected.Store(true)
err := c.connect(ctx)
if err != nil {
c.log.Error("connect error:", err)
}
select {
case <-ctx.Done():
return
default:
}
time.Sleep(backoffTime())
}
}
// connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect
func (c *Client) connect(ctx context.Context) error {
c.log.Debug("connecting")
sshConfig := &ssh.ClientConfig{
User: c.user,
Auth: c.authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
var sshClient *ssh.Client
var errConnect error
sshClient, errConnect = dialSSH(ctx, c.address, sshConfig)
if errConnect != nil {
return errConnect
}
// always disconnect when function returns
defer sshClient.Close()
defer c.log.Debug("disconnecting")
c.log.Debug("connected")
<-ctx.Done()
return nil
}
// // connect once to the SSH server. if the connection breaks, we return error and the caller
// // will try to re-connect
// func connectToSshAndServe2(
// ctx context.Context,
// address string,
// authConfig types.Auth,
// auth ssh.AuthMethod,
// log *zap.SugaredLogger,
// ) error {
// log = log.With(zap.String("sshServer", address))
// log.Debug("connecting")
// sshConfig := &ssh.ClientConfig{
// User: authConfig.User,
// Auth: []ssh.AuthMethod{auth},
// HostKeyCallback: ssh.InsecureIgnoreHostKey(),
// }
//
// var sshClient *ssh.Client
// var errConnect error
//
// sshClient, errConnect = dialSSH(ctx, address, sshConfig)
// if errConnect != nil {
// return errConnect
// }
//
// // always disconnect when function returns
// defer sshClient.Close()
// defer log.Debug("disconnecting")
//
// log.Debug("connected")
//
// // each listener in reverseForwardOnePort() can return one error, so make sure channel
// // has enough buffering so even if we return from here, the goroutines won't block trying
// // to return an error
// listenerStopped := make(chan error, len(conf.Forwards))
//
// for _, forward := range conf.Forwards {
// // TODO: "if any fails, tear down all workers" -style error handling would be better
// // handled with https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc
// if err := reverseForwardOnePort(
// forward,
// sshClient,
// listenerStopped,
// makeLogger("reverseForwardOnePort"),
// makeLogger,
// ); err != nil {
// // closes SSH connection if even one forward Listen() fails
// return err
// }
// }
//
// // we're connected and have succesfully started listening on all reverse forwards, wait
// // for either user to ask us to stop or any of the listeners to error
// select {
// case <-ctx.Done(): // cancel requested
// return nil
// case listenerFirstErr := <-listenerStopped:
// // one or more of the listeners encountered an error. react by closing the connection
// // assumes all the other listeners failed too so no teardown necessary
// select {
// case <-ctx.Done(): // pretty much errors are to be expected if cancellation triggered
// return nil
// default:
// return listenerFirstErr
// }
// }
// }
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
//
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
func reverseForwardOnePort(
forward Forward,
sshClient *ssh.Client,
listenerStopped chan<- error,
log *zap.SugaredLogger,
) error {
// reverse listen on remote server port
listener, err := sshClient.Listen("tcp", forward.Remote.String())
if err != nil {
return err
}
go func() {
defer listener.Close()
log.Debugf("forwarding %s -> %s", forward.Local.String(), forward.Remote.String())
// handle incoming connections on reverse forwarded tunnel
for {
client, err := listener.Accept()
if err != nil {
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
return
}
// handle the connection in another goroutine, so we can support multiple concurrent
// connections on the same port
go handleReverseForwardConn(client, forward, log)
}
}()
return nil
}
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
defer client.Close()
log.Debugf("%s connected", client.RemoteAddr())
defer log.Debug("closed")
remote, err := net.Dial("tcp", forward.Local.String())
if err != nil {
log.Errorf("dial INTO local service error: %s", err.Error())
return
}
// pipe data in both directions:
// - client => remote
// - remote => client
//
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
// remote endpoint.
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
// - this blocks until either of the parties' socket closes (or breaks)
if err := bidipipe.Pipe(
bidipipe.WithName("client", client),
bidipipe.WithName("remote", remote),
); err != nil {
log.Error(err)
}
}
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
dialer := net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
if err != nil {
return nil, err
}
return ssh.NewClient(clConn, newChan, reqChan), nil
}