sshpoke/internal/server/driver/ssh/sshtun/connect.go

211 lines
4.4 KiB
Go

package sshtun
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
"github.com/function61/gokit/app/backoff"
"go.uber.org/zap"
)
type SessionCallback func(*ssh.Session)
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
type Tunnel struct {
user string
address Endpoint
authMethods []ssh.AuthMethod
log *zap.SugaredLogger
tunConfig TunnelConfig
connected atomic.Bool
}
type TunnelConfig struct {
Forward Forward
HostKeys []ssh.PublicKey
NoPTY bool
Shell BoolOrStr
ClientVersion string
FakeRemoteHost bool
KeepAliveInterval uint
KeepAliveMax uint
}
type BoolOrStr string
func (b BoolOrStr) IsBool() bool {
return b == "true" || b == "false"
}
func (b BoolOrStr) Falsy() bool {
return b == "" || b == "false"
}
func (b BoolOrStr) String() string {
return string(b)
}
func New(address, user string, auth []ssh.AuthMethod, sc TunnelConfig, log *zap.SugaredLogger) *Tunnel {
return &Tunnel{
address: AddrToEndpoint(address),
user: user,
authMethods: auth,
tunConfig: sc,
log: log.With(zap.String("sshServer", address)),
}
}
func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
if t.connected.Load() {
return
}
defer t.connected.Store(false)
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
for {
t.connected.Store(true)
err := t.connect(ctx, bannerCb, sessionCb)
if err != nil {
t.log.Error("connect error: ", err)
}
select {
case <-ctx.Done():
return
default:
}
time.Sleep(backoffTime())
}
}
func (t *Tunnel) buildHostKeyCallback() ssh.HostKeyCallback {
if t.tunConfig.HostKeys == nil || len(t.tunConfig.HostKeys) == 0 {
return ssh.InsecureIgnoreHostKey()
}
if len(t.tunConfig.HostKeys) == 1 {
return ssh.FixedHostKey(t.tunConfig.HostKeys[0])
}
return FixedHostKeys(t.tunConfig.HostKeys)
}
// connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect
func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
t.log.Debug("connecting")
sshConfig := &ssh.ClientConfig{
User: t.user,
Auth: t.authMethods,
HostKeyCallback: t.buildHostKeyCallback(),
BannerCallback: bannerCb,
ClientVersion: t.tunConfig.ClientVersion,
}
var sshClient *ssh.Client
var errConnect error
sshClient, errConnect = dialSSH(ctx, t.address.String(), sshConfig)
if errConnect != nil {
return errConnect
}
defer sshClient.Close()
defer t.log.Debug("disconnecting")
t.log.Debug("connected")
listenerStopped := make(chan error)
var wg sync.WaitGroup
sess, err := sshClient.NewSession()
if err != nil {
t.log.Errorf("session error: %s", err)
return err
}
defer sess.Close()
if sessionCb == nil {
sessionCb = func(*ssh.Session) {}
}
if !t.tunConfig.NoPTY {
err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{
ssh.ECHO: 0,
ssh.IGNCR: 1,
})
if err != nil {
t.log.Warnf("PTY allocation failed: %s", err.Error())
}
}
if !t.tunConfig.Shell.Falsy() {
if t.tunConfig.Shell.IsBool() {
if err := sess.Shell(); err != nil {
t.log.Warnf("failed to start empty shell: %s", err.Error())
}
} else {
if err := sess.Start(t.tunConfig.Shell.String()); err != nil {
t.log.Warnf("failed to start shell '%s': %s", t.tunConfig.Shell, err.Error())
}
}
wg.Add(1)
go func() {
defer wg.Done()
_ = sess.Wait()
}()
}
wg.Add(1)
go func() {
defer wg.Done()
sessionCb(sess)
}()
wg.Add(1)
reverseErr := make(chan error)
go func() {
defer wg.Done()
reverseErr <- t.reverseForwardOnePort(sshClient, listenerStopped)
}()
wg.Add(1)
go t.keepAlive(ctx, sshClient, &wg)
if err := <-reverseErr; err != nil {
return err
}
select {
case <-ctx.Done():
return nil
case listenerFirstErr := <-listenerStopped:
select {
case <-ctx.Done():
return nil
default:
return listenerFirstErr
}
}
}
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
dialer := net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
if err != nil {
return nil, err
}
return ssh.NewClient(clConn, newChan, reqChan), nil
}