206 lines
4.3 KiB
Go
206 lines
4.3 KiB
Go
package sshtun
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
"github.com/function61/gokit/app/backoff"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type SessionCallback func(*ssh.Session)
|
|
|
|
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
|
|
|
|
type Tunnel struct {
|
|
user string
|
|
address Endpoint
|
|
authMethods []ssh.AuthMethod
|
|
log *zap.SugaredLogger
|
|
tunConfig TunnelConfig
|
|
connected atomic.Bool
|
|
}
|
|
|
|
type TunnelConfig struct {
|
|
Forward Forward
|
|
HostKeyCallback ssh.HostKeyCallback
|
|
NoPTY bool
|
|
Shell BoolOrStr
|
|
ClientVersion string
|
|
FakeRemoteHost bool
|
|
KeepAliveInterval uint
|
|
KeepAliveMax uint
|
|
}
|
|
|
|
type BoolOrStr string
|
|
|
|
func (b BoolOrStr) IsBool() bool {
|
|
return b == "true" || b == "false" || b == "1" || b == "0" || b == ""
|
|
}
|
|
|
|
func (b BoolOrStr) Falsy() bool {
|
|
return b == "" || b == "0" || b == "false"
|
|
}
|
|
|
|
func (b BoolOrStr) String() string {
|
|
return string(b)
|
|
}
|
|
|
|
func New(address, user string, auth []ssh.AuthMethod, sc TunnelConfig, log *zap.SugaredLogger) *Tunnel {
|
|
return &Tunnel{
|
|
address: AddrToEndpoint(address),
|
|
user: user,
|
|
authMethods: auth,
|
|
tunConfig: sc,
|
|
log: log.With(zap.String("sshServer", address)),
|
|
}
|
|
}
|
|
|
|
func (t *Tunnel) Connect(ctx context.Context,
|
|
hsLineCb func(msg string), bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
|
|
if t.connected.Load() {
|
|
return
|
|
}
|
|
|
|
defer t.connected.Store(false)
|
|
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
|
for {
|
|
t.connected.Store(true)
|
|
err := t.connect(ctx, hsLineCb, bannerCb, sessionCb)
|
|
if err != nil {
|
|
t.log.Error("connect error: ", err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
time.Sleep(backoffTime())
|
|
}
|
|
}
|
|
|
|
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
|
// will try to re-connect
|
|
func (t *Tunnel) connect(ctx context.Context,
|
|
hsLineCb func(string), bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
|
|
t.log.Debug("connecting")
|
|
sshConfig := &ssh.ClientConfig{
|
|
Config: ssh.Config{
|
|
HandshakePacketReader: newHandshakePerLineReader(hsLineCb),
|
|
},
|
|
User: t.user,
|
|
Auth: t.authMethods,
|
|
HostKeyCallback: t.tunConfig.HostKeyCallback,
|
|
BannerCallback: bannerCb,
|
|
ClientVersion: t.tunConfig.ClientVersion,
|
|
}
|
|
|
|
var sshClient *ssh.Client
|
|
var errConnect error
|
|
|
|
sshClient, errConnect = dialSSH(ctx, t.address.String(), sshConfig)
|
|
if errConnect != nil {
|
|
return errConnect
|
|
}
|
|
|
|
defer sshClient.Close()
|
|
defer t.log.Debug("disconnecting")
|
|
|
|
t.log.Debug("connected")
|
|
listenerStopped := make(chan error)
|
|
|
|
var wg sync.WaitGroup
|
|
sess, err := sshClient.NewSession()
|
|
if err != nil {
|
|
t.log.Errorf("session error: %s", err)
|
|
return err
|
|
}
|
|
defer sess.Close()
|
|
|
|
if sessionCb == nil {
|
|
sessionCb = func(*ssh.Session) {}
|
|
}
|
|
|
|
if !t.tunConfig.NoPTY {
|
|
err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{
|
|
ssh.ECHO: 0,
|
|
ssh.IGNCR: 1,
|
|
})
|
|
if err != nil {
|
|
t.log.Warnf("PTY allocation failed: %s", err.Error())
|
|
}
|
|
}
|
|
if !t.tunConfig.Shell.Falsy() {
|
|
if t.tunConfig.Shell.IsBool() {
|
|
if err := sess.Shell(); err != nil {
|
|
t.log.Warnf("failed to start empty shell: %s", err.Error())
|
|
}
|
|
} else {
|
|
if err := sess.Start(t.tunConfig.Shell.String()); err != nil {
|
|
t.log.Warnf("failed to start shell '%s': %s", t.tunConfig.Shell, err.Error())
|
|
}
|
|
}
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
_ = sess.Wait()
|
|
}()
|
|
}
|
|
|
|
wg.Add(1)
|
|
go func() {
|
|
defer wg.Done()
|
|
sessionCb(sess)
|
|
}()
|
|
|
|
wg.Add(1)
|
|
reverseErr := make(chan error)
|
|
go func() {
|
|
defer wg.Done()
|
|
reverseErr <- t.reverseForwardOnePort(sshClient, listenerStopped)
|
|
}()
|
|
|
|
wg.Add(1)
|
|
go t.keepAlive(ctx, sshClient, &wg)
|
|
|
|
if err := <-reverseErr; err != nil {
|
|
return err
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case listenerFirstErr := <-listenerStopped:
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
default:
|
|
return listenerFirstErr
|
|
}
|
|
}
|
|
}
|
|
|
|
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
|
dialer := net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
|
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ssh.NewClient(clConn, newChan, reqChan), nil
|
|
}
|