package ssh import ( "context" "encoding/base64" "errors" "fmt" "net" "path" "regexp" "strconv" "strings" "sync" "github.com/Neur0toxine/sshpoke/internal/config" "github.com/Neur0toxine/sshpoke/internal/server/driver/base" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshtun" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" "github.com/Neur0toxine/sshpoke/internal/server/driver/util" "github.com/Neur0toxine/sshpoke/pkg/dto" "github.com/Neur0toxine/sshpoke/pkg/proto/ssh" "github.com/Neur0toxine/sshpoke/pkg/proto/ssh/knownhosts" "go.uber.org/zap" ) const KnownHostsFile = "known_hosts" var ErrAlreadyInUse = errors.New("domain is already in use") type SSH struct { base.Base params Params auth []ssh.AuthMethod hostKeys []ssh.PublicKey hostKeyCallback ssh.HostKeyCallback conns map[string]conn clientVersion string rw sync.RWMutex wg sync.WaitGroup domainRegExp *regexp.Regexp } type conn struct { ctx context.Context cancel func() tun *sshtun.Tunnel } func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) { drv := &SSH{ Base: base.New(ctx, name), conns: make(map[string]conn), } if err := util.UnmarshalParams(params, &drv.params); err != nil { return nil, err } if err := drv.buildHostKeys(); err != nil { return nil, err } matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex) if err != nil { return nil, fmt.Errorf("invalid domain_extract_regex: %w", err) } drv.domainRegExp = matcher drv.populateFromSSHConfig() drv.auth = drv.authenticators() drv.clientVersion = drv.buildClientVersion() drv.hostKeyCallback = drv.buildHostKeyCallback() return drv, nil } func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn { tun := sshtun.New(d.params.Address, d.params.Auth.User, d.auth, sshtun.TunnelConfig{ Forward: val, HostKeyCallback: d.hostKeyCallback, Shell: d.params.Shell, ClientVersion: d.clientVersion, FakeRemoteHost: d.params.FakeRemoteHost, KeepAliveInterval: uint(d.params.KeepAlive.Interval), KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts), }, d.Log()) ctx, cancel := context.WithCancel(d.Context()) tunDbgLog := d.Log().With("ssh-output", val.Remote.String()) var outputReaderCb sshtun.SessionCallback if d.params.ReadSessionsOutput == nil || (*d.params.ReadSessionsOutput) { outputReaderCb = sshtun.OutputReaderCallback(func(msg string) { msg = strings.TrimSpace(msg) if msg == "" { return } tunDbgLog.Debug("session: ", msg) if domainMatcher != nil { domainMatcher(msg) } }) } go tun.Connect(ctx, d.buildHandshakeLineCallback(domainMatcher, tunDbgLog), sshtun.BannerDebugLogCallback(tunDbgLog), outputReaderCb) return conn{ctx: ctx, cancel: cancel, tun: tun} } func (d *SSH) buildHandshakeLineCallback(domainMatcher func(string), tunDbgLog *zap.SugaredLogger) func(string) { if d.params.ReadRawPackets { return func(msg string) { msg = strings.TrimSpace(msg) if msg == "" { return } tunDbgLog.Debugf("ssh: %s", msg) if domainMatcher != nil { domainMatcher(msg) } } } return nil } func (d *SSH) buildHostKeyCallback() ssh.HostKeyCallback { keysCallback := func() ssh.HostKeyCallback { if d.hostKeys == nil || len(d.hostKeys) == 0 { return ssh.InsecureIgnoreHostKey() } if len(d.hostKeys) == 1 { return ssh.FixedHostKey(d.hostKeys[0]) } return sshtun.FixedHostKeys(d.hostKeys) }() if d.params.Auth.Type == types.AuthTypeKey && d.params.Auth.Directory != "" && len(d.hostKeys) == 0 { knownHostsPath := types.SmartPath(path.Join(string(d.params.Auth.Directory), KnownHostsFile)) resolvedPath, err := knownHostsPath.Resolve(false) if err != nil { return ssh.InsecureIgnoreHostKey() } hostKeyCallback, err := knownhosts.New(resolvedPath) if err != nil { return ssh.InsecureIgnoreHostKey() } return hostKeyCallback } return keysCallback } func (d *SSH) buildClientVersion() string { ver := strings.TrimSpace(d.params.ClientVersion) if ver == "" { return "" } if !strings.HasPrefix(ver, "SSH-2.0-") { d.Log().Warn( "client_version must have 'SSH-2.0-' prefix (see RFC-4253), this will be fixed automatically") ver = "SSH-2.0-" + ver } if !isValidClientVersion(ver) { d.Log().Warnf("invalid client_version value, using default...") return "" } return ver } func (d *SSH) buildHostKeys() error { if d.params.HostKeys == "" { return nil } hostKeys := []ssh.PublicKey{} for _, keyLine := range strings.Split(d.params.HostKeys, "\n") { key, err := d.pubKeyFromSSHKeyScan(keyLine) if err != nil { d.Log().Debugf("invalid public key: %s", keyLine) return fmt.Errorf("invalid public key for the host: %w", err) } if key != nil { hostKeys = append(hostKeys, key) } } d.hostKeys = hostKeys return nil } // pubKeyFromSSHKeyScan extracts host public key from ssh-keyscan output format. func (d *SSH) pubKeyFromSSHKeyScan(line string) (key ssh.PublicKey, err error) { line = strings.TrimSpace(line) if strings.HasPrefix(line, "#") || line == "" { // comment or empty line - should be ignored. return nil, nil } cols := strings.Fields(line) for i := len(cols) - 1; i >= 0; i-- { col := strings.TrimSpace(cols[i]) keyData, err := base64.StdEncoding.DecodeString(col) if err != nil { continue } key, err = ssh.ParsePublicKey(keyData) if err == nil { return key, nil } } return nil, errors.New("no public key in the provided data") } func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) { if d.domainRegExp == nil { return nil } return func(msg string) { domain := d.domainRegExp.FindString(msg) if domain == "" { return } d.PushEventStatus(dto.EventStatus{ Type: dto.EventStart, ID: containerID, Domain: domain, }) } } func (d *SSH) populateFromSSHConfig() { if d.params.Auth.Directory == "" { return } cfg, err := parseSSHConfig(types.SmartPath(path.Join(string(d.params.Auth.Directory), "config"))) if err != nil { return } host := d.extractHostFromAddr(d.params.Address) hostCfg := &hostConfig{cfg: cfg, host: host} port, err := hostCfg.Get("Port") if err != nil { port = "22" } if hostName, err := hostCfg.Get("HostName"); err == nil && hostName != "" { d.params.Address = net.JoinHostPort(hostName, port) } if user, err := hostCfg.Get("User"); err == nil && user != "" { d.params.Auth.User = user } if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" { resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false) if err == nil { d.params.Auth.Type = types.AuthTypeKey d.params.Auth.Keyfile = resolvedKeyFile } } } func (d *SSH) extractHostFromAddr(addr string) string { host, _, err := net.SplitHostPort(addr) if err != nil { return addr } return host } func (d *SSH) Handle(event dto.Event) error { defer d.rw.Unlock() d.rw.Lock() switch event.Type { case dto.EventStart: if d.params.Mode == types.DomainModeSingle && len(d.conns) > 0 { return ErrAlreadyInUse } conn := d.forward(sshtun.Forward{ Local: d.localEndpoint(event.Container.IP, event.Container.Port), Remote: d.remoteEndpoint(event.Container.RemoteHost), }, d.makeDomainMatcherFunc(event.Container.ID)) d.conns[event.Container.ID] = conn d.wg.Add(1) case dto.EventStop: conn, found := d.conns[event.Container.ID] if !found { return nil } conn.cancel() delete(d.conns, event.Container.ID) d.propagateStop(event.Container.ID) d.wg.Done() case dto.EventShutdown: for id, conn := range d.conns { conn.cancel() delete(d.conns, id) d.propagateStop(id) d.wg.Done() } } return nil } func (d *SSH) propagateStop(containerID string) { d.PushEventStatus(dto.EventStatus{Type: dto.EventStop, ID: containerID}) } func (d *SSH) localEndpoint(ip net.IP, port uint16) sshtun.Endpoint { return sshtun.AddrToEndpoint(net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))) } func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint { port := int(d.params.ForwardPort) if port == 0 { port = 80 } if remoteHost == "" && !d.params.FakeRemoteHost { // Listen on all interfaces if no host was provided. remoteHost = "0.0.0.0" } return sshtun.Endpoint{ Host: remoteHost, Port: port, } } func (d *SSH) Driver() config.DriverType { return config.DriverSSH } func (d *SSH) WaitForShutdown() { go d.Handle(dto.Event{Type: dto.EventShutdown}) d.wg.Wait() } func (d *SSH) authenticators() []ssh.AuthMethod { auth := d.authenticator() if auth == nil { return nil } return []ssh.AuthMethod{auth} } func (d *SSH) authenticator() ssh.AuthMethod { switch d.params.Auth.Type { case types.AuthTypePasswordless: return sshtun.AuthPassword("") case types.AuthTypePassword: return sshtun.AuthPassword(d.params.Auth.Password) case types.AuthTypeKey: if d.params.Auth.Keyfile != "" { keyAuth, err := sshtun.AuthKeyFile( types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile))) if err != nil { return nil } return keyAuth } dirAuth, err := sshtun.AuthKeyDir(d.params.Auth.Directory) if err != nil { return nil } return dirAuth } return nil }