ssh: pty, keepAlive, domain matcher, refactor
This commit is contained in:
parent
eb4d78c483
commit
666daabe95
@ -61,7 +61,7 @@ func (p *pluginAPI) receiverForContext(ctx context.Context) plugin.Plugin {
|
||||
}
|
||||
|
||||
func StartPluginAPI() {
|
||||
port := config.Default.PluginAPIPort
|
||||
port := config.Default.API.PluginPort
|
||||
if port == 0 {
|
||||
port = plugin2.DefaultPort
|
||||
}
|
||||
|
@ -12,12 +12,17 @@ var Default Config
|
||||
|
||||
type Config struct {
|
||||
Debug bool `mapstructure:"debug"`
|
||||
PluginAPIPort int `mapstructure:"plugin_api_port" validate:"gte=0,lte=65535"`
|
||||
API API `mapstructure:"api"`
|
||||
Docker DockerConfig `mapstructure:"docker"`
|
||||
DefaultServer string `mapstructure:"default_server"`
|
||||
Servers []Server `mapstructure:"servers"`
|
||||
}
|
||||
|
||||
type API struct {
|
||||
WebPort int `mapstructure:"web_port" validate:"gte=0,lte=65535"`
|
||||
PluginPort int `mapstructure:"plugin_port" validate:"gte=0,lte=65535"`
|
||||
}
|
||||
|
||||
type DockerConfig struct {
|
||||
FromEnv *bool `mapstructure:"from_env,omitempty"`
|
||||
CertPath string `mapstructure:"cert_path"`
|
||||
|
@ -3,8 +3,10 @@ package ssh
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"path"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
@ -26,6 +28,7 @@ type SSH struct {
|
||||
conns map[string]conn
|
||||
rw sync.RWMutex
|
||||
wg sync.WaitGroup
|
||||
domainRegExp *regexp.Regexp
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
@ -42,26 +45,58 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
|
||||
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid domain_extract_regex: %w", err)
|
||||
}
|
||||
drv.domainRegExp = matcher
|
||||
drv.populateFromSSHConfig()
|
||||
drv.auth = drv.authenticators()
|
||||
return drv, nil
|
||||
}
|
||||
|
||||
func (d *SSH) forward(val sshtun.Forward) conn {
|
||||
func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
|
||||
tun := sshtun.New(d.params.Address,
|
||||
d.params.Auth.User,
|
||||
d.params.FakeRemoteHost,
|
||||
val,
|
||||
d.auth,
|
||||
sshtun.SessionConfig{
|
||||
NoPTY: d.params.NoPTY,
|
||||
FakeRemoteHost: d.params.FakeRemoteHost,
|
||||
KeepAliveInterval: uint(d.params.KeepAlive.Interval),
|
||||
KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts),
|
||||
},
|
||||
d.Log())
|
||||
ctx, cancel := context.WithCancel(d.Context())
|
||||
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
|
||||
go tun.Connect(ctx,
|
||||
sshtun.StdoutPrinterBannerCallback(tunDbgLog),
|
||||
sshtun.StdoutPrinterSessionCallback(tunDbgLog))
|
||||
sshtun.BannerDebugLogCallback(tunDbgLog),
|
||||
sshtun.OutputReaderCallback(func(msg string) {
|
||||
d.Log().Debug(msg)
|
||||
if domainMatcher != nil {
|
||||
domainMatcher(msg)
|
||||
}
|
||||
}))
|
||||
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
||||
}
|
||||
|
||||
func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) {
|
||||
if d.domainRegExp == nil {
|
||||
return nil
|
||||
}
|
||||
return func(msg string) {
|
||||
domain := d.domainRegExp.FindString(msg)
|
||||
if domain == "" {
|
||||
return
|
||||
}
|
||||
d.PushEventStatus(dto.EventStatus{
|
||||
Type: dto.EventStart,
|
||||
ID: containerID,
|
||||
Domain: domain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (d *SSH) populateFromSSHConfig() {
|
||||
if d.params.Auth.Directory == "" {
|
||||
return
|
||||
@ -94,9 +129,9 @@ func (d *SSH) Handle(event dto.Event) error {
|
||||
return ErrAlreadyInUse
|
||||
}
|
||||
conn := d.forward(sshtun.Forward{
|
||||
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
|
||||
Local: d.localEndpoint(event.Container.IP, event.Container.Port),
|
||||
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
||||
})
|
||||
}, d.makeDomainMatcherFunc(event.Container.ID))
|
||||
d.conns[event.Container.ID] = conn
|
||||
d.wg.Add(1)
|
||||
case dto.EventStop:
|
||||
@ -106,17 +141,27 @@ func (d *SSH) Handle(event dto.Event) error {
|
||||
}
|
||||
conn.cancel()
|
||||
delete(d.conns, event.Container.ID)
|
||||
d.propagateStop(event.Container.ID)
|
||||
d.wg.Done()
|
||||
case dto.EventShutdown:
|
||||
for id, conn := range d.conns {
|
||||
conn.cancel()
|
||||
delete(d.conns, id)
|
||||
d.propagateStop(id)
|
||||
d.wg.Done()
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *SSH) propagateStop(containerID string) {
|
||||
d.PushEventStatus(dto.EventStatus{Type: dto.EventStop, ID: containerID})
|
||||
}
|
||||
|
||||
func (d *SSH) localEndpoint(ip net.IP, port uint16) sshtun.Endpoint {
|
||||
return sshtun.AddrToEndpoint(net.JoinHostPort(ip.String(), strconv.Itoa(int(port))))
|
||||
}
|
||||
|
||||
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
|
||||
port := int(d.params.ForwardPort)
|
||||
if port == 0 {
|
||||
|
@ -1,6 +1,8 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
||||
)
|
||||
@ -11,16 +13,19 @@ type Params struct {
|
||||
ForwardPort uint16 `mapstructure:"forward_port"`
|
||||
Auth types.Auth `mapstructure:"auth"`
|
||||
KeepAlive types.KeepAlive `mapstructure:"keepalive"`
|
||||
Domain string `mapstructure:"domain"`
|
||||
DomainProto string `mapstructure:"domain_proto"`
|
||||
DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"`
|
||||
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
|
||||
FakeRemoteHost bool `mapstructure:"fake_remote_host"`
|
||||
NoPTY bool `mapstructure:"nopty"`
|
||||
Commands types.Commands `mapstructure:"commands"`
|
||||
}
|
||||
|
||||
func (p *Params) Validate() error {
|
||||
if err := util.Validator.Struct(p); err != nil {
|
||||
return err
|
||||
}
|
||||
if p.NoPTY && (len(p.Commands.OnConnect) > 0 || len(p.Commands.OnDisconnect) > 0) {
|
||||
return errors.New("commands aren't available without PTY (nopty = true)")
|
||||
}
|
||||
return p.Auth.Validate()
|
||||
}
|
||||
|
28
internal/server/driver/ssh/regexp.go
Normal file
28
internal/server/driver/ssh/regexp.go
Normal file
@ -0,0 +1,28 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var builtInRegExps = map[string]*regexp.Regexp{
|
||||
"webUrl": regexp.MustCompile(`(?m)https?:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
|
||||
"httpUrl": regexp.MustCompile(`(?m)http:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
|
||||
"httpsUrl": regexp.MustCompile(`(?m)https:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
|
||||
}
|
||||
|
||||
func makeDomainCatchRegExp(expr string) (*regexp.Regexp, error) {
|
||||
if expr == "" {
|
||||
return nil, nil
|
||||
}
|
||||
if strings.HasPrefix(expr, "!") {
|
||||
exprName := expr[1:]
|
||||
builtIn, found := builtInRegExps[exprName]
|
||||
if !found {
|
||||
return nil, fmt.Errorf("no builtin regexp with name '%s'", exprName)
|
||||
}
|
||||
return builtIn, nil
|
||||
}
|
||||
return regexp.Compile(expr)
|
||||
}
|
171
internal/server/driver/ssh/sshtun/connect.go
Normal file
171
internal/server/driver/ssh/sshtun/connect.go
Normal file
@ -0,0 +1,171 @@
|
||||
package sshtun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||
"github.com/function61/gokit/app/backoff"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type SessionCallback func(*ssh.Session)
|
||||
|
||||
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
|
||||
|
||||
type Tunnel struct {
|
||||
user string
|
||||
address Endpoint
|
||||
forward Forward
|
||||
authMethods []ssh.AuthMethod
|
||||
log *zap.SugaredLogger
|
||||
sessConfig SessionConfig
|
||||
connected atomic.Bool
|
||||
}
|
||||
|
||||
type SessionConfig struct {
|
||||
NoPTY bool
|
||||
FakeRemoteHost bool
|
||||
KeepAliveInterval uint
|
||||
KeepAliveMax uint
|
||||
}
|
||||
|
||||
func New(address, user string, forward Forward, auth []ssh.AuthMethod, sc SessionConfig, log *zap.SugaredLogger) *Tunnel {
|
||||
return &Tunnel{
|
||||
address: AddrToEndpoint(address),
|
||||
user: user,
|
||||
forward: forward,
|
||||
authMethods: auth,
|
||||
sessConfig: sc,
|
||||
log: log.With(zap.String("sshServer", address)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
|
||||
if t.connected.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
defer t.connected.Store(false)
|
||||
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
||||
for {
|
||||
t.connected.Store(true)
|
||||
err := t.connect(ctx, bannerCb, sessionCb)
|
||||
if err != nil {
|
||||
t.log.Error("connect error: ", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
time.Sleep(backoffTime())
|
||||
}
|
||||
}
|
||||
|
||||
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
||||
// will try to re-connect
|
||||
func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
|
||||
t.log.Debug("connecting")
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: t.user,
|
||||
Auth: t.authMethods,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
BannerCallback: bannerCb,
|
||||
}
|
||||
|
||||
var sshClient *ssh.Client
|
||||
var errConnect error
|
||||
|
||||
sshClient, errConnect = dialSSH(ctx, t.address.String(), sshConfig)
|
||||
if errConnect != nil {
|
||||
return errConnect
|
||||
}
|
||||
|
||||
defer sshClient.Close()
|
||||
defer t.log.Debug("disconnecting")
|
||||
|
||||
t.log.Debug("connected")
|
||||
listenerStopped := make(chan error)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
sess, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
t.log.Errorf("session error: %s", err)
|
||||
return err
|
||||
}
|
||||
defer sess.Close()
|
||||
|
||||
if sessionCb == nil {
|
||||
sessionCb = func(*ssh.Session) {}
|
||||
}
|
||||
|
||||
if !t.sessConfig.NoPTY {
|
||||
err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{
|
||||
ssh.ECHO: 0,
|
||||
ssh.IGNCR: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.log.Warnf("PTY allocation failed: %s", err.Error())
|
||||
} else {
|
||||
if err := sess.Shell(); err != nil {
|
||||
t.log.Warnf("failed to start shell: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sessionCb(sess)
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
reverseErr := make(chan error)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reverseErr <- t.reverseForwardOnePort(sshClient, listenerStopped)
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go t.keepAlive(ctx, sshClient, &wg)
|
||||
|
||||
if err := <-reverseErr; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case listenerFirstErr := <-listenerStopped:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return listenerFirstErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
||||
dialer := net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ssh.NewClient(clConn, newChan, reqChan), nil
|
||||
}
|
@ -4,8 +4,12 @@ import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||
"github.com/function61/gokit/io/bidipipe"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Forward struct {
|
||||
@ -35,3 +39,77 @@ type Endpoint struct {
|
||||
func (endpoint *Endpoint) String() string {
|
||||
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
||||
}
|
||||
|
||||
func (t *Tunnel) ipFromAddr(addr net.Addr) net.IP {
|
||||
host, _, _ := net.SplitHostPort(addr.String())
|
||||
return net.ParseIP(host)
|
||||
}
|
||||
|
||||
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
||||
//
|
||||
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
||||
func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
|
||||
var (
|
||||
listener net.Listener
|
||||
err error
|
||||
)
|
||||
if t.sessConfig.FakeRemoteHost {
|
||||
listener, err = sshClient.ListenTCP(&net.TCPAddr{
|
||||
IP: t.ipFromAddr(sshClient.Conn.RemoteAddr()),
|
||||
Port: t.forward.Remote.Port,
|
||||
}, t.forward.Remote.Host)
|
||||
} else {
|
||||
listener, err = sshClient.Listen("tcp", t.forward.Remote.String())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
t.log.Debugf("forwarding %s <- %s", t.forward.Local.String(), t.forward.Remote.String())
|
||||
|
||||
for {
|
||||
client, err := listener.Accept()
|
||||
if err != nil {
|
||||
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
go handleReverseForwardConn(client, t.forward, t.log)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
||||
defer client.Close()
|
||||
|
||||
remote, err := net.Dial("tcp", forward.Local.String())
|
||||
if err != nil {
|
||||
log.Errorf("cannot dial local service: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("proxying %s <-> %s", forward.Local.String(), client.RemoteAddr())
|
||||
|
||||
// pipe data in both directions:
|
||||
// - client => remote
|
||||
// - remote => client
|
||||
//
|
||||
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
||||
// remote endpoint.
|
||||
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
||||
// - this blocks until either of the parties' socket closes (or breaks)
|
||||
if err := bidipipe.Pipe(
|
||||
bidipipe.WithName("client", client),
|
||||
bidipipe.WithName("remote", remote),
|
||||
); err != nil {
|
||||
// we can safely ignore those errors.
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
return
|
||||
}
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
62
internal/server/driver/ssh/sshtun/keepalive.go
Normal file
62
internal/server/driver/ssh/sshtun/keepalive.go
Normal file
@ -0,0 +1,62 @@
|
||||
package sshtun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||
)
|
||||
|
||||
// keepAlive periodically sends messages to invoke a response.
|
||||
// If the server does not respond after some period of time,
|
||||
// assume that the underlying net.Conn abruptly died.
|
||||
func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) {
|
||||
defer wg.Done()
|
||||
if t.sessConfig.KeepAliveInterval == 0 || t.sessConfig.KeepAliveMax == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// Detect when the SSH connection is closed.
|
||||
wait := make(chan error, 1)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wait <- client.Wait()
|
||||
}()
|
||||
|
||||
// Repeatedly check if the remote server is still alive.
|
||||
var aliveCount int32
|
||||
ticker := time.NewTicker(time.Duration(t.sessConfig.KeepAliveInterval) * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.log.Debug("stopping keep-alive...")
|
||||
_ = client.Close()
|
||||
return
|
||||
case err := <-wait:
|
||||
if err != nil && err != io.EOF {
|
||||
t.log.Error("ssh error:", err)
|
||||
}
|
||||
return
|
||||
case <-ticker.C:
|
||||
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.sessConfig.KeepAliveMax) {
|
||||
t.log.Error("keep-alive failed, closing connection...")
|
||||
_ = client.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
|
||||
if err == nil {
|
||||
atomic.StoreInt32(&aliveCount, 0)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
@ -9,7 +9,7 @@ import (
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
||||
func OutputReaderCallback(callback func(string)) SessionCallback {
|
||||
return func(session *ssh.Session) {
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
@ -27,7 +27,7 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
||||
combined.Read(func(r io.Reader) error {
|
||||
scan := bufio.NewScanner(r)
|
||||
for scan.Scan() {
|
||||
log.Debug(scan.Text())
|
||||
callback(scan.Text())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
@ -35,7 +35,7 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
||||
}
|
||||
}
|
||||
|
||||
func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback {
|
||||
func BannerDebugLogCallback(log *zap.SugaredLogger) ssh.BannerCallback {
|
||||
return func(msg string) error {
|
||||
log.Debug(msg)
|
||||
return nil
|
||||
|
@ -1,224 +0,0 @@
|
||||
package sshtun
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||
"github.com/function61/gokit/app/backoff"
|
||||
"github.com/function61/gokit/io/bidipipe"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type SessionCallback func(*ssh.Session)
|
||||
|
||||
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
|
||||
|
||||
type Tunnel struct {
|
||||
user string
|
||||
address Endpoint
|
||||
forward Forward
|
||||
authMethods []ssh.AuthMethod
|
||||
log *zap.SugaredLogger
|
||||
connected atomic.Bool
|
||||
fakeRemoteHost bool
|
||||
}
|
||||
|
||||
func New(address, user string, fakeRemoteHost bool,
|
||||
forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel {
|
||||
return &Tunnel{
|
||||
address: AddrToEndpoint(address),
|
||||
user: user,
|
||||
fakeRemoteHost: fakeRemoteHost,
|
||||
forward: forward,
|
||||
authMethods: auth,
|
||||
log: log.With(zap.String("sshServer", address)),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
|
||||
if c.connected.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
defer c.connected.Store(false)
|
||||
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
||||
for {
|
||||
c.connected.Store(true)
|
||||
err := c.connect(ctx, bannerCb, sessionCb)
|
||||
if err != nil {
|
||||
c.log.Error("connect error:", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
time.Sleep(backoffTime())
|
||||
}
|
||||
}
|
||||
|
||||
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
||||
// will try to re-connect
|
||||
func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
|
||||
c.log.Debug("connecting")
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: c.user,
|
||||
Auth: c.authMethods,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
BannerCallback: bannerCb,
|
||||
}
|
||||
|
||||
var sshClient *ssh.Client
|
||||
var errConnect error
|
||||
|
||||
sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig)
|
||||
if errConnect != nil {
|
||||
return errConnect
|
||||
}
|
||||
|
||||
defer sshClient.Close()
|
||||
defer c.log.Debug("disconnecting")
|
||||
|
||||
c.log.Debug("connected")
|
||||
listenerStopped := make(chan error)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
sess, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
c.log.Errorf("session error: %s", err)
|
||||
return err
|
||||
}
|
||||
defer sess.Close()
|
||||
|
||||
if sessionCb == nil {
|
||||
sessionCb = func(*ssh.Session) {}
|
||||
}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sessionCb(sess)
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
reverseErr := make(chan error)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped)
|
||||
}()
|
||||
|
||||
if err := <-reverseErr; err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
case listenerFirstErr := <-listenerStopped:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
default:
|
||||
return listenerFirstErr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
||||
//
|
||||
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
||||
func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
|
||||
var (
|
||||
listener net.Listener
|
||||
err error
|
||||
)
|
||||
if c.fakeRemoteHost {
|
||||
listener, err = sshClient.ListenTCP(&net.TCPAddr{
|
||||
IP: c.ipFromAddr(sshClient.Conn.RemoteAddr()),
|
||||
Port: c.forward.Remote.Port,
|
||||
}, c.forward.Remote.Host)
|
||||
} else {
|
||||
listener, err = sshClient.Listen("tcp", c.forward.Remote.String())
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String())
|
||||
|
||||
for {
|
||||
client, err := listener.Accept()
|
||||
if err != nil {
|
||||
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
go handleReverseForwardConn(client, c.forward, c.log)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Tunnel) ipFromAddr(addr net.Addr) net.IP {
|
||||
host, _, _ := net.SplitHostPort(addr.String())
|
||||
return net.ParseIP(host)
|
||||
}
|
||||
|
||||
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
||||
defer client.Close()
|
||||
|
||||
remote, err := net.Dial("tcp", forward.Local.String())
|
||||
if err != nil {
|
||||
log.Errorf("cannot dial local service: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("proxying %s <-> %s", forward.Local.String(), client.RemoteAddr())
|
||||
|
||||
// pipe data in both directions:
|
||||
// - client => remote
|
||||
// - remote => client
|
||||
//
|
||||
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
||||
// remote endpoint.
|
||||
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
||||
// - this blocks until either of the parties' socket closes (or breaks)
|
||||
if err := bidipipe.Pipe(
|
||||
bidipipe.WithName("client", client),
|
||||
bidipipe.WithName("remote", remote),
|
||||
); err != nil {
|
||||
// we can safely ignore those errors.
|
||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||
return
|
||||
}
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
||||
dialer := net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ssh.NewClient(clConn, newChan, reqChan), nil
|
||||
}
|
6
internal/server/driver/ssh/types/commands.go
Normal file
6
internal/server/driver/ssh/types/commands.go
Normal file
@ -0,0 +1,6 @@
|
||||
package types
|
||||
|
||||
type Commands struct {
|
||||
OnConnect []string `mapstructure:"on_connect"`
|
||||
OnDisconnect []string `mapstructure:"on_disconnect"`
|
||||
}
|
@ -9,7 +9,7 @@ import (
|
||||
var Validator *validator.Validate
|
||||
|
||||
func init() {
|
||||
Validator = validator.New()
|
||||
Validator = validator.New(validator.WithRequiredStructEnabled())
|
||||
_ = Validator.RegisterValidation("validregexp", isValidRegExp)
|
||||
}
|
||||
|
||||
|
@ -89,6 +89,8 @@ func (m *Manager) eventStatusCallback(serverName string) base.EventStatusCallbac
|
||||
}
|
||||
|
||||
func (m *Manager) processEventStatus(serverName string, event dto.EventStatus) {
|
||||
logger.Sugar.Debugw("received EventStatus from server",
|
||||
"serverName", serverName, "eventStatus", event)
|
||||
m.statusLock.RLock()
|
||||
_, exists := m.statusMap[serverName]
|
||||
if !exists {
|
||||
|
@ -12,6 +12,22 @@ const (
|
||||
EventUnknown
|
||||
)
|
||||
|
||||
var eventTypeNames = map[EventType]string{
|
||||
EventStart: "start",
|
||||
EventStop: "stop",
|
||||
EventShutdown: "shutdown",
|
||||
EventError: "error",
|
||||
EventUnknown: "unknown",
|
||||
}
|
||||
|
||||
func (e EventType) String() string {
|
||||
name, ok := eventTypeNames[e]
|
||||
if ok {
|
||||
return name
|
||||
}
|
||||
return eventTypeNames[EventUnknown]
|
||||
}
|
||||
|
||||
func TypeFromAction(action string) EventType {
|
||||
switch action {
|
||||
case "start":
|
||||
|
Loading…
Reference in New Issue
Block a user