243 lines
6.3 KiB
Go
243 lines
6.3 KiB
Go
|
package sshproto
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"net"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
||
|
"github.com/function61/gokit/app/backoff"
|
||
|
"github.com/function61/gokit/io/bidipipe"
|
||
|
"go.uber.org/zap"
|
||
|
"golang.org/x/crypto/ssh"
|
||
|
)
|
||
|
|
||
|
type Client struct {
|
||
|
user string
|
||
|
address string
|
||
|
authMethods []ssh.AuthMethod
|
||
|
log *zap.SugaredLogger
|
||
|
connected atomic.Bool
|
||
|
}
|
||
|
|
||
|
func New(address, user string, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Client {
|
||
|
return &Client{
|
||
|
address: prepareAddress(address),
|
||
|
user: user,
|
||
|
authMethods: auth,
|
||
|
log: log.With(zap.String("sshServer", address)),
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func prepareAddress(address string) string {
|
||
|
_, _, err := net.SplitHostPort(address)
|
||
|
if err != nil && errtools.IsPortMissingErr(err) {
|
||
|
return net.JoinHostPort(address, "22")
|
||
|
}
|
||
|
return address
|
||
|
}
|
||
|
|
||
|
func (c *Client) Connect(ctx context.Context) {
|
||
|
if c.connected.Load() {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
defer c.connected.Store(false)
|
||
|
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
||
|
for {
|
||
|
c.connected.Store(true)
|
||
|
err := c.connect(ctx)
|
||
|
if err != nil {
|
||
|
c.log.Error("connect error:", err)
|
||
|
}
|
||
|
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
time.Sleep(backoffTime())
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
||
|
// will try to re-connect
|
||
|
func (c *Client) connect(ctx context.Context) error {
|
||
|
c.log.Debug("connecting")
|
||
|
sshConfig := &ssh.ClientConfig{
|
||
|
User: c.user,
|
||
|
Auth: c.authMethods,
|
||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||
|
}
|
||
|
|
||
|
var sshClient *ssh.Client
|
||
|
var errConnect error
|
||
|
|
||
|
sshClient, errConnect = dialSSH(ctx, c.address, sshConfig)
|
||
|
if errConnect != nil {
|
||
|
return errConnect
|
||
|
}
|
||
|
|
||
|
// always disconnect when function returns
|
||
|
defer sshClient.Close()
|
||
|
defer c.log.Debug("disconnecting")
|
||
|
|
||
|
c.log.Debug("connected")
|
||
|
|
||
|
<-ctx.Done()
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// // connect once to the SSH server. if the connection breaks, we return error and the caller
|
||
|
// // will try to re-connect
|
||
|
// func connectToSshAndServe2(
|
||
|
// ctx context.Context,
|
||
|
// address string,
|
||
|
// authConfig types.Auth,
|
||
|
// auth ssh.AuthMethod,
|
||
|
// log *zap.SugaredLogger,
|
||
|
// ) error {
|
||
|
// log = log.With(zap.String("sshServer", address))
|
||
|
// log.Debug("connecting")
|
||
|
// sshConfig := &ssh.ClientConfig{
|
||
|
// User: authConfig.User,
|
||
|
// Auth: []ssh.AuthMethod{auth},
|
||
|
// HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||
|
// }
|
||
|
//
|
||
|
// var sshClient *ssh.Client
|
||
|
// var errConnect error
|
||
|
//
|
||
|
// sshClient, errConnect = dialSSH(ctx, address, sshConfig)
|
||
|
// if errConnect != nil {
|
||
|
// return errConnect
|
||
|
// }
|
||
|
//
|
||
|
// // always disconnect when function returns
|
||
|
// defer sshClient.Close()
|
||
|
// defer log.Debug("disconnecting")
|
||
|
//
|
||
|
// log.Debug("connected")
|
||
|
//
|
||
|
// // each listener in reverseForwardOnePort() can return one error, so make sure channel
|
||
|
// // has enough buffering so even if we return from here, the goroutines won't block trying
|
||
|
// // to return an error
|
||
|
// listenerStopped := make(chan error, len(conf.Forwards))
|
||
|
//
|
||
|
// for _, forward := range conf.Forwards {
|
||
|
// // TODO: "if any fails, tear down all workers" -style error handling would be better
|
||
|
// // handled with https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc
|
||
|
// if err := reverseForwardOnePort(
|
||
|
// forward,
|
||
|
// sshClient,
|
||
|
// listenerStopped,
|
||
|
// makeLogger("reverseForwardOnePort"),
|
||
|
// makeLogger,
|
||
|
// ); err != nil {
|
||
|
// // closes SSH connection if even one forward Listen() fails
|
||
|
// return err
|
||
|
// }
|
||
|
// }
|
||
|
//
|
||
|
// // we're connected and have succesfully started listening on all reverse forwards, wait
|
||
|
// // for either user to ask us to stop or any of the listeners to error
|
||
|
// select {
|
||
|
// case <-ctx.Done(): // cancel requested
|
||
|
// return nil
|
||
|
// case listenerFirstErr := <-listenerStopped:
|
||
|
// // one or more of the listeners encountered an error. react by closing the connection
|
||
|
// // assumes all the other listeners failed too so no teardown necessary
|
||
|
// select {
|
||
|
// case <-ctx.Done(): // pretty much errors are to be expected if cancellation triggered
|
||
|
// return nil
|
||
|
// default:
|
||
|
// return listenerFirstErr
|
||
|
// }
|
||
|
// }
|
||
|
// }
|
||
|
|
||
|
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
||
|
//
|
||
|
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
||
|
func reverseForwardOnePort(
|
||
|
forward Forward,
|
||
|
sshClient *ssh.Client,
|
||
|
listenerStopped chan<- error,
|
||
|
log *zap.SugaredLogger,
|
||
|
) error {
|
||
|
// reverse listen on remote server port
|
||
|
listener, err := sshClient.Listen("tcp", forward.Remote.String())
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
go func() {
|
||
|
defer listener.Close()
|
||
|
log.Debugf("forwarding %s -> %s", forward.Local.String(), forward.Remote.String())
|
||
|
|
||
|
// handle incoming connections on reverse forwarded tunnel
|
||
|
for {
|
||
|
client, err := listener.Accept()
|
||
|
if err != nil {
|
||
|
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// handle the connection in another goroutine, so we can support multiple concurrent
|
||
|
// connections on the same port
|
||
|
go handleReverseForwardConn(client, forward, log)
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
||
|
defer client.Close()
|
||
|
|
||
|
log.Debugf("%s connected", client.RemoteAddr())
|
||
|
defer log.Debug("closed")
|
||
|
|
||
|
remote, err := net.Dial("tcp", forward.Local.String())
|
||
|
if err != nil {
|
||
|
log.Errorf("dial INTO local service error: %s", err.Error())
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// pipe data in both directions:
|
||
|
// - client => remote
|
||
|
// - remote => client
|
||
|
//
|
||
|
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
||
|
// remote endpoint.
|
||
|
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
||
|
// - this blocks until either of the parties' socket closes (or breaks)
|
||
|
if err := bidipipe.Pipe(
|
||
|
bidipipe.WithName("client", client),
|
||
|
bidipipe.WithName("remote", remote),
|
||
|
); err != nil {
|
||
|
log.Error(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
||
|
dialer := net.Dialer{
|
||
|
Timeout: 10 * time.Second,
|
||
|
}
|
||
|
|
||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
return ssh.NewClient(clConn, newChan, reqChan), nil
|
||
|
}
|