ssh: add known_hosts support, remove auth type override via PasswordAuthentication

This commit is contained in:
Pavel 2023-11-20 22:44:14 +03:00
parent a3a56bb795
commit fa9b7b1838
4 changed files with 56 additions and 25 deletions

View File

@ -50,7 +50,8 @@ servers:
# Remote user # Remote user
user: user user: user
# Directory with SSH keys. ssh-config from this directory will be used if `keyfile` is not provided. # Directory with SSH keys. ssh-config from this directory will be used if `keyfile` is not provided.
# Only some of the ssh-config attributes are used. # Supported ssh-config directives: HostName, Port, User, IdentityFile
# known_hosts from this directory will be used if `host_keys` is not provided.
directory: "~/.ssh" directory: "~/.ssh"
# Expose mode (multiple domains or single domain). Allowed values: single, multi. # Expose mode (multiple domains or single domain). Allowed values: single, multi.
mode: multi mode: multi

View File

@ -19,8 +19,11 @@ import (
"github.com/Neur0toxine/sshpoke/internal/server/driver/util" "github.com/Neur0toxine/sshpoke/internal/server/driver/util"
"github.com/Neur0toxine/sshpoke/pkg/dto" "github.com/Neur0toxine/sshpoke/pkg/dto"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh" "github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh/knownhosts"
) )
const KnownHostsFile = "known_hosts"
var ErrAlreadyInUse = errors.New("domain is already in use") var ErrAlreadyInUse = errors.New("domain is already in use")
type SSH struct { type SSH struct {
@ -28,6 +31,7 @@ type SSH struct {
params Params params Params
auth []ssh.AuthMethod auth []ssh.AuthMethod
hostKeys []ssh.PublicKey hostKeys []ssh.PublicKey
hostKeyCallback ssh.HostKeyCallback
conns map[string]conn conns map[string]conn
clientVersion string clientVersion string
rw sync.RWMutex rw sync.RWMutex
@ -60,6 +64,7 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
drv.populateFromSSHConfig() drv.populateFromSSHConfig()
drv.auth = drv.authenticators() drv.auth = drv.authenticators()
drv.clientVersion = drv.buildClientVersion() drv.clientVersion = drv.buildClientVersion()
drv.hostKeyCallback = drv.buildHostKeyCallback()
return drv, nil return drv, nil
} }
@ -69,7 +74,7 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
d.auth, d.auth,
sshtun.TunnelConfig{ sshtun.TunnelConfig{
Forward: val, Forward: val,
HostKeys: d.hostKeys, HostKeyCallback: d.hostKeyCallback,
NoPTY: d.params.NoPTY, NoPTY: d.params.NoPTY,
Shell: sshtun.BoolOrStr(d.params.Shell), Shell: sshtun.BoolOrStr(d.params.Shell),
ClientVersion: d.clientVersion, ClientVersion: d.clientVersion,
@ -91,6 +96,31 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
return conn{ctx: ctx, cancel: cancel, tun: tun} return conn{ctx: ctx, cancel: cancel, tun: tun}
} }
func (d *SSH) buildHostKeyCallback() ssh.HostKeyCallback {
keysCallback := func() ssh.HostKeyCallback {
if d.hostKeys == nil || len(d.hostKeys) == 0 {
return ssh.InsecureIgnoreHostKey()
}
if len(d.hostKeys) == 1 {
return ssh.FixedHostKey(d.hostKeys[0])
}
return sshtun.FixedHostKeys(d.hostKeys)
}()
if d.params.Auth.Type == types.AuthTypeKey && d.params.Auth.Directory != "" && len(d.hostKeys) == 0 {
knownHostsPath := types.SmartPath(path.Join(string(d.params.Auth.Directory), KnownHostsFile))
resolvedPath, err := knownHostsPath.Resolve(false)
if err != nil {
return ssh.InsecureIgnoreHostKey()
}
hostKeyCallback, err := knownhosts.New(resolvedPath)
if err != nil {
return ssh.InsecureIgnoreHostKey()
}
return hostKeyCallback
}
return keysCallback
}
func (d *SSH) buildClientVersion() string { func (d *SSH) buildClientVersion() string {
ver := strings.TrimSpace(d.params.ClientVersion) ver := strings.TrimSpace(d.params.ClientVersion)
if ver == "" { if ver == "" {
@ -186,9 +216,6 @@ func (d *SSH) populateFromSSHConfig() {
if user, err := hostCfg.Get("User"); err == nil && user != "" { if user, err := hostCfg.Get("User"); err == nil && user != "" {
d.params.Auth.User = user d.params.Auth.User = user
} }
if usePass, err := hostCfg.Get("PasswordAuthentication"); err == nil && usePass == "yes" {
d.params.Auth.Type = types.AuthTypePassword
}
if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" { if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" {
resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false) resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false)
if err == nil { if err == nil {

View File

@ -27,7 +27,7 @@ type Tunnel struct {
type TunnelConfig struct { type TunnelConfig struct {
Forward Forward Forward Forward
HostKeys []ssh.PublicKey HostKeyCallback ssh.HostKeyCallback
NoPTY bool NoPTY bool
Shell BoolOrStr Shell BoolOrStr
ClientVersion string ClientVersion string
@ -84,16 +84,6 @@ func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi
} }
} }
func (t *Tunnel) buildHostKeyCallback() ssh.HostKeyCallback {
if t.tunConfig.HostKeys == nil || len(t.tunConfig.HostKeys) == 0 {
return ssh.InsecureIgnoreHostKey()
}
if len(t.tunConfig.HostKeys) == 1 {
return ssh.FixedHostKey(t.tunConfig.HostKeys[0])
}
return FixedHostKeys(t.tunConfig.HostKeys)
}
// connect once to the SSH server. if the connection breaks, we return error and the caller // connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect // will try to re-connect
func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error { func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
@ -101,7 +91,7 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi
sshConfig := &ssh.ClientConfig{ sshConfig := &ssh.ClientConfig{
User: t.user, User: t.user,
Auth: t.authMethods, Auth: t.authMethods,
HostKeyCallback: t.buildHostKeyCallback(), HostKeyCallback: t.tunConfig.HostKeyCallback,
BannerCallback: bannerCb, BannerCallback: bannerCb,
ClientVersion: t.tunConfig.ClientVersion, ClientVersion: t.tunConfig.ClientVersion,
} }

View File

@ -34,3 +34,16 @@ func (f *fixedHostKeys) check(hostname string, remote net.Addr, key ssh.PublicKe
} }
return nil return nil
} }
func CombineHostKeyCallbacks(callbacks ...ssh.HostKeyCallback) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
var err error
for _, cb := range callbacks {
err = cb(hostname, remote, key)
if err == nil {
return nil
}
}
return err
}
}