ssh: pty, keepAlive, domain matcher, refactor

This commit is contained in:
Pavel 2023-11-19 13:06:38 +03:00
parent eb4d78c483
commit 666daabe95
14 changed files with 437 additions and 243 deletions

View File

@ -61,7 +61,7 @@ func (p *pluginAPI) receiverForContext(ctx context.Context) plugin.Plugin {
}
func StartPluginAPI() {
port := config.Default.PluginAPIPort
port := config.Default.API.PluginPort
if port == 0 {
port = plugin2.DefaultPort
}

View File

@ -12,12 +12,17 @@ var Default Config
type Config struct {
Debug bool `mapstructure:"debug"`
PluginAPIPort int `mapstructure:"plugin_api_port" validate:"gte=0,lte=65535"`
API API `mapstructure:"api"`
Docker DockerConfig `mapstructure:"docker"`
DefaultServer string `mapstructure:"default_server"`
Servers []Server `mapstructure:"servers"`
}
type API struct {
WebPort int `mapstructure:"web_port" validate:"gte=0,lte=65535"`
PluginPort int `mapstructure:"plugin_port" validate:"gte=0,lte=65535"`
}
type DockerConfig struct {
FromEnv *bool `mapstructure:"from_env,omitempty"`
CertPath string `mapstructure:"cert_path"`

View File

@ -3,8 +3,10 @@ package ssh
import (
"context"
"errors"
"fmt"
"net"
"path"
"regexp"
"strconv"
"sync"
@ -21,11 +23,12 @@ var ErrAlreadyInUse = errors.New("domain is already in use")
type SSH struct {
base.Base
params Params
auth []ssh.AuthMethod
conns map[string]conn
rw sync.RWMutex
wg sync.WaitGroup
params Params
auth []ssh.AuthMethod
conns map[string]conn
rw sync.RWMutex
wg sync.WaitGroup
domainRegExp *regexp.Regexp
}
type conn struct {
@ -42,26 +45,58 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
if err := util.UnmarshalParams(params, &drv.params); err != nil {
return nil, err
}
matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex)
if err != nil {
return nil, fmt.Errorf("invalid domain_extract_regex: %w", err)
}
drv.domainRegExp = matcher
drv.populateFromSSHConfig()
drv.auth = drv.authenticators()
return drv, nil
}
func (d *SSH) forward(val sshtun.Forward) conn {
func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
tun := sshtun.New(d.params.Address,
d.params.Auth.User,
d.params.FakeRemoteHost,
val,
d.auth,
sshtun.SessionConfig{
NoPTY: d.params.NoPTY,
FakeRemoteHost: d.params.FakeRemoteHost,
KeepAliveInterval: uint(d.params.KeepAlive.Interval),
KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts),
},
d.Log())
ctx, cancel := context.WithCancel(d.Context())
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
go tun.Connect(ctx,
sshtun.StdoutPrinterBannerCallback(tunDbgLog),
sshtun.StdoutPrinterSessionCallback(tunDbgLog))
sshtun.BannerDebugLogCallback(tunDbgLog),
sshtun.OutputReaderCallback(func(msg string) {
d.Log().Debug(msg)
if domainMatcher != nil {
domainMatcher(msg)
}
}))
return conn{ctx: ctx, cancel: cancel, tun: tun}
}
func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) {
if d.domainRegExp == nil {
return nil
}
return func(msg string) {
domain := d.domainRegExp.FindString(msg)
if domain == "" {
return
}
d.PushEventStatus(dto.EventStatus{
Type: dto.EventStart,
ID: containerID,
Domain: domain,
})
}
}
func (d *SSH) populateFromSSHConfig() {
if d.params.Auth.Directory == "" {
return
@ -94,9 +129,9 @@ func (d *SSH) Handle(event dto.Event) error {
return ErrAlreadyInUse
}
conn := d.forward(sshtun.Forward{
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
Local: d.localEndpoint(event.Container.IP, event.Container.Port),
Remote: d.remoteEndpoint(event.Container.RemoteHost),
})
}, d.makeDomainMatcherFunc(event.Container.ID))
d.conns[event.Container.ID] = conn
d.wg.Add(1)
case dto.EventStop:
@ -106,17 +141,27 @@ func (d *SSH) Handle(event dto.Event) error {
}
conn.cancel()
delete(d.conns, event.Container.ID)
d.propagateStop(event.Container.ID)
d.wg.Done()
case dto.EventShutdown:
for id, conn := range d.conns {
conn.cancel()
delete(d.conns, id)
d.propagateStop(id)
d.wg.Done()
}
}
return nil
}
func (d *SSH) propagateStop(containerID string) {
d.PushEventStatus(dto.EventStatus{Type: dto.EventStop, ID: containerID})
}
func (d *SSH) localEndpoint(ip net.IP, port uint16) sshtun.Endpoint {
return sshtun.AddrToEndpoint(net.JoinHostPort(ip.String(), strconv.Itoa(int(port))))
}
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
port := int(d.params.ForwardPort)
if port == 0 {

View File

@ -1,6 +1,8 @@
package ssh
import (
"errors"
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
)
@ -11,16 +13,19 @@ type Params struct {
ForwardPort uint16 `mapstructure:"forward_port"`
Auth types.Auth `mapstructure:"auth"`
KeepAlive types.KeepAlive `mapstructure:"keepalive"`
Domain string `mapstructure:"domain"`
DomainProto string `mapstructure:"domain_proto"`
DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"`
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
FakeRemoteHost bool `mapstructure:"fake_remote_host"`
NoPTY bool `mapstructure:"nopty"`
Commands types.Commands `mapstructure:"commands"`
}
func (p *Params) Validate() error {
if err := util.Validator.Struct(p); err != nil {
return err
}
if p.NoPTY && (len(p.Commands.OnConnect) > 0 || len(p.Commands.OnDisconnect) > 0) {
return errors.New("commands aren't available without PTY (nopty = true)")
}
return p.Auth.Validate()
}

View File

@ -0,0 +1,28 @@
package ssh
import (
"fmt"
"regexp"
"strings"
)
var builtInRegExps = map[string]*regexp.Regexp{
"webUrl": regexp.MustCompile(`(?m)https?:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
"httpUrl": regexp.MustCompile(`(?m)http:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
"httpsUrl": regexp.MustCompile(`(?m)https:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
}
func makeDomainCatchRegExp(expr string) (*regexp.Regexp, error) {
if expr == "" {
return nil, nil
}
if strings.HasPrefix(expr, "!") {
exprName := expr[1:]
builtIn, found := builtInRegExps[exprName]
if !found {
return nil, fmt.Errorf("no builtin regexp with name '%s'", exprName)
}
return builtIn, nil
}
return regexp.Compile(expr)
}

View File

@ -0,0 +1,171 @@
package sshtun
import (
"context"
"net"
"sync"
"sync/atomic"
"time"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
"github.com/function61/gokit/app/backoff"
"go.uber.org/zap"
)
type SessionCallback func(*ssh.Session)
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
type Tunnel struct {
user string
address Endpoint
forward Forward
authMethods []ssh.AuthMethod
log *zap.SugaredLogger
sessConfig SessionConfig
connected atomic.Bool
}
type SessionConfig struct {
NoPTY bool
FakeRemoteHost bool
KeepAliveInterval uint
KeepAliveMax uint
}
func New(address, user string, forward Forward, auth []ssh.AuthMethod, sc SessionConfig, log *zap.SugaredLogger) *Tunnel {
return &Tunnel{
address: AddrToEndpoint(address),
user: user,
forward: forward,
authMethods: auth,
sessConfig: sc,
log: log.With(zap.String("sshServer", address)),
}
}
func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
if t.connected.Load() {
return
}
defer t.connected.Store(false)
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
for {
t.connected.Store(true)
err := t.connect(ctx, bannerCb, sessionCb)
if err != nil {
t.log.Error("connect error: ", err)
}
select {
case <-ctx.Done():
return
default:
}
time.Sleep(backoffTime())
}
}
// connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect
func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
t.log.Debug("connecting")
sshConfig := &ssh.ClientConfig{
User: t.user,
Auth: t.authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
BannerCallback: bannerCb,
}
var sshClient *ssh.Client
var errConnect error
sshClient, errConnect = dialSSH(ctx, t.address.String(), sshConfig)
if errConnect != nil {
return errConnect
}
defer sshClient.Close()
defer t.log.Debug("disconnecting")
t.log.Debug("connected")
listenerStopped := make(chan error)
var wg sync.WaitGroup
sess, err := sshClient.NewSession()
if err != nil {
t.log.Errorf("session error: %s", err)
return err
}
defer sess.Close()
if sessionCb == nil {
sessionCb = func(*ssh.Session) {}
}
if !t.sessConfig.NoPTY {
err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{
ssh.ECHO: 0,
ssh.IGNCR: 1,
})
if err != nil {
t.log.Warnf("PTY allocation failed: %s", err.Error())
} else {
if err := sess.Shell(); err != nil {
t.log.Warnf("failed to start shell: %s", err.Error())
}
}
}
wg.Add(1)
go func() {
defer wg.Done()
sessionCb(sess)
}()
wg.Add(1)
reverseErr := make(chan error)
go func() {
defer wg.Done()
reverseErr <- t.reverseForwardOnePort(sshClient, listenerStopped)
}()
wg.Add(1)
go t.keepAlive(ctx, sshClient, &wg)
if err := <-reverseErr; err != nil {
return err
}
select {
case <-ctx.Done():
return nil
case listenerFirstErr := <-listenerStopped:
select {
case <-ctx.Done():
return nil
default:
return listenerFirstErr
}
}
}
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
dialer := net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
if err != nil {
return nil, err
}
return ssh.NewClient(clConn, newChan, reqChan), nil
}

View File

@ -4,8 +4,12 @@ import (
"fmt"
"net"
"strconv"
"strings"
"github.com/Neur0toxine/sshpoke/pkg/errtools"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
"github.com/function61/gokit/io/bidipipe"
"go.uber.org/zap"
)
type Forward struct {
@ -35,3 +39,77 @@ type Endpoint struct {
func (endpoint *Endpoint) String() string {
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
}
func (t *Tunnel) ipFromAddr(addr net.Addr) net.IP {
host, _, _ := net.SplitHostPort(addr.String())
return net.ParseIP(host)
}
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
//
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
var (
listener net.Listener
err error
)
if t.sessConfig.FakeRemoteHost {
listener, err = sshClient.ListenTCP(&net.TCPAddr{
IP: t.ipFromAddr(sshClient.Conn.RemoteAddr()),
Port: t.forward.Remote.Port,
}, t.forward.Remote.Host)
} else {
listener, err = sshClient.Listen("tcp", t.forward.Remote.String())
}
if err != nil {
return err
}
go func() {
defer listener.Close()
t.log.Debugf("forwarding %s <- %s", t.forward.Local.String(), t.forward.Remote.String())
for {
client, err := listener.Accept()
if err != nil {
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
return
}
go handleReverseForwardConn(client, t.forward, t.log)
}
}()
return nil
}
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
defer client.Close()
remote, err := net.Dial("tcp", forward.Local.String())
if err != nil {
log.Errorf("cannot dial local service: %s", err.Error())
return
}
log.Debugf("proxying %s <-> %s", forward.Local.String(), client.RemoteAddr())
// pipe data in both directions:
// - client => remote
// - remote => client
//
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
// remote endpoint.
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
// - this blocks until either of the parties' socket closes (or breaks)
if err := bidipipe.Pipe(
bidipipe.WithName("client", client),
bidipipe.WithName("remote", remote),
); err != nil {
// we can safely ignore those errors.
if strings.Contains(err.Error(), "use of closed network connection") {
return
}
log.Error(err)
}
}

View File

@ -0,0 +1,62 @@
package sshtun
import (
"context"
"io"
"sync"
"sync/atomic"
"time"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
)
// keepAlive periodically sends messages to invoke a response.
// If the server does not respond after some period of time,
// assume that the underlying net.Conn abruptly died.
func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) {
defer wg.Done()
if t.sessConfig.KeepAliveInterval == 0 || t.sessConfig.KeepAliveMax == 0 {
return
}
// Detect when the SSH connection is closed.
wait := make(chan error, 1)
wg.Add(1)
go func() {
defer wg.Done()
wait <- client.Wait()
}()
// Repeatedly check if the remote server is still alive.
var aliveCount int32
ticker := time.NewTicker(time.Duration(t.sessConfig.KeepAliveInterval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
t.log.Debug("stopping keep-alive...")
_ = client.Close()
return
case err := <-wait:
if err != nil && err != io.EOF {
t.log.Error("ssh error:", err)
}
return
case <-ticker.C:
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.sessConfig.KeepAliveMax) {
t.log.Error("keep-alive failed, closing connection...")
_ = client.Close()
return
}
}
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
if err == nil {
atomic.StoreInt32(&aliveCount, 0)
}
}()
}
}

View File

@ -9,7 +9,7 @@ import (
"go.uber.org/zap"
)
func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
func OutputReaderCallback(callback func(string)) SessionCallback {
return func(session *ssh.Session) {
stdout, err := session.StdoutPipe()
if err != nil {
@ -27,7 +27,7 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
combined.Read(func(r io.Reader) error {
scan := bufio.NewScanner(r)
for scan.Scan() {
log.Debug(scan.Text())
callback(scan.Text())
}
return nil
})
@ -35,7 +35,7 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
}
}
func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback {
func BannerDebugLogCallback(log *zap.SugaredLogger) ssh.BannerCallback {
return func(msg string) error {
log.Debug(msg)
return nil

View File

@ -1,224 +0,0 @@
package sshtun
import (
"context"
"fmt"
"net"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
"github.com/function61/gokit/app/backoff"
"github.com/function61/gokit/io/bidipipe"
"go.uber.org/zap"
)
type SessionCallback func(*ssh.Session)
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
type Tunnel struct {
user string
address Endpoint
forward Forward
authMethods []ssh.AuthMethod
log *zap.SugaredLogger
connected atomic.Bool
fakeRemoteHost bool
}
func New(address, user string, fakeRemoteHost bool,
forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel {
return &Tunnel{
address: AddrToEndpoint(address),
user: user,
fakeRemoteHost: fakeRemoteHost,
forward: forward,
authMethods: auth,
log: log.With(zap.String("sshServer", address)),
}
}
func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
if c.connected.Load() {
return
}
defer c.connected.Store(false)
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
for {
c.connected.Store(true)
err := c.connect(ctx, bannerCb, sessionCb)
if err != nil {
c.log.Error("connect error:", err)
}
select {
case <-ctx.Done():
return
default:
}
time.Sleep(backoffTime())
}
}
// connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect
func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
c.log.Debug("connecting")
sshConfig := &ssh.ClientConfig{
User: c.user,
Auth: c.authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
BannerCallback: bannerCb,
}
var sshClient *ssh.Client
var errConnect error
sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig)
if errConnect != nil {
return errConnect
}
defer sshClient.Close()
defer c.log.Debug("disconnecting")
c.log.Debug("connected")
listenerStopped := make(chan error)
var wg sync.WaitGroup
sess, err := sshClient.NewSession()
if err != nil {
c.log.Errorf("session error: %s", err)
return err
}
defer sess.Close()
if sessionCb == nil {
sessionCb = func(*ssh.Session) {}
}
wg.Add(1)
go func() {
defer wg.Done()
sessionCb(sess)
}()
wg.Add(1)
reverseErr := make(chan error)
go func() {
defer wg.Done()
reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped)
}()
if err := <-reverseErr; err != nil {
return err
}
select {
case <-ctx.Done():
return nil
case listenerFirstErr := <-listenerStopped:
select {
case <-ctx.Done():
return nil
default:
return listenerFirstErr
}
}
}
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
//
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
var (
listener net.Listener
err error
)
if c.fakeRemoteHost {
listener, err = sshClient.ListenTCP(&net.TCPAddr{
IP: c.ipFromAddr(sshClient.Conn.RemoteAddr()),
Port: c.forward.Remote.Port,
}, c.forward.Remote.Host)
} else {
listener, err = sshClient.Listen("tcp", c.forward.Remote.String())
}
if err != nil {
return err
}
go func() {
defer listener.Close()
c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String())
for {
client, err := listener.Accept()
if err != nil {
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
return
}
go handleReverseForwardConn(client, c.forward, c.log)
}
}()
return nil
}
func (c *Tunnel) ipFromAddr(addr net.Addr) net.IP {
host, _, _ := net.SplitHostPort(addr.String())
return net.ParseIP(host)
}
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
defer client.Close()
remote, err := net.Dial("tcp", forward.Local.String())
if err != nil {
log.Errorf("cannot dial local service: %s", err.Error())
return
}
log.Debugf("proxying %s <-> %s", forward.Local.String(), client.RemoteAddr())
// pipe data in both directions:
// - client => remote
// - remote => client
//
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
// remote endpoint.
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
// - this blocks until either of the parties' socket closes (or breaks)
if err := bidipipe.Pipe(
bidipipe.WithName("client", client),
bidipipe.WithName("remote", remote),
); err != nil {
// we can safely ignore those errors.
if strings.Contains(err.Error(), "use of closed network connection") {
return
}
log.Error(err)
}
}
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
dialer := net.Dialer{
Timeout: 10 * time.Second,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, err
}
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
if err != nil {
return nil, err
}
return ssh.NewClient(clConn, newChan, reqChan), nil
}

View File

@ -0,0 +1,6 @@
package types
type Commands struct {
OnConnect []string `mapstructure:"on_connect"`
OnDisconnect []string `mapstructure:"on_disconnect"`
}

View File

@ -9,7 +9,7 @@ import (
var Validator *validator.Validate
func init() {
Validator = validator.New()
Validator = validator.New(validator.WithRequiredStructEnabled())
_ = Validator.RegisterValidation("validregexp", isValidRegExp)
}

View File

@ -89,6 +89,8 @@ func (m *Manager) eventStatusCallback(serverName string) base.EventStatusCallbac
}
func (m *Manager) processEventStatus(serverName string, event dto.EventStatus) {
logger.Sugar.Debugw("received EventStatus from server",
"serverName", serverName, "eventStatus", event)
m.statusLock.RLock()
_, exists := m.statusMap[serverName]
if !exists {

View File

@ -12,6 +12,22 @@ const (
EventUnknown
)
var eventTypeNames = map[EventType]string{
EventStart: "start",
EventStop: "stop",
EventShutdown: "shutdown",
EventError: "error",
EventUnknown: "unknown",
}
func (e EventType) String() string {
name, ok := eventTypeNames[e]
if ok {
return name
}
return eventTypeNames[EventUnknown]
}
func TypeFromAction(action string) EventType {
switch action {
case "start":