sshpoke/internal/server/driver/ssh/driver.go

329 lines
8.4 KiB
Go

package ssh
import (
"context"
"encoding/base64"
"errors"
"fmt"
"net"
"path"
"regexp"
"strconv"
"strings"
"sync"
"github.com/Neur0toxine/sshpoke/internal/config"
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshtun"
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
"github.com/Neur0toxine/sshpoke/pkg/dto"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh/knownhosts"
)
const KnownHostsFile = "known_hosts"
var ErrAlreadyInUse = errors.New("domain is already in use")
type SSH struct {
base.Base
params Params
auth []ssh.AuthMethod
hostKeys []ssh.PublicKey
hostKeyCallback ssh.HostKeyCallback
conns map[string]conn
clientVersion string
rw sync.RWMutex
wg sync.WaitGroup
domainRegExp *regexp.Regexp
}
type conn struct {
ctx context.Context
cancel func()
tun *sshtun.Tunnel
}
func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) {
drv := &SSH{
Base: base.New(ctx, name),
conns: make(map[string]conn),
}
if err := util.UnmarshalParams(params, &drv.params); err != nil {
return nil, err
}
if err := drv.buildHostKeys(); err != nil {
return nil, err
}
matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex)
if err != nil {
return nil, fmt.Errorf("invalid domain_extract_regex: %w", err)
}
drv.domainRegExp = matcher
drv.populateFromSSHConfig()
drv.auth = drv.authenticators()
drv.clientVersion = drv.buildClientVersion()
drv.hostKeyCallback = drv.buildHostKeyCallback()
return drv, nil
}
func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
tun := sshtun.New(d.params.Address,
d.params.Auth.User,
d.auth,
sshtun.TunnelConfig{
Forward: val,
HostKeyCallback: d.hostKeyCallback,
NoPTY: d.params.NoPTY,
Shell: sshtun.BoolOrStr(d.params.Shell),
ClientVersion: d.clientVersion,
FakeRemoteHost: d.params.FakeRemoteHost,
KeepAliveInterval: uint(d.params.KeepAlive.Interval),
KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts),
},
d.Log())
ctx, cancel := context.WithCancel(d.Context())
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
go tun.Connect(ctx,
sshtun.BannerDebugLogCallback(tunDbgLog),
sshtun.OutputReaderCallback(func(msg string) {
d.Log().Debug(msg)
if domainMatcher != nil {
domainMatcher(msg)
}
}))
return conn{ctx: ctx, cancel: cancel, tun: tun}
}
func (d *SSH) buildHostKeyCallback() ssh.HostKeyCallback {
keysCallback := func() ssh.HostKeyCallback {
if d.hostKeys == nil || len(d.hostKeys) == 0 {
return ssh.InsecureIgnoreHostKey()
}
if len(d.hostKeys) == 1 {
return ssh.FixedHostKey(d.hostKeys[0])
}
return sshtun.FixedHostKeys(d.hostKeys)
}()
if d.params.Auth.Type == types.AuthTypeKey && d.params.Auth.Directory != "" && len(d.hostKeys) == 0 {
knownHostsPath := types.SmartPath(path.Join(string(d.params.Auth.Directory), KnownHostsFile))
resolvedPath, err := knownHostsPath.Resolve(false)
if err != nil {
return ssh.InsecureIgnoreHostKey()
}
hostKeyCallback, err := knownhosts.New(resolvedPath)
if err != nil {
return ssh.InsecureIgnoreHostKey()
}
return hostKeyCallback
}
return keysCallback
}
func (d *SSH) buildClientVersion() string {
ver := strings.TrimSpace(d.params.ClientVersion)
if ver == "" {
return ""
}
if !strings.HasPrefix(ver, "SSH-2.0-") {
d.Log().Warn(
"client_version must have 'SSH-2.0-' prefix (see RFC-4253), this will be fixed automatically")
ver = "SSH-2.0-" + ver
}
if !isValidClientVersion(ver) {
d.Log().Warnf("invalid client_version value, using default...")
return ""
}
return ver
}
func (d *SSH) buildHostKeys() error {
if d.params.HostKeys == "" {
return nil
}
hostKeys := []ssh.PublicKey{}
for _, keyLine := range strings.Split(d.params.HostKeys, "\n") {
key, err := d.pubKeyFromSSHKeyScan(keyLine)
if err != nil {
d.Log().Debugf("invalid public key: %s", keyLine)
return fmt.Errorf("invalid public key for the host: %w", err)
}
if key != nil {
hostKeys = append(hostKeys, key)
}
}
d.hostKeys = hostKeys
return nil
}
// pubKeyFromSSHKeyScan extracts host public key from ssh-keyscan output format.
func (d *SSH) pubKeyFromSSHKeyScan(line string) (key ssh.PublicKey, err error) {
line = strings.TrimSpace(line)
if strings.HasPrefix(line, "#") || line == "" { // comment or empty line - should be ignored.
return nil, nil
}
cols := strings.Fields(line)
for i := len(cols) - 1; i >= 0; i-- {
col := strings.TrimSpace(cols[i])
keyData, err := base64.StdEncoding.DecodeString(col)
if err != nil {
continue
}
key, err = ssh.ParsePublicKey(keyData)
if err == nil {
return key, nil
}
}
return nil, errors.New("no public key in the provided data")
}
func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) {
if d.domainRegExp == nil {
return nil
}
return func(msg string) {
domain := d.domainRegExp.FindString(msg)
if domain == "" {
return
}
d.PushEventStatus(dto.EventStatus{
Type: dto.EventStart,
ID: containerID,
Domain: domain,
})
}
}
func (d *SSH) populateFromSSHConfig() {
if d.params.Auth.Directory == "" {
return
}
cfg, err := parseSSHConfig(types.SmartPath(path.Join(string(d.params.Auth.Directory), "config")))
if err != nil {
return
}
host := d.extractHostFromAddr(d.params.Address)
hostCfg := &hostConfig{cfg: cfg, host: host}
port, err := hostCfg.Get("Port")
if err != nil {
port = "22"
}
if hostName, err := hostCfg.Get("HostName"); err == nil && hostName != "" {
d.params.Address = net.JoinHostPort(hostName, port)
}
if user, err := hostCfg.Get("User"); err == nil && user != "" {
d.params.Auth.User = user
}
if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" {
resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false)
if err == nil {
d.params.Auth.Type = types.AuthTypeKey
d.params.Auth.Keyfile = resolvedKeyFile
}
}
}
func (d *SSH) extractHostFromAddr(addr string) string {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return addr
}
return host
}
func (d *SSH) Handle(event dto.Event) error {
defer d.rw.Unlock()
d.rw.Lock()
switch event.Type {
case dto.EventStart:
if d.params.Mode == types.DomainModeSingle && len(d.conns) > 0 {
return ErrAlreadyInUse
}
conn := d.forward(sshtun.Forward{
Local: d.localEndpoint(event.Container.IP, event.Container.Port),
Remote: d.remoteEndpoint(event.Container.RemoteHost),
}, d.makeDomainMatcherFunc(event.Container.ID))
d.conns[event.Container.ID] = conn
d.wg.Add(1)
case dto.EventStop:
conn, found := d.conns[event.Container.ID]
if !found {
return nil
}
conn.cancel()
delete(d.conns, event.Container.ID)
d.propagateStop(event.Container.ID)
d.wg.Done()
case dto.EventShutdown:
for id, conn := range d.conns {
conn.cancel()
delete(d.conns, id)
d.propagateStop(id)
d.wg.Done()
}
}
return nil
}
func (d *SSH) propagateStop(containerID string) {
d.PushEventStatus(dto.EventStatus{Type: dto.EventStop, ID: containerID})
}
func (d *SSH) localEndpoint(ip net.IP, port uint16) sshtun.Endpoint {
return sshtun.AddrToEndpoint(net.JoinHostPort(ip.String(), strconv.Itoa(int(port))))
}
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
port := int(d.params.ForwardPort)
if port == 0 {
port = 80
}
return sshtun.Endpoint{
Host: remoteHost,
Port: port,
}
}
func (d *SSH) Driver() config.DriverType {
return config.DriverSSH
}
func (d *SSH) WaitForShutdown() {
go d.Handle(dto.Event{Type: dto.EventShutdown})
d.wg.Wait()
}
func (d *SSH) authenticators() []ssh.AuthMethod {
auth := d.authenticator()
if auth == nil {
return nil
}
return []ssh.AuthMethod{auth}
}
func (d *SSH) authenticator() ssh.AuthMethod {
switch d.params.Auth.Type {
case types.AuthTypePasswordless:
return sshtun.AuthPassword("")
case types.AuthTypePassword:
return sshtun.AuthPassword(d.params.Auth.Password)
case types.AuthTypeKey:
if d.params.Auth.Keyfile != "" {
keyAuth, err := sshtun.AuthKeyFile(
types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile)))
if err != nil {
return nil
}
return keyAuth
}
dirAuth, err := sshtun.AuthKeyDir(d.params.Auth.Directory)
if err != nil {
return nil
}
return dirAuth
}
return nil
}