diff --git a/config.example.yml b/config.example.yml index f9f6efe..ca64c98 100644 --- a/config.example.yml +++ b/config.example.yml @@ -38,7 +38,11 @@ servers: # Disables PTY request for this server. nopty: true # Requests interactive shell for SSH sessions. Should be `true` for the `commands`. + # You can also pass a string with shell binary, for example, "/bin/sh". + # Note: commands will be executed using provided shell binary. shell: false + # Spoof client version with provided (value below is taken directly from OpenSSH). This value must be compliant with RFC-4253. + client_version: "SSH-2.0-OpenSSH_9.5" # Authentication data. auth: # Authentication type. Supported types: key, password, passwordless @@ -84,7 +88,7 @@ servers: forward_port: 80 fake_remote_host: true nopty: false - shell: true + shell: "/usr/bin/bash" mode: single keepalive: interval: 1 diff --git a/internal/server/driver/ssh/driver.go b/internal/server/driver/ssh/driver.go index 4b576d2..5172203 100644 --- a/internal/server/driver/ssh/driver.go +++ b/internal/server/driver/ssh/driver.go @@ -25,13 +25,14 @@ var ErrAlreadyInUse = errors.New("domain is already in use") type SSH struct { base.Base - params Params - auth []ssh.AuthMethod - hostKeys []ssh.PublicKey - conns map[string]conn - rw sync.RWMutex - wg sync.WaitGroup - domainRegExp *regexp.Regexp + params Params + auth []ssh.AuthMethod + hostKeys []ssh.PublicKey + conns map[string]conn + clientVersion string + rw sync.RWMutex + wg sync.WaitGroup + domainRegExp *regexp.Regexp } type conn struct { @@ -58,6 +59,7 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri drv.domainRegExp = matcher drv.populateFromSSHConfig() drv.auth = drv.authenticators() + drv.clientVersion = drv.buildClientVersion() return drv, nil } @@ -69,7 +71,8 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { Forward: val, HostKeys: d.hostKeys, NoPTY: d.params.NoPTY, - Shell: d.params.Shell, + Shell: sshtun.BoolOrStr(d.params.Shell), + ClientVersion: d.clientVersion, FakeRemoteHost: d.params.FakeRemoteHost, KeepAliveInterval: uint(d.params.KeepAlive.Interval), KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts), @@ -88,6 +91,23 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { return conn{ctx: ctx, cancel: cancel, tun: tun} } +func (d *SSH) buildClientVersion() string { + ver := strings.TrimSpace(d.params.ClientVersion) + if ver == "" { + return "" + } + if !strings.HasPrefix(ver, "SSH-2.0-") { + d.Log().Warn( + "client_version must have 'SSH-2.0-' prefix (see RFC-4253), this will be fixed automatically") + ver = "SSH-2.0-" + ver + } + if !isValidClientVersion(ver) { + d.Log().Warnf("invalid client_version value, using default...") + return "" + } + return ver +} + func (d *SSH) buildHostKeys() error { if d.params.HostKeys == "" { return nil diff --git a/internal/server/driver/ssh/params.go b/internal/server/driver/ssh/params.go index 35d3069..e61ecca 100644 --- a/internal/server/driver/ssh/params.go +++ b/internal/server/driver/ssh/params.go @@ -18,7 +18,8 @@ type Params struct { Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"` FakeRemoteHost bool `mapstructure:"fake_remote_host"` NoPTY bool `mapstructure:"nopty"` - Shell bool `mapstructure:"shell"` + Shell string `mapstructure:"shell"` + ClientVersion string `mapstructure:"client_version"` Commands types.Commands `mapstructure:"commands"` } diff --git a/internal/server/driver/ssh/sshtun/connect.go b/internal/server/driver/ssh/sshtun/connect.go index 5c83f2e..5917817 100644 --- a/internal/server/driver/ssh/sshtun/connect.go +++ b/internal/server/driver/ssh/sshtun/connect.go @@ -29,12 +29,27 @@ type TunnelConfig struct { Forward Forward HostKeys []ssh.PublicKey NoPTY bool - Shell bool + Shell BoolOrStr + ClientVersion string FakeRemoteHost bool KeepAliveInterval uint KeepAliveMax uint } +type BoolOrStr string + +func (b BoolOrStr) IsBool() bool { + return b == "true" || b == "false" +} + +func (b BoolOrStr) Falsy() bool { + return b == "" || b == "false" +} + +func (b BoolOrStr) String() string { + return string(b) +} + func New(address, user string, auth []ssh.AuthMethod, sc TunnelConfig, log *zap.SugaredLogger) *Tunnel { return &Tunnel{ address: AddrToEndpoint(address), @@ -88,6 +103,7 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi Auth: t.authMethods, HostKeyCallback: t.buildHostKeyCallback(), BannerCallback: bannerCb, + ClientVersion: t.tunConfig.ClientVersion, } var sshClient *ssh.Client @@ -125,9 +141,15 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi t.log.Warnf("PTY allocation failed: %s", err.Error()) } } - if t.tunConfig.Shell { - if err := sess.Shell(); err != nil { - t.log.Warnf("failed to start shell: %s", err.Error()) + if !t.tunConfig.Shell.Falsy() { + if t.tunConfig.Shell.IsBool() { + if err := sess.Shell(); err != nil { + t.log.Warnf("failed to start empty shell: %s", err.Error()) + } + } else { + if err := sess.Start(t.tunConfig.Shell.String()); err != nil { + t.log.Warnf("failed to start shell '%s': %s", t.tunConfig.Shell, err.Error()) + } } wg.Add(1) go func() { diff --git a/internal/server/driver/ssh/validate.go b/internal/server/driver/ssh/validate.go new file mode 100644 index 0000000..cb47c0f --- /dev/null +++ b/internal/server/driver/ssh/validate.go @@ -0,0 +1,10 @@ +package ssh + +import "regexp" + +var clientVersionVerifier = regexp.MustCompile(`^[a-zA-Z0-9\.\-\_]+\x{20}?[a-zA-Z0-9\.\-\_]+?$`) + +// isValidClientVersion returns true if provided SSH client version string is compliant with RFC-4253. +func isValidClientVersion(ver string) bool { + return clientVersionVerifier.MatchString(ver) +} diff --git a/internal/server/driver/util/params.go b/internal/server/driver/util/params.go index 24ab7e1..1052ad8 100644 --- a/internal/server/driver/util/params.go +++ b/internal/server/driver/util/params.go @@ -10,7 +10,14 @@ type ValidationAvailable interface { } func UnmarshalParams(params config.DriverParams, target ValidationAvailable) error { - if err := mapstructure.Decode(params, target); err != nil { + dec, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + Result: target, + WeaklyTypedInput: true, + }) + if err != nil { + return err + } + if err := dec.Decode(params); err != nil { return err } if val, canValidate := target.(ValidationAvailable); canValidate {