ssh: add known_hosts support, remove auth type override via PasswordAuthentication

This commit is contained in:
Pavel 2023-11-20 22:44:14 +03:00
parent a3a56bb795
commit fa9b7b1838
4 changed files with 56 additions and 25 deletions

View File

@ -50,7 +50,8 @@ servers:
# Remote user
user: user
# Directory with SSH keys. ssh-config from this directory will be used if `keyfile` is not provided.
# Only some of the ssh-config attributes are used.
# Supported ssh-config directives: HostName, Port, User, IdentityFile
# known_hosts from this directory will be used if `host_keys` is not provided.
directory: "~/.ssh"
# Expose mode (multiple domains or single domain). Allowed values: single, multi.
mode: multi

View File

@ -19,8 +19,11 @@ import (
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
"github.com/Neur0toxine/sshpoke/pkg/dto"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh/knownhosts"
)
const KnownHostsFile = "known_hosts"
var ErrAlreadyInUse = errors.New("domain is already in use")
type SSH struct {
@ -28,6 +31,7 @@ type SSH struct {
params Params
auth []ssh.AuthMethod
hostKeys []ssh.PublicKey
hostKeyCallback ssh.HostKeyCallback
conns map[string]conn
clientVersion string
rw sync.RWMutex
@ -60,6 +64,7 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
drv.populateFromSSHConfig()
drv.auth = drv.authenticators()
drv.clientVersion = drv.buildClientVersion()
drv.hostKeyCallback = drv.buildHostKeyCallback()
return drv, nil
}
@ -69,7 +74,7 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
d.auth,
sshtun.TunnelConfig{
Forward: val,
HostKeys: d.hostKeys,
HostKeyCallback: d.hostKeyCallback,
NoPTY: d.params.NoPTY,
Shell: sshtun.BoolOrStr(d.params.Shell),
ClientVersion: d.clientVersion,
@ -91,6 +96,31 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
return conn{ctx: ctx, cancel: cancel, tun: tun}
}
func (d *SSH) buildHostKeyCallback() ssh.HostKeyCallback {
keysCallback := func() ssh.HostKeyCallback {
if d.hostKeys == nil || len(d.hostKeys) == 0 {
return ssh.InsecureIgnoreHostKey()
}
if len(d.hostKeys) == 1 {
return ssh.FixedHostKey(d.hostKeys[0])
}
return sshtun.FixedHostKeys(d.hostKeys)
}()
if d.params.Auth.Type == types.AuthTypeKey && d.params.Auth.Directory != "" && len(d.hostKeys) == 0 {
knownHostsPath := types.SmartPath(path.Join(string(d.params.Auth.Directory), KnownHostsFile))
resolvedPath, err := knownHostsPath.Resolve(false)
if err != nil {
return ssh.InsecureIgnoreHostKey()
}
hostKeyCallback, err := knownhosts.New(resolvedPath)
if err != nil {
return ssh.InsecureIgnoreHostKey()
}
return hostKeyCallback
}
return keysCallback
}
func (d *SSH) buildClientVersion() string {
ver := strings.TrimSpace(d.params.ClientVersion)
if ver == "" {
@ -186,9 +216,6 @@ func (d *SSH) populateFromSSHConfig() {
if user, err := hostCfg.Get("User"); err == nil && user != "" {
d.params.Auth.User = user
}
if usePass, err := hostCfg.Get("PasswordAuthentication"); err == nil && usePass == "yes" {
d.params.Auth.Type = types.AuthTypePassword
}
if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" {
resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false)
if err == nil {

View File

@ -27,7 +27,7 @@ type Tunnel struct {
type TunnelConfig struct {
Forward Forward
HostKeys []ssh.PublicKey
HostKeyCallback ssh.HostKeyCallback
NoPTY bool
Shell BoolOrStr
ClientVersion string
@ -84,16 +84,6 @@ func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi
}
}
func (t *Tunnel) buildHostKeyCallback() ssh.HostKeyCallback {
if t.tunConfig.HostKeys == nil || len(t.tunConfig.HostKeys) == 0 {
return ssh.InsecureIgnoreHostKey()
}
if len(t.tunConfig.HostKeys) == 1 {
return ssh.FixedHostKey(t.tunConfig.HostKeys[0])
}
return FixedHostKeys(t.tunConfig.HostKeys)
}
// connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect
func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
@ -101,7 +91,7 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi
sshConfig := &ssh.ClientConfig{
User: t.user,
Auth: t.authMethods,
HostKeyCallback: t.buildHostKeyCallback(),
HostKeyCallback: t.tunConfig.HostKeyCallback,
BannerCallback: bannerCb,
ClientVersion: t.tunConfig.ClientVersion,
}

View File

@ -34,3 +34,16 @@ func (f *fixedHostKeys) check(hostname string, remote net.Addr, key ssh.PublicKe
}
return nil
}
func CombineHostKeyCallbacks(callbacks ...ssh.HostKeyCallback) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
var err error
for _, cb := range callbacks {
err = cb(hostname, remote, key)
if err == nil {
return nil
}
}
return err
}
}