sshpoke/internal/server/driver/ssh/sshtun/ssh.go
2023-11-18 21:23:29 +03:00

211 lines
4.8 KiB
Go

package sshtun
import (
"context"
"fmt"
"net"
"os"
"sync"
"sync/atomic"
"time"
"github.com/function61/gokit/app/backoff"
"github.com/function61/gokit/io/bidipipe"
"go.uber.org/zap"
"golang.org/x/crypto/ssh"
)
type SessionCallback func(*ssh.Session)
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
type Tunnel struct {
user string
address Endpoint
forward Forward
authMethods []ssh.AuthMethod
log *zap.SugaredLogger
connected atomic.Bool
fakeRemoteHost bool
}
func New(address, user string, fakeRemoteHost bool,
forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel {
return &Tunnel{
address: AddrToEndpoint(address),
user: user,
fakeRemoteHost: fakeRemoteHost,
forward: forward,
authMethods: auth,
log: log.With(zap.String("sshServer", address)),
}
}
func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
if c.connected.Load() {
return
}
defer c.connected.Store(false)
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
for {
c.connected.Store(true)
err := c.connect(ctx, sessionCb)
if err != nil {
c.log.Error("connect error:", err)
}
select {
case <-ctx.Done():
return
default:
}
time.Sleep(backoffTime())
}
}
// connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect
func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
c.log.Debug("connecting")
sshConfig := &ssh.ClientConfig{
User: c.user,
Auth: c.authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
var sshClient *ssh.Client
var errConnect error
sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig)
if errConnect != nil {
return errConnect
}
defer sshClient.Close()
defer c.log.Debug("disconnecting")
c.log.Debug("connected")
listenerStopped := make(chan error)
sess, err := sshClient.NewSession()
if err != nil {
c.log.Errorf("session error: %s", err)
return err
}
defer sess.Close()
var wg sync.WaitGroup
if sessionCb == nil {
sessionCb = func(*ssh.Session) {}
}
wg.Add(2)
go func() {
defer wg.Done()
sessionCb(sess)
}()
reverseErr := make(chan error)
go func() {
defer wg.Done()
reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped)
}()
if err := <-reverseErr; err != nil {
return err
}
select {
case <-ctx.Done():
return nil
case listenerFirstErr := <-listenerStopped:
select {
case <-ctx.Done():
return nil
default:
return listenerFirstErr
}
}
}
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
//
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
if c.fakeRemoteHost {
newSishHostListener(sshClient).ListenFakeRemoteHost(c.forward.Remote)
time.Sleep(time.Second)
os.Exit(0)
}
listener, err := sshClient.Listen("tcp", c.forward.Remote.String())
if err != nil {
return err
}
go func() {
defer listener.Close()
c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String())
for {
client, err := listener.Accept()
if err != nil {
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
return
}
go handleReverseForwardConn(client, c.forward, c.log)
}
}()
return nil
}
func (c *Tunnel) listenTCPWithoutResolving() {
}
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
defer client.Close()
log.Debugf("%s connected", client.RemoteAddr())
defer log.Debug("closed")
remote, err := net.Dial("tcp", forward.Local.String())
if err != nil {
log.Errorf("dial INTO local service error: %s", err.Error())
return
}
// pipe data in both directions:
// - client => remote
// - remote => client
//
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
// remote endpoint.
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
// - this blocks until either of the parties' socket closes (or breaks)
if err := bidipipe.Pipe(
bidipipe.WithName("client", client),
bidipipe.WithName("remote", remote),
); err != nil {
log.Error(err)
}
}
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
dialer := net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
if err != nil {
return nil, err
}
return ssh.NewClient(clConn, newChan, reqChan), nil
}