ssh: pty, keepAlive, domain matcher, refactor
This commit is contained in:
parent
eb4d78c483
commit
666daabe95
@ -61,7 +61,7 @@ func (p *pluginAPI) receiverForContext(ctx context.Context) plugin.Plugin {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func StartPluginAPI() {
|
func StartPluginAPI() {
|
||||||
port := config.Default.PluginAPIPort
|
port := config.Default.API.PluginPort
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
port = plugin2.DefaultPort
|
port = plugin2.DefaultPort
|
||||||
}
|
}
|
||||||
|
@ -12,12 +12,17 @@ var Default Config
|
|||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Debug bool `mapstructure:"debug"`
|
Debug bool `mapstructure:"debug"`
|
||||||
PluginAPIPort int `mapstructure:"plugin_api_port" validate:"gte=0,lte=65535"`
|
API API `mapstructure:"api"`
|
||||||
Docker DockerConfig `mapstructure:"docker"`
|
Docker DockerConfig `mapstructure:"docker"`
|
||||||
DefaultServer string `mapstructure:"default_server"`
|
DefaultServer string `mapstructure:"default_server"`
|
||||||
Servers []Server `mapstructure:"servers"`
|
Servers []Server `mapstructure:"servers"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type API struct {
|
||||||
|
WebPort int `mapstructure:"web_port" validate:"gte=0,lte=65535"`
|
||||||
|
PluginPort int `mapstructure:"plugin_port" validate:"gte=0,lte=65535"`
|
||||||
|
}
|
||||||
|
|
||||||
type DockerConfig struct {
|
type DockerConfig struct {
|
||||||
FromEnv *bool `mapstructure:"from_env,omitempty"`
|
FromEnv *bool `mapstructure:"from_env,omitempty"`
|
||||||
CertPath string `mapstructure:"cert_path"`
|
CertPath string `mapstructure:"cert_path"`
|
||||||
|
@ -3,8 +3,10 @@ package ssh
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"path"
|
"path"
|
||||||
|
"regexp"
|
||||||
"strconv"
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@ -26,6 +28,7 @@ type SSH struct {
|
|||||||
conns map[string]conn
|
conns map[string]conn
|
||||||
rw sync.RWMutex
|
rw sync.RWMutex
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
domainRegExp *regexp.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn struct {
|
type conn struct {
|
||||||
@ -42,26 +45,58 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
|
|||||||
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid domain_extract_regex: %w", err)
|
||||||
|
}
|
||||||
|
drv.domainRegExp = matcher
|
||||||
drv.populateFromSSHConfig()
|
drv.populateFromSSHConfig()
|
||||||
drv.auth = drv.authenticators()
|
drv.auth = drv.authenticators()
|
||||||
return drv, nil
|
return drv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) forward(val sshtun.Forward) conn {
|
func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
|
||||||
tun := sshtun.New(d.params.Address,
|
tun := sshtun.New(d.params.Address,
|
||||||
d.params.Auth.User,
|
d.params.Auth.User,
|
||||||
d.params.FakeRemoteHost,
|
|
||||||
val,
|
val,
|
||||||
d.auth,
|
d.auth,
|
||||||
|
sshtun.SessionConfig{
|
||||||
|
NoPTY: d.params.NoPTY,
|
||||||
|
FakeRemoteHost: d.params.FakeRemoteHost,
|
||||||
|
KeepAliveInterval: uint(d.params.KeepAlive.Interval),
|
||||||
|
KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts),
|
||||||
|
},
|
||||||
d.Log())
|
d.Log())
|
||||||
ctx, cancel := context.WithCancel(d.Context())
|
ctx, cancel := context.WithCancel(d.Context())
|
||||||
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
|
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
|
||||||
go tun.Connect(ctx,
|
go tun.Connect(ctx,
|
||||||
sshtun.StdoutPrinterBannerCallback(tunDbgLog),
|
sshtun.BannerDebugLogCallback(tunDbgLog),
|
||||||
sshtun.StdoutPrinterSessionCallback(tunDbgLog))
|
sshtun.OutputReaderCallback(func(msg string) {
|
||||||
|
d.Log().Debug(msg)
|
||||||
|
if domainMatcher != nil {
|
||||||
|
domainMatcher(msg)
|
||||||
|
}
|
||||||
|
}))
|
||||||
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) {
|
||||||
|
if d.domainRegExp == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return func(msg string) {
|
||||||
|
domain := d.domainRegExp.FindString(msg)
|
||||||
|
if domain == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.PushEventStatus(dto.EventStatus{
|
||||||
|
Type: dto.EventStart,
|
||||||
|
ID: containerID,
|
||||||
|
Domain: domain,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (d *SSH) populateFromSSHConfig() {
|
func (d *SSH) populateFromSSHConfig() {
|
||||||
if d.params.Auth.Directory == "" {
|
if d.params.Auth.Directory == "" {
|
||||||
return
|
return
|
||||||
@ -94,9 +129,9 @@ func (d *SSH) Handle(event dto.Event) error {
|
|||||||
return ErrAlreadyInUse
|
return ErrAlreadyInUse
|
||||||
}
|
}
|
||||||
conn := d.forward(sshtun.Forward{
|
conn := d.forward(sshtun.Forward{
|
||||||
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
|
Local: d.localEndpoint(event.Container.IP, event.Container.Port),
|
||||||
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
||||||
})
|
}, d.makeDomainMatcherFunc(event.Container.ID))
|
||||||
d.conns[event.Container.ID] = conn
|
d.conns[event.Container.ID] = conn
|
||||||
d.wg.Add(1)
|
d.wg.Add(1)
|
||||||
case dto.EventStop:
|
case dto.EventStop:
|
||||||
@ -106,17 +141,27 @@ func (d *SSH) Handle(event dto.Event) error {
|
|||||||
}
|
}
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
delete(d.conns, event.Container.ID)
|
delete(d.conns, event.Container.ID)
|
||||||
|
d.propagateStop(event.Container.ID)
|
||||||
d.wg.Done()
|
d.wg.Done()
|
||||||
case dto.EventShutdown:
|
case dto.EventShutdown:
|
||||||
for id, conn := range d.conns {
|
for id, conn := range d.conns {
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
delete(d.conns, id)
|
delete(d.conns, id)
|
||||||
|
d.propagateStop(id)
|
||||||
d.wg.Done()
|
d.wg.Done()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *SSH) propagateStop(containerID string) {
|
||||||
|
d.PushEventStatus(dto.EventStatus{Type: dto.EventStop, ID: containerID})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SSH) localEndpoint(ip net.IP, port uint16) sshtun.Endpoint {
|
||||||
|
return sshtun.AddrToEndpoint(net.JoinHostPort(ip.String(), strconv.Itoa(int(port))))
|
||||||
|
}
|
||||||
|
|
||||||
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
|
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
|
||||||
port := int(d.params.ForwardPort)
|
port := int(d.params.ForwardPort)
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
package ssh
|
package ssh
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
||||||
)
|
)
|
||||||
@ -11,16 +13,19 @@ type Params struct {
|
|||||||
ForwardPort uint16 `mapstructure:"forward_port"`
|
ForwardPort uint16 `mapstructure:"forward_port"`
|
||||||
Auth types.Auth `mapstructure:"auth"`
|
Auth types.Auth `mapstructure:"auth"`
|
||||||
KeepAlive types.KeepAlive `mapstructure:"keepalive"`
|
KeepAlive types.KeepAlive `mapstructure:"keepalive"`
|
||||||
Domain string `mapstructure:"domain"`
|
|
||||||
DomainProto string `mapstructure:"domain_proto"`
|
|
||||||
DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"`
|
DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"`
|
||||||
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
|
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
|
||||||
FakeRemoteHost bool `mapstructure:"fake_remote_host"`
|
FakeRemoteHost bool `mapstructure:"fake_remote_host"`
|
||||||
|
NoPTY bool `mapstructure:"nopty"`
|
||||||
|
Commands types.Commands `mapstructure:"commands"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Params) Validate() error {
|
func (p *Params) Validate() error {
|
||||||
if err := util.Validator.Struct(p); err != nil {
|
if err := util.Validator.Struct(p); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if p.NoPTY && (len(p.Commands.OnConnect) > 0 || len(p.Commands.OnDisconnect) > 0) {
|
||||||
|
return errors.New("commands aren't available without PTY (nopty = true)")
|
||||||
|
}
|
||||||
return p.Auth.Validate()
|
return p.Auth.Validate()
|
||||||
}
|
}
|
||||||
|
28
internal/server/driver/ssh/regexp.go
Normal file
28
internal/server/driver/ssh/regexp.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var builtInRegExps = map[string]*regexp.Regexp{
|
||||||
|
"webUrl": regexp.MustCompile(`(?m)https?:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
|
||||||
|
"httpUrl": regexp.MustCompile(`(?m)http:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
|
||||||
|
"httpsUrl": regexp.MustCompile(`(?m)https:\/\/(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]{1,6}\b(?:[-a-zA-Z0-9()@:%_\+.~#?&\/=]*)`),
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeDomainCatchRegExp(expr string) (*regexp.Regexp, error) {
|
||||||
|
if expr == "" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
if strings.HasPrefix(expr, "!") {
|
||||||
|
exprName := expr[1:]
|
||||||
|
builtIn, found := builtInRegExps[exprName]
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("no builtin regexp with name '%s'", exprName)
|
||||||
|
}
|
||||||
|
return builtIn, nil
|
||||||
|
}
|
||||||
|
return regexp.Compile(expr)
|
||||||
|
}
|
171
internal/server/driver/ssh/sshtun/connect.go
Normal file
171
internal/server/driver/ssh/sshtun/connect.go
Normal file
@ -0,0 +1,171 @@
|
|||||||
|
package sshtun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||||
|
"github.com/function61/gokit/app/backoff"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionCallback func(*ssh.Session)
|
||||||
|
|
||||||
|
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
|
||||||
|
|
||||||
|
type Tunnel struct {
|
||||||
|
user string
|
||||||
|
address Endpoint
|
||||||
|
forward Forward
|
||||||
|
authMethods []ssh.AuthMethod
|
||||||
|
log *zap.SugaredLogger
|
||||||
|
sessConfig SessionConfig
|
||||||
|
connected atomic.Bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type SessionConfig struct {
|
||||||
|
NoPTY bool
|
||||||
|
FakeRemoteHost bool
|
||||||
|
KeepAliveInterval uint
|
||||||
|
KeepAliveMax uint
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(address, user string, forward Forward, auth []ssh.AuthMethod, sc SessionConfig, log *zap.SugaredLogger) *Tunnel {
|
||||||
|
return &Tunnel{
|
||||||
|
address: AddrToEndpoint(address),
|
||||||
|
user: user,
|
||||||
|
forward: forward,
|
||||||
|
authMethods: auth,
|
||||||
|
sessConfig: sc,
|
||||||
|
log: log.With(zap.String("sshServer", address)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
|
||||||
|
if t.connected.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer t.connected.Store(false)
|
||||||
|
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
||||||
|
for {
|
||||||
|
t.connected.Store(true)
|
||||||
|
err := t.connect(ctx, bannerCb, sessionCb)
|
||||||
|
if err != nil {
|
||||||
|
t.log.Error("connect error: ", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(backoffTime())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
||||||
|
// will try to re-connect
|
||||||
|
func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
|
||||||
|
t.log.Debug("connecting")
|
||||||
|
sshConfig := &ssh.ClientConfig{
|
||||||
|
User: t.user,
|
||||||
|
Auth: t.authMethods,
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
BannerCallback: bannerCb,
|
||||||
|
}
|
||||||
|
|
||||||
|
var sshClient *ssh.Client
|
||||||
|
var errConnect error
|
||||||
|
|
||||||
|
sshClient, errConnect = dialSSH(ctx, t.address.String(), sshConfig)
|
||||||
|
if errConnect != nil {
|
||||||
|
return errConnect
|
||||||
|
}
|
||||||
|
|
||||||
|
defer sshClient.Close()
|
||||||
|
defer t.log.Debug("disconnecting")
|
||||||
|
|
||||||
|
t.log.Debug("connected")
|
||||||
|
listenerStopped := make(chan error)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
sess, err := sshClient.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
t.log.Errorf("session error: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer sess.Close()
|
||||||
|
|
||||||
|
if sessionCb == nil {
|
||||||
|
sessionCb = func(*ssh.Session) {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !t.sessConfig.NoPTY {
|
||||||
|
err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{
|
||||||
|
ssh.ECHO: 0,
|
||||||
|
ssh.IGNCR: 1,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.log.Warnf("PTY allocation failed: %s", err.Error())
|
||||||
|
} else {
|
||||||
|
if err := sess.Shell(); err != nil {
|
||||||
|
t.log.Warnf("failed to start shell: %s", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
sessionCb(sess)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
reverseErr := make(chan error)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
reverseErr <- t.reverseForwardOnePort(sshClient, listenerStopped)
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go t.keepAlive(ctx, sshClient, &wg)
|
||||||
|
|
||||||
|
if err := <-reverseErr; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
case listenerFirstErr := <-listenerStopped:
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return listenerFirstErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
dialer := net.Dialer{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.NewClient(clConn, newChan, reqChan), nil
|
||||||
|
}
|
@ -4,8 +4,12 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
||||||
|
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||||
|
"github.com/function61/gokit/io/bidipipe"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Forward struct {
|
type Forward struct {
|
||||||
@ -35,3 +39,77 @@ type Endpoint struct {
|
|||||||
func (endpoint *Endpoint) String() string {
|
func (endpoint *Endpoint) String() string {
|
||||||
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Tunnel) ipFromAddr(addr net.Addr) net.IP {
|
||||||
|
host, _, _ := net.SplitHostPort(addr.String())
|
||||||
|
return net.ParseIP(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
||||||
|
//
|
||||||
|
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
||||||
|
func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
|
||||||
|
var (
|
||||||
|
listener net.Listener
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
if t.sessConfig.FakeRemoteHost {
|
||||||
|
listener, err = sshClient.ListenTCP(&net.TCPAddr{
|
||||||
|
IP: t.ipFromAddr(sshClient.Conn.RemoteAddr()),
|
||||||
|
Port: t.forward.Remote.Port,
|
||||||
|
}, t.forward.Remote.Host)
|
||||||
|
} else {
|
||||||
|
listener, err = sshClient.Listen("tcp", t.forward.Remote.String())
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer listener.Close()
|
||||||
|
t.log.Debugf("forwarding %s <- %s", t.forward.Local.String(), t.forward.Remote.String())
|
||||||
|
|
||||||
|
for {
|
||||||
|
client, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go handleReverseForwardConn(client, t.forward, t.log)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
remote, err := net.Dial("tcp", forward.Local.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("cannot dial local service: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("proxying %s <-> %s", forward.Local.String(), client.RemoteAddr())
|
||||||
|
|
||||||
|
// pipe data in both directions:
|
||||||
|
// - client => remote
|
||||||
|
// - remote => client
|
||||||
|
//
|
||||||
|
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
||||||
|
// remote endpoint.
|
||||||
|
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
||||||
|
// - this blocks until either of the parties' socket closes (or breaks)
|
||||||
|
if err := bidipipe.Pipe(
|
||||||
|
bidipipe.WithName("client", client),
|
||||||
|
bidipipe.WithName("remote", remote),
|
||||||
|
); err != nil {
|
||||||
|
// we can safely ignore those errors.
|
||||||
|
if strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
62
internal/server/driver/ssh/sshtun/keepalive.go
Normal file
62
internal/server/driver/ssh/sshtun/keepalive.go
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
package sshtun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
// keepAlive periodically sends messages to invoke a response.
|
||||||
|
// If the server does not respond after some period of time,
|
||||||
|
// assume that the underlying net.Conn abruptly died.
|
||||||
|
func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) {
|
||||||
|
defer wg.Done()
|
||||||
|
if t.sessConfig.KeepAliveInterval == 0 || t.sessConfig.KeepAliveMax == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect when the SSH connection is closed.
|
||||||
|
wait := make(chan error, 1)
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
wait <- client.Wait()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Repeatedly check if the remote server is still alive.
|
||||||
|
var aliveCount int32
|
||||||
|
ticker := time.NewTicker(time.Duration(t.sessConfig.KeepAliveInterval) * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.log.Debug("stopping keep-alive...")
|
||||||
|
_ = client.Close()
|
||||||
|
return
|
||||||
|
case err := <-wait:
|
||||||
|
if err != nil && err != io.EOF {
|
||||||
|
t.log.Error("ssh error:", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.sessConfig.KeepAliveMax) {
|
||||||
|
t.log.Error("keep-alive failed, closing connection...")
|
||||||
|
_ = client.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
|
||||||
|
if err == nil {
|
||||||
|
atomic.StoreInt32(&aliveCount, 0)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}
|
@ -9,7 +9,7 @@ import (
|
|||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
func OutputReaderCallback(callback func(string)) SessionCallback {
|
||||||
return func(session *ssh.Session) {
|
return func(session *ssh.Session) {
|
||||||
stdout, err := session.StdoutPipe()
|
stdout, err := session.StdoutPipe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -27,7 +27,7 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
|||||||
combined.Read(func(r io.Reader) error {
|
combined.Read(func(r io.Reader) error {
|
||||||
scan := bufio.NewScanner(r)
|
scan := bufio.NewScanner(r)
|
||||||
for scan.Scan() {
|
for scan.Scan() {
|
||||||
log.Debug(scan.Text())
|
callback(scan.Text())
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
@ -35,7 +35,7 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback {
|
func BannerDebugLogCallback(log *zap.SugaredLogger) ssh.BannerCallback {
|
||||||
return func(msg string) error {
|
return func(msg string) error {
|
||||||
log.Debug(msg)
|
log.Debug(msg)
|
||||||
return nil
|
return nil
|
||||||
|
@ -1,224 +0,0 @@
|
|||||||
package sshtun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
||||||
"github.com/function61/gokit/app/backoff"
|
|
||||||
"github.com/function61/gokit/io/bidipipe"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SessionCallback func(*ssh.Session)
|
|
||||||
|
|
||||||
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
|
|
||||||
|
|
||||||
type Tunnel struct {
|
|
||||||
user string
|
|
||||||
address Endpoint
|
|
||||||
forward Forward
|
|
||||||
authMethods []ssh.AuthMethod
|
|
||||||
log *zap.SugaredLogger
|
|
||||||
connected atomic.Bool
|
|
||||||
fakeRemoteHost bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(address, user string, fakeRemoteHost bool,
|
|
||||||
forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel {
|
|
||||||
return &Tunnel{
|
|
||||||
address: AddrToEndpoint(address),
|
|
||||||
user: user,
|
|
||||||
fakeRemoteHost: fakeRemoteHost,
|
|
||||||
forward: forward,
|
|
||||||
authMethods: auth,
|
|
||||||
log: log.With(zap.String("sshServer", address)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
|
|
||||||
if c.connected.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer c.connected.Store(false)
|
|
||||||
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
|
||||||
for {
|
|
||||||
c.connected.Store(true)
|
|
||||||
err := c.connect(ctx, bannerCb, sessionCb)
|
|
||||||
if err != nil {
|
|
||||||
c.log.Error("connect error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(backoffTime())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
|
||||||
// will try to re-connect
|
|
||||||
func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
|
|
||||||
c.log.Debug("connecting")
|
|
||||||
sshConfig := &ssh.ClientConfig{
|
|
||||||
User: c.user,
|
|
||||||
Auth: c.authMethods,
|
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
||||||
BannerCallback: bannerCb,
|
|
||||||
}
|
|
||||||
|
|
||||||
var sshClient *ssh.Client
|
|
||||||
var errConnect error
|
|
||||||
|
|
||||||
sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig)
|
|
||||||
if errConnect != nil {
|
|
||||||
return errConnect
|
|
||||||
}
|
|
||||||
|
|
||||||
defer sshClient.Close()
|
|
||||||
defer c.log.Debug("disconnecting")
|
|
||||||
|
|
||||||
c.log.Debug("connected")
|
|
||||||
listenerStopped := make(chan error)
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
sess, err := sshClient.NewSession()
|
|
||||||
if err != nil {
|
|
||||||
c.log.Errorf("session error: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer sess.Close()
|
|
||||||
|
|
||||||
if sessionCb == nil {
|
|
||||||
sessionCb = func(*ssh.Session) {}
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
sessionCb(sess)
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
reverseErr := make(chan error)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped)
|
|
||||||
}()
|
|
||||||
|
|
||||||
if err := <-reverseErr; err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
case listenerFirstErr := <-listenerStopped:
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
default:
|
|
||||||
return listenerFirstErr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
|
||||||
//
|
|
||||||
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
|
||||||
func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
|
|
||||||
var (
|
|
||||||
listener net.Listener
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if c.fakeRemoteHost {
|
|
||||||
listener, err = sshClient.ListenTCP(&net.TCPAddr{
|
|
||||||
IP: c.ipFromAddr(sshClient.Conn.RemoteAddr()),
|
|
||||||
Port: c.forward.Remote.Port,
|
|
||||||
}, c.forward.Remote.Host)
|
|
||||||
} else {
|
|
||||||
listener, err = sshClient.Listen("tcp", c.forward.Remote.String())
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer listener.Close()
|
|
||||||
c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String())
|
|
||||||
|
|
||||||
for {
|
|
||||||
client, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go handleReverseForwardConn(client, c.forward, c.log)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Tunnel) ipFromAddr(addr net.Addr) net.IP {
|
|
||||||
host, _, _ := net.SplitHostPort(addr.String())
|
|
||||||
return net.ParseIP(host)
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
remote, err := net.Dial("tcp", forward.Local.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("cannot dial local service: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("proxying %s <-> %s", forward.Local.String(), client.RemoteAddr())
|
|
||||||
|
|
||||||
// pipe data in both directions:
|
|
||||||
// - client => remote
|
|
||||||
// - remote => client
|
|
||||||
//
|
|
||||||
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
|
||||||
// remote endpoint.
|
|
||||||
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
|
||||||
// - this blocks until either of the parties' socket closes (or breaks)
|
|
||||||
if err := bidipipe.Pipe(
|
|
||||||
bidipipe.WithName("client", client),
|
|
||||||
bidipipe.WithName("remote", remote),
|
|
||||||
); err != nil {
|
|
||||||
// we can safely ignore those errors.
|
|
||||||
if strings.Contains(err.Error(), "use of closed network connection") {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
|
||||||
dialer := net.Dialer{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ssh.NewClient(clConn, newChan, reqChan), nil
|
|
||||||
}
|
|
6
internal/server/driver/ssh/types/commands.go
Normal file
6
internal/server/driver/ssh/types/commands.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package types
|
||||||
|
|
||||||
|
type Commands struct {
|
||||||
|
OnConnect []string `mapstructure:"on_connect"`
|
||||||
|
OnDisconnect []string `mapstructure:"on_disconnect"`
|
||||||
|
}
|
@ -9,7 +9,7 @@ import (
|
|||||||
var Validator *validator.Validate
|
var Validator *validator.Validate
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
Validator = validator.New()
|
Validator = validator.New(validator.WithRequiredStructEnabled())
|
||||||
_ = Validator.RegisterValidation("validregexp", isValidRegExp)
|
_ = Validator.RegisterValidation("validregexp", isValidRegExp)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -89,6 +89,8 @@ func (m *Manager) eventStatusCallback(serverName string) base.EventStatusCallbac
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processEventStatus(serverName string, event dto.EventStatus) {
|
func (m *Manager) processEventStatus(serverName string, event dto.EventStatus) {
|
||||||
|
logger.Sugar.Debugw("received EventStatus from server",
|
||||||
|
"serverName", serverName, "eventStatus", event)
|
||||||
m.statusLock.RLock()
|
m.statusLock.RLock()
|
||||||
_, exists := m.statusMap[serverName]
|
_, exists := m.statusMap[serverName]
|
||||||
if !exists {
|
if !exists {
|
||||||
|
@ -12,6 +12,22 @@ const (
|
|||||||
EventUnknown
|
EventUnknown
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var eventTypeNames = map[EventType]string{
|
||||||
|
EventStart: "start",
|
||||||
|
EventStop: "stop",
|
||||||
|
EventShutdown: "shutdown",
|
||||||
|
EventError: "error",
|
||||||
|
EventUnknown: "unknown",
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e EventType) String() string {
|
||||||
|
name, ok := eventTypeNames[e]
|
||||||
|
if ok {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
return eventTypeNames[EventUnknown]
|
||||||
|
}
|
||||||
|
|
||||||
func TypeFromAction(action string) EventType {
|
func TypeFromAction(action string) EventType {
|
||||||
switch action {
|
switch action {
|
||||||
case "start":
|
case "start":
|
||||||
|
Loading…
Reference in New Issue
Block a user