diff --git a/config.example.yml b/config.example.yml index ca64c98..a506a78 100644 --- a/config.example.yml +++ b/config.example.yml @@ -50,7 +50,8 @@ servers: # Remote user user: user # Directory with SSH keys. ssh-config from this directory will be used if `keyfile` is not provided. - # Only some of the ssh-config attributes are used. + # Supported ssh-config directives: HostName, Port, User, IdentityFile + # known_hosts from this directory will be used if `host_keys` is not provided. directory: "~/.ssh" # Expose mode (multiple domains or single domain). Allowed values: single, multi. mode: multi diff --git a/internal/server/driver/ssh/driver.go b/internal/server/driver/ssh/driver.go index 5172203..1f057e0 100644 --- a/internal/server/driver/ssh/driver.go +++ b/internal/server/driver/ssh/driver.go @@ -19,20 +19,24 @@ import ( "github.com/Neur0toxine/sshpoke/internal/server/driver/util" "github.com/Neur0toxine/sshpoke/pkg/dto" "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" + "github.com/Neur0toxine/sshpoke/pkg/proto/ssh/knownhosts" ) +const KnownHostsFile = "known_hosts" + var ErrAlreadyInUse = errors.New("domain is already in use") type SSH struct { base.Base - params Params - auth []ssh.AuthMethod - hostKeys []ssh.PublicKey - conns map[string]conn - clientVersion string - rw sync.RWMutex - wg sync.WaitGroup - domainRegExp *regexp.Regexp + params Params + auth []ssh.AuthMethod + hostKeys []ssh.PublicKey + hostKeyCallback ssh.HostKeyCallback + conns map[string]conn + clientVersion string + rw sync.RWMutex + wg sync.WaitGroup + domainRegExp *regexp.Regexp } type conn struct { @@ -60,6 +64,7 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri drv.populateFromSSHConfig() drv.auth = drv.authenticators() drv.clientVersion = drv.buildClientVersion() + drv.hostKeyCallback = drv.buildHostKeyCallback() return drv, nil } @@ -69,7 +74,7 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { d.auth, sshtun.TunnelConfig{ Forward: val, - HostKeys: d.hostKeys, + HostKeyCallback: d.hostKeyCallback, NoPTY: d.params.NoPTY, Shell: sshtun.BoolOrStr(d.params.Shell), ClientVersion: d.clientVersion, @@ -91,6 +96,31 @@ func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { return conn{ctx: ctx, cancel: cancel, tun: tun} } +func (d *SSH) buildHostKeyCallback() ssh.HostKeyCallback { + keysCallback := func() ssh.HostKeyCallback { + if d.hostKeys == nil || len(d.hostKeys) == 0 { + return ssh.InsecureIgnoreHostKey() + } + if len(d.hostKeys) == 1 { + return ssh.FixedHostKey(d.hostKeys[0]) + } + return sshtun.FixedHostKeys(d.hostKeys) + }() + if d.params.Auth.Type == types.AuthTypeKey && d.params.Auth.Directory != "" && len(d.hostKeys) == 0 { + knownHostsPath := types.SmartPath(path.Join(string(d.params.Auth.Directory), KnownHostsFile)) + resolvedPath, err := knownHostsPath.Resolve(false) + if err != nil { + return ssh.InsecureIgnoreHostKey() + } + hostKeyCallback, err := knownhosts.New(resolvedPath) + if err != nil { + return ssh.InsecureIgnoreHostKey() + } + return hostKeyCallback + } + return keysCallback +} + func (d *SSH) buildClientVersion() string { ver := strings.TrimSpace(d.params.ClientVersion) if ver == "" { @@ -186,9 +216,6 @@ func (d *SSH) populateFromSSHConfig() { if user, err := hostCfg.Get("User"); err == nil && user != "" { d.params.Auth.User = user } - if usePass, err := hostCfg.Get("PasswordAuthentication"); err == nil && usePass == "yes" { - d.params.Auth.Type = types.AuthTypePassword - } if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" { resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false) if err == nil { diff --git a/internal/server/driver/ssh/sshtun/connect.go b/internal/server/driver/ssh/sshtun/connect.go index 5917817..c12b451 100644 --- a/internal/server/driver/ssh/sshtun/connect.go +++ b/internal/server/driver/ssh/sshtun/connect.go @@ -27,7 +27,7 @@ type Tunnel struct { type TunnelConfig struct { Forward Forward - HostKeys []ssh.PublicKey + HostKeyCallback ssh.HostKeyCallback NoPTY bool Shell BoolOrStr ClientVersion string @@ -84,16 +84,6 @@ func (t *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi } } -func (t *Tunnel) buildHostKeyCallback() ssh.HostKeyCallback { - if t.tunConfig.HostKeys == nil || len(t.tunConfig.HostKeys) == 0 { - return ssh.InsecureIgnoreHostKey() - } - if len(t.tunConfig.HostKeys) == 1 { - return ssh.FixedHostKey(t.tunConfig.HostKeys[0]) - } - return FixedHostKeys(t.tunConfig.HostKeys) -} - // connect once to the SSH server. if the connection breaks, we return error and the caller // will try to re-connect func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error { @@ -101,7 +91,7 @@ func (t *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessi sshConfig := &ssh.ClientConfig{ User: t.user, Auth: t.authMethods, - HostKeyCallback: t.buildHostKeyCallback(), + HostKeyCallback: t.tunConfig.HostKeyCallback, BannerCallback: bannerCb, ClientVersion: t.tunConfig.ClientVersion, } diff --git a/internal/server/driver/ssh/sshtun/fixed_host_keys.go b/internal/server/driver/ssh/sshtun/fixed_host_keys.go index f95a5eb..efa2767 100644 --- a/internal/server/driver/ssh/sshtun/fixed_host_keys.go +++ b/internal/server/driver/ssh/sshtun/fixed_host_keys.go @@ -34,3 +34,16 @@ func (f *fixedHostKeys) check(hostname string, remote net.Addr, key ssh.PublicKe } return nil } + +func CombineHostKeyCallbacks(callbacks ...ssh.HostKeyCallback) ssh.HostKeyCallback { + return func(hostname string, remote net.Addr, key ssh.PublicKey) error { + var err error + for _, cb := range callbacks { + err = cb(hostname, remote, key) + if err == nil { + return nil + } + } + return err + } +}