sshpoke/internal/server/driver/ssh/sshtun/connect.go

208 lines
4.3 KiB
Go
Raw Normal View History

package sshtun
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
"github.com/function61/gokit/app/backoff"
"go.uber.org/zap"
)
type SessionCallback func(*ssh.Session)
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
type Tunnel struct {
user string
address Endpoint
authMethods []ssh.AuthMethod
log *zap.SugaredLogger
tunConfig TunnelConfig
connected atomic.Bool
}
type TunnelConfig struct {
Forward Forward
HostKeyCallback ssh.HostKeyCallback
Shell *BoolOrStr
ClientVersion string
FakeRemoteHost bool
KeepAliveInterval uint
KeepAliveMax uint
}
type BoolOrStr string
func (b *BoolOrStr) IsBool() bool {
if b == nil {
return false
}
v := *b
return v == "true" || v == "false" || v == "1" || v == "0" || v == ""
}
func (b *BoolOrStr) Falsy() bool {
if b == nil {
return true
}
v := *b
return v == "" || v == "0" || v == "false"
}
func (b *BoolOrStr) String() string {
if b == nil {
return ""
}
return string(*b)
}
func New(address, user string, auth []ssh.AuthMethod, sc TunnelConfig, log *zap.SugaredLogger) *Tunnel {
return &Tunnel{
address: AddrToEndpoint(address),
user: user,
authMethods: auth,
tunConfig: sc,
log: log.With(zap.String("sshServer", address)),
}
}
func (t *Tunnel) Connect(ctx context.Context,
hsLineCb func(msg string), bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
if t.connected.Load() {
return
}
defer t.connected.Store(false)
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
for {
t.connected.Store(true)
err := t.connect(ctx, hsLineCb, bannerCb, sessionCb)
if err != nil {
t.log.Error("connect error: ", err)
}
select {
case <-ctx.Done():
return
default:
}
time.Sleep(backoffTime())
}
}
// connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect
func (t *Tunnel) connect(ctx context.Context,
hsLineCb func(string), bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
t.log.Debug("connecting")
sshConfig := &ssh.ClientConfig{
Config: ssh.Config{
HandshakePacketReader: newHandshakePerLineReader(hsLineCb),
},
User: t.user,
Auth: t.authMethods,
HostKeyCallback: t.tunConfig.HostKeyCallback,
BannerCallback: bannerCb,
ClientVersion: t.tunConfig.ClientVersion,
}
var sshClient *ssh.Client
var errConnect error
sshClient, errConnect = dialSSH(ctx, t.address.String(), sshConfig)
if errConnect != nil {
return errConnect
}
defer sshClient.Close()
defer t.log.Debug("disconnecting")
t.log.Debug("connected")
listenerStopped := make(chan error)
var wg sync.WaitGroup
sess, err := sshClient.NewSession()
if err != nil {
t.log.Errorf("session error: %s", err)
return err
}
defer sess.Close()
if sessionCb == nil {
sessionCb = func(*ssh.Session) {}
}
if t.tunConfig.Shell == nil || !t.tunConfig.Shell.Falsy() {
if t.tunConfig.Shell == nil || t.tunConfig.Shell.IsBool() {
if err := sess.Shell(); err != nil {
t.log.Warnf("failed to start empty shell: %s", err.Error())
}
} else {
if err := sess.Start(t.tunConfig.Shell.String()); err != nil {
t.log.Warnf("failed to start shell '%s': %s", t.tunConfig.Shell, err.Error())
}
2023-11-19 14:12:39 +03:00
}
wg.Add(1)
go func() {
defer wg.Done()
_ = sess.Wait()
t.log.Debug("main session has been terminated")
2023-11-19 14:12:39 +03:00
}()
}
wg.Add(1)
go func() {
defer wg.Done()
sessionCb(sess)
}()
wg.Add(1)
reverseErr := make(chan error)
go func() {
defer wg.Done()
reverseErr <- t.reverseForwardOnePort(sshClient, listenerStopped)
}()
wg.Add(1)
go t.keepAlive(ctx, sshClient, &wg)
if err := <-reverseErr; err != nil {
return err
}
select {
case <-ctx.Done():
return nil
case listenerFirstErr := <-listenerStopped:
select {
case <-ctx.Done():
return nil
default:
return listenerFirstErr
}
}
}
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
dialer := net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
if err != nil {
return nil, err
}
return ssh.NewClient(clConn, newChan, reqChan), nil
}