Compare commits

...

2 Commits

Author SHA1 Message Date
39faf83c78 ssh: host key verification; config file example 2023-11-19 17:03:12 +03:00
f654deb300 ssh: shell start support 2023-11-19 14:12:39 +03:00
8 changed files with 299 additions and 26 deletions

140
config.example.yml Normal file
View File

@ -0,0 +1,140 @@
# Enable or disable debug logging.
debug: true
# API settings.
api:
# Local port for Web API. Will be bound to localhost.
web_port: 25680
# Local port for plugin API. Will listen on all interfaces because it has auth.
plugin_port: 25681
# Docker client preferences.
docker:
# Extract client params from the environment.
from_env: true
# Cert path for the Docker client.
cert_path: ~
# Set it to false to disable TLS cert verification.
tls_verify: true
# Docker host. Can be useful for running containers alongside remote plugin (although it sounds weird to do so).
host: ~
# Docker version.
version: ~
# Default server to use if `sshpoke.server` is not specified in the target container labels.
default_server: mine
# Servers configuration.
servers:
# Server name.
- name: mine
# Server driver. Each driver has its own set of params. Supported drivers: ssh, plugin, null.
driver: ssh
params:
# SSH server address
address: "your1.server:2222"
# Remote port to be used for forwarding.
forward_port: 80
# This disables remote host resolution and forcibly uses server IP for remote host.
# It's the same as this syntax for sish: `ssh -R addr:80:localhost:80 your.sish.server`
# Set this to true if you're using sish, otherwise you'll get weird domains with IP's in them.
fake_remote_host: true
# Disables PTY request for this server.
nopty: true
# Requests interactive shell for SSH sessions. Should be `true` for the `commands`.
shell: false
# Authentication data.
auth:
# Authentication type. Supported types: key, password, passwordless
type: key
# Remote user
user: user
# Directory with SSH keys. ssh-config from this directory will be used if `keyfile` is not provided.
# Only some of the ssh-config attributes are used.
directory: "~/.ssh"
# Expose mode (multiple domains or single domain). Allowed values: single, multi.
mode: multi
# Keep-alive settings. Remove to disable keep-alive completely.
keepalive:
# Interval for keep-alive requests in seconds.
interval: 1
# How many attempts should fail to forcibly restart the connection.
max_attempts: 2
# Regular expression that will be used to extract domain from stdout & stderr. Useful for services like sish or
# localhost.run. `commands` output will also be parsed by this regex.
# With `!name` syntax you can use some built-in expressions:
# - !webUrl - any HTTP or HTTPS URL.
# - !httpUrl - any HTTP URL.
# - !httpsUrl - any HTTPS URL.
domain_extract_regex: "!httpsUrl"
# Host keys to prevent MITM. You can obtain those via `ssh-keyscan <address>` (specify `-p` for non-standard port).
# Always use '|' YAML syntax here (not '>') or sshpoke won't be able to parse keys.
host_keys: |
# ssh.neur0tx.site:2222 SSH-2.0-sish
# ssh.neur0tx.site:2222 SSH-2.0-sish
# ssh.neur0tx.site:2222 SSH-2.0-sish
[ssh.neur0tx.site]:2222 ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIEvxbqK0u8UjqEtrO/83GPS7MeoFp6C3+7KjOHd8+1GF
# ssh.neur0tx.site:2222 SSH-2.0-sish
# ssh.neur0tx.site:2222 SSH-2.0-sish
- name: ssh-demo-single-domain
driver: ssh
params:
auth:
type: key
user: user
directory: "~/.ssh"
keyfile: id_ed25519
address: "your2.server"
forward_port: 80
fake_remote_host: true
nopty: false
shell: true
mode: single
keepalive:
interval: 1
max_attempts: 2
domain_extract_regex: "!webUrl"
- name: ssh-demo-commands
driver: ssh
params:
address: "your3.server"
forward_port: 8080
auth:
type: key
user: user
directory: "~/.ssh"
mode: multi
keepalive:
interval: 1
max_attempts: 2
domain_extract_regex: "!webUrl"
# Commands that will be executed on the host.
commands:
# These commands will be executed after connect.
on_connect:
- echo https://`date +%s`.proxy.test
# These commands will be executed before disconnect.
on_disconnect:
- echo disconnect from `cat /etc/hostname`
- name: ssh-demo-with-password
driver: ssh
params:
address: "ssh.neur0tx.site"
forward_port: 8081
auth:
type: password
user: user
# Remote user password.
password: password
mode: multi
keepalive:
interval: 1
max_attempts: 2
domain_extract_regex: "!httpUrl"
commands:
on_connect:
- echo http://`date +%s`.proxy.test
- name: plugin-demo
driver: plugin
params:
# This token will be used by plugin while connecting to gRPC API.
token: key
- name: noop
# Null driver doesn't do anything. This driver will automatically be used for servers with invalid 'driver' value.
driver: null

View File

@ -2,12 +2,14 @@ package ssh
import ( import (
"context" "context"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"path" "path"
"regexp" "regexp"
"strconv" "strconv"
"strings"
"sync" "sync"
"github.com/Neur0toxine/sshpoke/internal/config" "github.com/Neur0toxine/sshpoke/internal/config"
@ -25,6 +27,7 @@ type SSH struct {
base.Base base.Base
params Params params Params
auth []ssh.AuthMethod auth []ssh.AuthMethod
hostKeys []ssh.PublicKey
conns map[string]conn conns map[string]conn
rw sync.RWMutex rw sync.RWMutex
wg sync.WaitGroup wg sync.WaitGroup
@ -45,6 +48,9 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
if err := util.UnmarshalParams(params, &drv.params); err != nil { if err := util.UnmarshalParams(params, &drv.params); err != nil {
return nil, err return nil, err
} }
if err := drv.buildHostKeys(); err != nil {
return nil, err
}
matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex) matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid domain_extract_regex: %w", err) return nil, fmt.Errorf("invalid domain_extract_regex: %w", err)
@ -58,10 +64,12 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
tun := sshtun.New(d.params.Address, tun := sshtun.New(d.params.Address,
d.params.Auth.User, d.params.Auth.User,
val,
d.auth, d.auth,
sshtun.SessionConfig{ sshtun.TunnelConfig{
Forward: val,
HostKeys: d.hostKeys,
NoPTY: d.params.NoPTY, NoPTY: d.params.NoPTY,
Shell: d.params.Shell,
FakeRemoteHost: d.params.FakeRemoteHost, FakeRemoteHost: d.params.FakeRemoteHost,
KeepAliveInterval: uint(d.params.KeepAlive.Interval), KeepAliveInterval: uint(d.params.KeepAlive.Interval),
KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts), KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts),
@ -80,6 +88,46 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
return conn{ctx: ctx, cancel: cancel, tun: tun} return conn{ctx: ctx, cancel: cancel, tun: tun}
} }
func (d *SSH) buildHostKeys() error {
if d.params.HostKeys == "" {
return nil
}
hostKeys := []ssh.PublicKey{}
for _, keyLine := range strings.Split(d.params.HostKeys, "\n") {
key, err := d.pubKeyFromSSHKeyScan(keyLine)
if err != nil {
d.Log().Debugf("invalid public key: %s", keyLine)
return fmt.Errorf("invalid public key for the host: %w", err)
}
if key != nil {
hostKeys = append(hostKeys, key)
}
}
d.hostKeys = hostKeys
return nil
}
// pubKeyFromSSHKeyScan extracts host public key from ssh-keyscan output format.
func (d *SSH) pubKeyFromSSHKeyScan(line string) (key ssh.PublicKey, err error) {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "#") || line == "" { // comment or empty line - should be ignored.
return nil, nil
}
cols := strings.Fields(line)
for i := len(cols) - 1; i >= 0; i-- {
col := strings.TrimSpace(cols[i])
keyData, err := base64.StdEncoding.DecodeString(col)
if err != nil {
continue
}
key, err = ssh.ParsePublicKey(keyData)
if err == nil {
return key, nil
}
}
return nil, errors.New("no public key in the provided data")
}
func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) { func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) {
if d.domainRegExp == nil { if d.domainRegExp == nil {
return nil return nil
@ -105,13 +153,23 @@ func (d *SSH) populateFromSSHConfig() {
if err != nil { if err != nil {
return return
} }
if user, err := cfg.Get(d.params.Address, "User"); err == nil && user != "" {
host := d.extractHostFromAddr(d.params.Address)
hostCfg := &hostConfig{cfg: cfg, host: host}
port, err := hostCfg.Get("Port")
if err != nil {
port = "22"
}
if hostName, err := hostCfg.Get("HostName"); err == nil && hostName != "" {
d.params.Address = net.JoinHostPort(hostName, port)
}
if user, err := hostCfg.Get("User"); err == nil && user != "" {
d.params.Auth.User = user d.params.Auth.User = user
} }
if usePass, err := cfg.Get(d.params.Address, "PasswordAuthentication"); err == nil && usePass == "yes" { if usePass, err := hostCfg.Get("PasswordAuthentication"); err == nil && usePass == "yes" {
d.params.Auth.Type = types.AuthTypePassword d.params.Auth.Type = types.AuthTypePassword
} }
if keyfile, err := cfg.Get(d.params.Address, "IdentityFile"); err == nil && keyfile != "" { if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" {
resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false) resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false)
if err == nil { if err == nil {
d.params.Auth.Type = types.AuthTypeKey d.params.Auth.Type = types.AuthTypeKey
@ -120,6 +178,14 @@ func (d *SSH) populateFromSSHConfig() {
} }
} }
func (d *SSH) extractHostFromAddr(addr string) string {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
return host
}
func (d *SSH) Handle(event dto.Event) error { func (d *SSH) Handle(event dto.Event) error {
defer d.rw.Unlock() defer d.rw.Unlock()
d.rw.Lock() d.rw.Lock()

View File

@ -9,6 +9,7 @@ import (
type Params struct { type Params struct {
Address string `mapstructure:"address" validate:"required"` Address string `mapstructure:"address" validate:"required"`
HostKeys string `mapstructure:"host_keys"`
DefaultHost *string `mapstructure:"default_host,omitempty"` DefaultHost *string `mapstructure:"default_host,omitempty"`
ForwardPort uint16 `mapstructure:"forward_port"` ForwardPort uint16 `mapstructure:"forward_port"`
Auth types.Auth `mapstructure:"auth"` Auth types.Auth `mapstructure:"auth"`
@ -17,6 +18,7 @@ type Params struct {
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"` Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
FakeRemoteHost bool `mapstructure:"fake_remote_host"` FakeRemoteHost bool `mapstructure:"fake_remote_host"`
NoPTY bool `mapstructure:"nopty"` NoPTY bool `mapstructure:"nopty"`
Shell bool `mapstructure:"shell"`
Commands types.Commands `mapstructure:"commands"` Commands types.Commands `mapstructure:"commands"`
} }

View File

@ -3,11 +3,22 @@ package ssh
import ( import (
"bytes" "bytes"
"os" "os"
"strings"
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
"github.com/kevinburke/ssh_config" "github.com/kevinburke/ssh_config"
) )
type hostConfig struct {
cfg *ssh_config.Config
host string
}
func (c *hostConfig) Get(key string) (string, error) {
val, err := c.cfg.Get(c.host, key)
return strings.TrimSpace(val), err
}
func parseSSHConfig(filePath types.SmartPath) (*ssh_config.Config, error) { func parseSSHConfig(filePath types.SmartPath) (*ssh_config.Config, error) {
fileName, err := filePath.Resolve(false) fileName, err := filePath.Resolve(false)
if err != nil { if err != nil {

View File

@ -19,27 +19,28 @@ var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
type Tunnel struct { type Tunnel struct {
user string user string
address Endpoint address Endpoint
forward Forward
authMethods []ssh.AuthMethod authMethods []ssh.AuthMethod
log *zap.SugaredLogger log *zap.SugaredLogger
sessConfig SessionConfig tunConfig TunnelConfig
connected atomic.Bool connected atomic.Bool
} }
type SessionConfig struct { type TunnelConfig struct {
Forward Forward
HostKeys []ssh.PublicKey
NoPTY bool NoPTY bool
Shell bool
FakeRemoteHost bool FakeRemoteHost bool
KeepAliveInterval uint KeepAliveInterval uint
KeepAliveMax uint KeepAliveMax uint
} }
func New(address, user string, forward Forward, auth []ssh.AuthMethod, sc SessionConfig, log *zap.SugaredLogger) *Tunnel { func New(address, user string, auth []ssh.AuthMethod, sc TunnelConfig, log *zap.SugaredLogger) *Tunnel {
return &Tunnel{ return &Tunnel{
address: AddrToEndpoint(address), address: AddrToEndpoint(address),
user: user, user: user,
forward: forward,
authMethods: auth, authMethods: auth,
sessConfig: sc, tunConfig: sc,
log: log.With(zap.String("sshServer", address)), log: log.With(zap.String("sshServer", address)),
} }
} }
@ -68,6 +69,16 @@ func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi
} }
} }
func (t *Tunnel) buildHostKeyCallback() ssh.HostKeyCallback {
if t.tunConfig.HostKeys == nil || len(t.tunConfig.HostKeys) == 0 {
return ssh.InsecureIgnoreHostKey()
}
if len(t.tunConfig.HostKeys) == 1 {
return ssh.FixedHostKey(t.tunConfig.HostKeys[0])
}
return FixedHostKeys(t.tunConfig.HostKeys)
}
// connect once to the SSH server. if the connection breaks, we return error and the caller // connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect // will try to re-connect
func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error { func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
@ -75,7 +86,7 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi
sshConfig := &ssh.ClientConfig{ sshConfig := &ssh.ClientConfig{
User: t.user, User: t.user,
Auth: t.authMethods, Auth: t.authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: t.buildHostKeyCallback(),
BannerCallback: bannerCb, BannerCallback: bannerCb,
} }
@ -105,19 +116,25 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi
sessionCb = func(*ssh.Session) {} sessionCb = func(*ssh.Session) {}
} }
if !t.sessConfig.NoPTY { if !t.tunConfig.NoPTY {
err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{ err = sess.RequestPty("xterm", 80, 40, ssh.TerminalModes{
ssh.ECHO: 0, ssh.ECHO: 0,
ssh.IGNCR: 1, ssh.IGNCR: 1,
}) })
if err != nil { if err != nil {
t.log.Warnf("PTY allocation failed: %s", err.Error()) t.log.Warnf("PTY allocation failed: %s", err.Error())
} else {
if err := sess.Shell(); err != nil {
t.log.Warnf("failed to start shell: %s", err.Error())
}
} }
} }
if t.tunConfig.Shell {
if err := sess.Shell(); err != nil {
t.log.Warnf("failed to start shell: %s", err.Error())
}
wg.Add(1)
go func() {
defer wg.Done()
_ = sess.Wait()
}()
}
wg.Add(1) wg.Add(1)
go func() { go func() {

View File

@ -0,0 +1,36 @@
package sshtun
import (
"bytes"
"fmt"
"net"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
)
func FixedHostKeys(keys []ssh.PublicKey) ssh.HostKeyCallback {
m := make(map[string]ssh.PublicKey)
for _, key := range keys {
m[key.Type()] = key
}
hk := &fixedHostKeys{keys: m}
return hk.check
}
type fixedHostKeys struct {
keys map[string]ssh.PublicKey
}
func (f *fixedHostKeys) check(hostname string, remote net.Addr, key ssh.PublicKey) error {
if f.keys == nil {
return fmt.Errorf("ssh: host keys should be defined")
}
if len(f.keys) == 0 {
return fmt.Errorf("ssh: no host keys were provided")
}
hostKey, found := f.keys[key.Type()]
if !found || !bytes.Equal(key.Marshal(), hostKey.Marshal()) {
return fmt.Errorf("ssh: host key mismatch")
}
return nil
}

View File

@ -53,13 +53,13 @@ func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped ch
listener net.Listener listener net.Listener
err error err error
) )
if t.sessConfig.FakeRemoteHost { if t.tunConfig.FakeRemoteHost {
listener, err = sshClient.ListenTCP(&net.TCPAddr{ listener, err = sshClient.ListenTCP(&net.TCPAddr{
IP: t.ipFromAddr(sshClient.Conn.RemoteAddr()), IP: t.ipFromAddr(sshClient.Conn.RemoteAddr()),
Port: t.forward.Remote.Port, Port: t.tunConfig.Forward.Remote.Port,
}, t.forward.Remote.Host) }, t.tunConfig.Forward.Remote.Host)
} else { } else {
listener, err = sshClient.Listen("tcp", t.forward.Remote.String()) listener, err = sshClient.Listen("tcp", t.tunConfig.Forward.Remote.String())
} }
if err != nil { if err != nil {
return err return err
@ -67,7 +67,8 @@ func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped ch
go func() { go func() {
defer listener.Close() defer listener.Close()
t.log.Debugf("forwarding %s <- %s", t.forward.Local.String(), t.forward.Remote.String()) t.log.Debugf("forwarding %s <- %s",
t.tunConfig.Forward.Local.String(), t.tunConfig.Forward.Remote.String())
for { for {
client, err := listener.Accept() client, err := listener.Accept()
@ -76,7 +77,7 @@ func (t *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped ch
return return
} }
go handleReverseForwardConn(client, t.forward, t.log) go handleReverseForwardConn(client, t.tunConfig.Forward, t.log)
} }
}() }()

View File

@ -15,7 +15,7 @@ import (
// assume that the underlying net.Conn abruptly died. // assume that the underlying net.Conn abruptly died.
func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) { func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) {
defer wg.Done() defer wg.Done()
if t.sessConfig.KeepAliveInterval == 0 || t.sessConfig.KeepAliveMax == 0 { if t.tunConfig.KeepAliveInterval == 0 || t.tunConfig.KeepAliveMax == 0 {
return return
} }
@ -29,7 +29,7 @@ func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.Wai
// Repeatedly check if the remote server is still alive. // Repeatedly check if the remote server is still alive.
var aliveCount int32 var aliveCount int32
ticker := time.NewTicker(time.Duration(t.sessConfig.KeepAliveInterval) * time.Second) ticker := time.NewTicker(time.Duration(t.tunConfig.KeepAliveInterval) * time.Second)
defer ticker.Stop() defer ticker.Stop()
for { for {
select { select {
@ -43,7 +43,7 @@ func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.Wai
} }
return return
case <-ticker.C: case <-ticker.C:
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.sessConfig.KeepAliveMax) { if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.tunConfig.KeepAliveMax) {
t.log.Error("keep-alive failed, closing connection...") t.log.Error("keep-alive failed, closing connection...")
_ = client.Close() _ = client.Close()
return return