diff --git a/config.example.yml b/config.example.yml new file mode 100644 index 0000000..f9f6efe --- /dev/null +++ b/config.example.yml @@ -0,0 +1,140 @@ +# Enable or disable debug logging. +debug: true +# API settings. +api: + # Local port for Web API. Will be bound to localhost. + web_port: 25680 + # Local port for plugin API. Will listen on all interfaces because it has auth. + plugin_port: 25681 +# Docker client preferences. +docker: + # Extract client params from the environment. + from_env: true + # Cert path for the Docker client. + cert_path: ~ + # Set it to false to disable TLS cert verification. + tls_verify: true + # Docker host. Can be useful for running containers alongside remote plugin (although it sounds weird to do so). + host: ~ + # Docker version. + version: ~ +# Default server to use if `sshpoke.server` is not specified in the target container labels. +default_server: mine +# Servers configuration. +servers: + # Server name. + - name: mine + # Server driver. Each driver has its own set of params. Supported drivers: ssh, plugin, null. + driver: ssh + params: + # SSH server address + address: "your1.server:2222" + # Remote port to be used for forwarding. + forward_port: 80 + # This disables remote host resolution and forcibly uses server IP for remote host. + # It's the same as this syntax for sish: `ssh -R addr:80:localhost:80 your.sish.server` + # Set this to true if you're using sish, otherwise you'll get weird domains with IP's in them. + fake_remote_host: true + # Disables PTY request for this server. + nopty: true + # Requests interactive shell for SSH sessions. Should be `true` for the `commands`. + shell: false + # Authentication data. + auth: + # Authentication type. Supported types: key, password, passwordless + type: key + # Remote user + user: user + # Directory with SSH keys. ssh-config from this directory will be used if `keyfile` is not provided. + # Only some of the ssh-config attributes are used. + directory: "~/.ssh" + # Expose mode (multiple domains or single domain). Allowed values: single, multi. + mode: multi + # Keep-alive settings. Remove to disable keep-alive completely. + keepalive: + # Interval for keep-alive requests in seconds. + interval: 1 + # How many attempts should fail to forcibly restart the connection. + max_attempts: 2 + # Regular expression that will be used to extract domain from stdout & stderr. Useful for services like sish or + # localhost.run. `commands` output will also be parsed by this regex. + # With `!name` syntax you can use some built-in expressions: + # - !webUrl - any HTTP or HTTPS URL. + # - !httpUrl - any HTTP URL. + # - !httpsUrl - any HTTPS URL. + domain_extract_regex: "!httpsUrl" + # Host keys to prevent MITM. You can obtain those via `ssh-keyscan
` (specify `-p` for non-standard port). + # Always use '|' YAML syntax here (not '>') or sshpoke won't be able to parse keys. + host_keys: | + # ssh.neur0tx.site:2222 SSH-2.0-sish + # ssh.neur0tx.site:2222 SSH-2.0-sish + # ssh.neur0tx.site:2222 SSH-2.0-sish + [ssh.neur0tx.site]:2222 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIEvxbqK0u8UjqEtrO/83GPS7MeoFp6C3+7KjOHd8+1GF + # ssh.neur0tx.site:2222 SSH-2.0-sish + # ssh.neur0tx.site:2222 SSH-2.0-sish + - name: ssh-demo-single-domain + driver: ssh + params: + auth: + type: key + user: user + directory: "~/.ssh" + keyfile: id_ed25519 + address: "your2.server" + forward_port: 80 + fake_remote_host: true + nopty: false + shell: true + mode: single + keepalive: + interval: 1 + max_attempts: 2 + domain_extract_regex: "!webUrl" + - name: ssh-demo-commands + driver: ssh + params: + address: "your3.server" + forward_port: 8080 + auth: + type: key + user: user + directory: "~/.ssh" + mode: multi + keepalive: + interval: 1 + max_attempts: 2 + domain_extract_regex: "!webUrl" + # Commands that will be executed on the host. + commands: + # These commands will be executed after connect. + on_connect: + - echo https://`date +%s`.proxy.test + # These commands will be executed before disconnect. + on_disconnect: + - echo disconnect from `cat /etc/hostname` + - name: ssh-demo-with-password + driver: ssh + params: + address: "ssh.neur0tx.site" + forward_port: 8081 + auth: + type: password + user: user + # Remote user password. + password: password + mode: multi + keepalive: + interval: 1 + max_attempts: 2 + domain_extract_regex: "!httpUrl" + commands: + on_connect: + - echo http://`date +%s`.proxy.test + - name: plugin-demo + driver: plugin + params: + # This token will be used by plugin while connecting to gRPC API. + token: key + - name: noop + # Null driver doesn't do anything. This driver will automatically be used for servers with invalid 'driver' value. + driver: null diff --git a/internal/server/driver/ssh/driver.go b/internal/server/driver/ssh/driver.go index 39a6aa2..4b576d2 100644 --- a/internal/server/driver/ssh/driver.go +++ b/internal/server/driver/ssh/driver.go @@ -2,12 +2,14 @@ package ssh import ( "context" + "encoding/base64" "errors" "fmt" "net" "path" "regexp" "strconv" + "strings" "sync" "github.com/Neur0toxine/sshpoke/internal/config" @@ -25,6 +27,7 @@ type SSH struct { base.Base params Params auth []ssh.AuthMethod + hostKeys []ssh.PublicKey conns map[string]conn rw sync.RWMutex wg sync.WaitGroup @@ -45,6 +48,9 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri if err := util.UnmarshalParams(params, &drv.params); err != nil { return nil, err } + if err := drv.buildHostKeys(); err != nil { + return nil, err + } matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex) if err != nil { return nil, fmt.Errorf("invalid domain_extract_regex: %w", err) @@ -58,9 +64,10 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { tun := sshtun.New(d.params.Address, d.params.Auth.User, - val, d.auth, - sshtun.SessionConfig{ + sshtun.TunnelConfig{ + Forward: val, + HostKeys: d.hostKeys, NoPTY: d.params.NoPTY, Shell: d.params.Shell, FakeRemoteHost: d.params.FakeRemoteHost, @@ -81,6 +88,46 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { return conn{ctx: ctx, cancel: cancel, tun: tun} } +func (d *SSH) buildHostKeys() error { + if d.params.HostKeys == "" { + return nil + } + hostKeys := []ssh.PublicKey{} + for _, keyLine := range strings.Split(d.params.HostKeys, "\n") { + key, err := d.pubKeyFromSSHKeyScan(keyLine) + if err != nil { + d.Log().Debugf("invalid public key: %s", keyLine) + return fmt.Errorf("invalid public key for the host: %w", err) + } + if key != nil { + hostKeys = append(hostKeys, key) + } + } + d.hostKeys = hostKeys + return nil +} + +// pubKeyFromSSHKeyScan extracts host public key from ssh-keyscan output format. +func (d *SSH) pubKeyFromSSHKeyScan(line string) (key ssh.PublicKey, err error) { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "#") || line == "" { // comment or empty line - should be ignored. + return nil, nil + } + cols := strings.Fields(line) + for i := len(cols) - 1; i >= 0; i-- { + col := strings.TrimSpace(cols[i]) + keyData, err := base64.StdEncoding.DecodeString(col) + if err != nil { + continue + } + key, err = ssh.ParsePublicKey(keyData) + if err == nil { + return key, nil + } + } + return nil, errors.New("no public key in the provided data") +} + func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) { if d.domainRegExp == nil { return nil @@ -106,13 +153,23 @@ func (d *SSH) populateFromSSHConfig() { if err != nil { return } - if user, err := cfg.Get(d.params.Address, "User"); err == nil && user != "" { + + host := d.extractHostFromAddr(d.params.Address) + hostCfg := &hostConfig{cfg: cfg, host: host} + port, err := hostCfg.Get("Port") + if err != nil { + port = "22" + } + if hostName, err := hostCfg.Get("HostName"); err == nil && hostName != "" { + d.params.Address = net.JoinHostPort(hostName, port) + } + if user, err := hostCfg.Get("User"); err == nil && user != "" { d.params.Auth.User = user } - if usePass, err := cfg.Get(d.params.Address, "PasswordAuthentication"); err == nil && usePass == "yes" { + if usePass, err := hostCfg.Get("PasswordAuthentication"); err == nil && usePass == "yes" { d.params.Auth.Type = types.AuthTypePassword } - if keyfile, err := cfg.Get(d.params.Address, "IdentityFile"); err == nil && keyfile != "" { + if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" { resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false) if err == nil { d.params.Auth.Type = types.AuthTypeKey @@ -121,6 +178,14 @@ func (d *SSH) populateFromSSHConfig() { } } +func (d *SSH) extractHostFromAddr(addr string) string { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return addr + } + return host +} + func (d *SSH) Handle(event dto.Event) error { defer d.rw.Unlock() d.rw.Lock() diff --git a/internal/server/driver/ssh/params.go b/internal/server/driver/ssh/params.go index f8cee14..35d3069 100644 --- a/internal/server/driver/ssh/params.go +++ b/internal/server/driver/ssh/params.go @@ -9,6 +9,7 @@ import ( type Params struct { Address string `mapstructure:"address" validate:"required"` + HostKeys string `mapstructure:"host_keys"` DefaultHost *string `mapstructure:"default_host,omitempty"` ForwardPort uint16 `mapstructure:"forward_port"` Auth types.Auth `mapstructure:"auth"` diff --git a/internal/server/driver/ssh/sshconfig.go b/internal/server/driver/ssh/sshconfig.go index c74767f..da08b30 100644 --- a/internal/server/driver/ssh/sshconfig.go +++ b/internal/server/driver/ssh/sshconfig.go @@ -3,11 +3,22 @@ package ssh import ( "bytes" "os" + "strings" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" "github.com/kevinburke/ssh_config" ) +type hostConfig struct { + cfg *ssh_config.Config + host string +} + +func (c *hostConfig) Get(key string) (string, error) { + val, err := c.cfg.Get(c.host, key) + return strings.TrimSpace(val), err +} + func parseSSHConfig(filePath types.SmartPath) (*ssh_config.Config, error) { fileName, err := filePath.Resolve(false) if err != nil { diff --git a/internal/server/driver/ssh/sshtun/connect.go b/internal/server/driver/ssh/sshtun/connect.go index 97e51bf..5c83f2e 100644 --- a/internal/server/driver/ssh/sshtun/connect.go +++ b/internal/server/driver/ssh/sshtun/connect.go @@ -19,14 +19,15 @@ var NoopSessionCallback SessionCallback = func(*ssh.Session) {} type Tunnel struct { user string address Endpoint - forward Forward authMethods []ssh.AuthMethod log *zap.SugaredLogger - sessConfig SessionConfig + tunConfig TunnelConfig connected atomic.Bool } -type SessionConfig struct { +type TunnelConfig struct { + Forward Forward + HostKeys []ssh.PublicKey NoPTY bool Shell bool FakeRemoteHost bool @@ -34,13 +35,12 @@ type SessionConfig struct { KeepAliveMax uint } -func New(address, user string, forward Forward, auth []ssh.AuthMethod, sc SessionConfig, log *zap.SugaredLogger) *Tunnel { +func New(address, user string, auth []ssh.AuthMethod, sc TunnelConfig, log *zap.SugaredLogger) *Tunnel { return &Tunnel{ address: AddrToEndpoint(address), user: user, - forward: forward, authMethods: auth, - sessConfig: sc, + tunConfig: sc, log: log.With(zap.String("sshServer", address)), } } @@ -69,6 +69,16 @@ func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi } } +func (t *Tunnel) buildHostKeyCallback() ssh.HostKeyCallback { + if t.tunConfig.HostKeys == nil || len(t.tunConfig.HostKeys) == 0 { + return ssh.InsecureIgnoreHostKey() + } + if len(t.tunConfig.HostKeys) == 1 { + return ssh.FixedHostKey(t.tunConfig.HostKeys[0]) + } + return FixedHostKeys(t.tunConfig.HostKeys) +} + // connect once to the SSH server. if the connection breaks, we return error and the caller // will try to re-connect func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error { @@ -76,7 +86,7 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi sshConfig := &ssh.ClientConfig{ User: t.user, Auth: t.authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), + HostKeyCallback: t.buildHostKeyCallback(), BannerCallback: bannerCb, } @@ -106,7 +116,7 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi sessionCb = func(*ssh.Session) {} } - if !t.sessConfig.NoPTY { + if !t.tunConfig.NoPTY { err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{ ssh.ECHO: 0, ssh.IGNCR: 1, @@ -115,7 +125,7 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi t.log.Warnf("PTY allocation failed: %s", err.Error()) } } - if t.sessConfig.Shell { + if t.tunConfig.Shell { if err := sess.Shell(); err != nil { t.log.Warnf("failed to start shell: %s", err.Error()) } diff --git a/internal/server/driver/ssh/sshtun/fixed_host_keys.go b/internal/server/driver/ssh/sshtun/fixed_host_keys.go new file mode 100644 index 0000000..f95a5eb --- /dev/null +++ b/internal/server/driver/ssh/sshtun/fixed_host_keys.go @@ -0,0 +1,36 @@ +package sshtun + +import ( + "bytes" + "fmt" + "net" + + "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" +) + +func FixedHostKeys(keys []ssh.PublicKey) ssh.HostKeyCallback { + m := make(map[string]ssh.PublicKey) + for _, key := range keys { + m[key.Type()] = key + } + hk := &fixedHostKeys{keys: m} + return hk.check +} + +type fixedHostKeys struct { + keys map[string]ssh.PublicKey +} + +func (f *fixedHostKeys) check(hostname string, remote net.Addr, key ssh.PublicKey) error { + if f.keys == nil { + return fmt.Errorf("ssh: host keys should be defined") + } + if len(f.keys) == 0 { + return fmt.Errorf("ssh: no host keys were provided") + } + hostKey, found := f.keys[key.Type()] + if !found || !bytes.Equal(key.Marshal(), hostKey.Marshal()) { + return fmt.Errorf("ssh: host key mismatch") + } + return nil +} diff --git a/internal/server/driver/ssh/sshtun/forward.go b/internal/server/driver/ssh/sshtun/forward.go index 4cfa959..9e46edb 100644 --- a/internal/server/driver/ssh/sshtun/forward.go +++ b/internal/server/driver/ssh/sshtun/forward.go @@ -53,13 +53,13 @@ func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped ch listener net.Listener err error ) - if t.sessConfig.FakeRemoteHost { + if t.tunConfig.FakeRemoteHost { listener, err = sshClient.ListenTCP(&net.TCPAddr{ IP: t.ipFromAddr(sshClient.Conn.RemoteAddr()), - Port: t.forward.Remote.Port, - }, t.forward.Remote.Host) + Port: t.tunConfig.Forward.Remote.Port, + }, t.tunConfig.Forward.Remote.Host) } else { - listener, err = sshClient.Listen("tcp", t.forward.Remote.String()) + listener, err = sshClient.Listen("tcp", t.tunConfig.Forward.Remote.String()) } if err != nil { return err @@ -67,7 +67,8 @@ func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped ch go func() { defer listener.Close() - t.log.Debugf("forwarding %s <- %s", t.forward.Local.String(), t.forward.Remote.String()) + t.log.Debugf("forwarding %s <- %s", + t.tunConfig.Forward.Local.String(), t.tunConfig.Forward.Remote.String()) for { client, err := listener.Accept() @@ -76,7 +77,7 @@ func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped ch return } - go handleReverseForwardConn(client, t.forward, t.log) + go handleReverseForwardConn(client, t.tunConfig.Forward, t.log) } }() diff --git a/internal/server/driver/ssh/sshtun/keepalive.go b/internal/server/driver/ssh/sshtun/keepalive.go index a6de49c..54a5e4e 100644 --- a/internal/server/driver/ssh/sshtun/keepalive.go +++ b/internal/server/driver/ssh/sshtun/keepalive.go @@ -15,7 +15,7 @@ import ( // assume that the underlying net.Conn abruptly died. func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) { defer wg.Done() - if t.sessConfig.KeepAliveInterval == 0 || t.sessConfig.KeepAliveMax == 0 { + if t.tunConfig.KeepAliveInterval == 0 || t.tunConfig.KeepAliveMax == 0 { return } @@ -29,7 +29,7 @@ func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.Wai // Repeatedly check if the remote server is still alive. var aliveCount int32 - ticker := time.NewTicker(time.Duration(t.sessConfig.KeepAliveInterval) * time.Second) + ticker := time.NewTicker(time.Duration(t.tunConfig.KeepAliveInterval) * time.Second) defer ticker.Stop() for { select { @@ -43,7 +43,7 @@ func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.Wai } return case <-ticker.C: - if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.sessConfig.KeepAliveMax) { + if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.tunConfig.KeepAliveMax) { t.log.Error("keep-alive failed, closing connection...") _ = client.Close() return