211 lines
4.8 KiB
Go
211 lines
4.8 KiB
Go
package sshtun
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/function61/gokit/app/backoff"
|
|
"github.com/function61/gokit/io/bidipipe"
|
|
"go.uber.org/zap"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type SessionCallback func(*ssh.Session)
|
|
|
|
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
|
|
|
|
type Tunnel struct {
|
|
user string
|
|
address Endpoint
|
|
forward Forward
|
|
authMethods []ssh.AuthMethod
|
|
log *zap.SugaredLogger
|
|
connected atomic.Bool
|
|
fakeRemoteHost bool
|
|
}
|
|
|
|
func New(address, user string, fakeRemoteHost bool,
|
|
forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel {
|
|
return &Tunnel{
|
|
address: AddrToEndpoint(address),
|
|
user: user,
|
|
fakeRemoteHost: fakeRemoteHost,
|
|
forward: forward,
|
|
authMethods: auth,
|
|
log: log.With(zap.String("sshServer", address)),
|
|
}
|
|
}
|
|
|
|
func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
|
if c.connected.Load() {
|
|
return
|
|
}
|
|
|
|
defer c.connected.Store(false)
|
|
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
|
for {
|
|
c.connected.Store(true)
|
|
err := c.connect(ctx, sessionCb)
|
|
if err != nil {
|
|
c.log.Error("connect error:", err)
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
}
|
|
|
|
time.Sleep(backoffTime())
|
|
}
|
|
}
|
|
|
|
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
|
// will try to re-connect
|
|
func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
|
c.log.Debug("connecting")
|
|
sshConfig := &ssh.ClientConfig{
|
|
User: c.user,
|
|
Auth: c.authMethods,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
}
|
|
|
|
var sshClient *ssh.Client
|
|
var errConnect error
|
|
|
|
sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig)
|
|
if errConnect != nil {
|
|
return errConnect
|
|
}
|
|
|
|
defer sshClient.Close()
|
|
defer c.log.Debug("disconnecting")
|
|
|
|
c.log.Debug("connected")
|
|
listenerStopped := make(chan error)
|
|
|
|
sess, err := sshClient.NewSession()
|
|
if err != nil {
|
|
c.log.Errorf("session error: %s", err)
|
|
return err
|
|
}
|
|
defer sess.Close()
|
|
|
|
var wg sync.WaitGroup
|
|
if sessionCb == nil {
|
|
sessionCb = func(*ssh.Session) {}
|
|
}
|
|
wg.Add(2)
|
|
go func() {
|
|
defer wg.Done()
|
|
sessionCb(sess)
|
|
}()
|
|
reverseErr := make(chan error)
|
|
go func() {
|
|
defer wg.Done()
|
|
reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped)
|
|
}()
|
|
|
|
if err := <-reverseErr; err != nil {
|
|
return err
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
case listenerFirstErr := <-listenerStopped:
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil
|
|
default:
|
|
return listenerFirstErr
|
|
}
|
|
}
|
|
}
|
|
|
|
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
|
//
|
|
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
|
func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
|
|
if c.fakeRemoteHost {
|
|
newSishHostListener(sshClient).ListenFakeRemoteHost(c.forward.Remote)
|
|
time.Sleep(time.Second)
|
|
os.Exit(0)
|
|
}
|
|
listener, err := sshClient.Listen("tcp", c.forward.Remote.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
defer listener.Close()
|
|
c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String())
|
|
|
|
for {
|
|
client, err := listener.Accept()
|
|
if err != nil {
|
|
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
|
return
|
|
}
|
|
|
|
go handleReverseForwardConn(client, c.forward, c.log)
|
|
}
|
|
}()
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Tunnel) listenTCPWithoutResolving() {
|
|
}
|
|
|
|
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
|
defer client.Close()
|
|
|
|
log.Debugf("%s connected", client.RemoteAddr())
|
|
defer log.Debug("closed")
|
|
|
|
remote, err := net.Dial("tcp", forward.Local.String())
|
|
if err != nil {
|
|
log.Errorf("dial INTO local service error: %s", err.Error())
|
|
return
|
|
}
|
|
|
|
// pipe data in both directions:
|
|
// - client => remote
|
|
// - remote => client
|
|
//
|
|
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
|
// remote endpoint.
|
|
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
|
// - this blocks until either of the parties' socket closes (or breaks)
|
|
if err := bidipipe.Pipe(
|
|
bidipipe.WithName("client", client),
|
|
bidipipe.WithName("remote", remote),
|
|
); err != nil {
|
|
log.Error(err)
|
|
}
|
|
}
|
|
|
|
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
|
dialer := net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
|
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return ssh.NewClient(clConn, newChan, reqChan), nil
|
|
}
|