282 lines
7.0 KiB
Go
282 lines
7.0 KiB
Go
package ssh
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"path"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
|
|
"github.com/Neur0toxine/sshpoke/internal/config"
|
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
|
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshtun"
|
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
|
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
|
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
|
)
|
|
|
|
var ErrAlreadyInUse = errors.New("domain is already in use")
|
|
|
|
type SSH struct {
|
|
base.Base
|
|
params Params
|
|
auth []ssh.AuthMethod
|
|
hostKeys []ssh.PublicKey
|
|
conns map[string]conn
|
|
rw sync.RWMutex
|
|
wg sync.WaitGroup
|
|
domainRegExp *regexp.Regexp
|
|
}
|
|
|
|
type conn struct {
|
|
ctx context.Context
|
|
cancel func()
|
|
tun *sshtun.Tunnel
|
|
}
|
|
|
|
func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) {
|
|
drv := &SSH{
|
|
Base: base.New(ctx, name),
|
|
conns: make(map[string]conn),
|
|
}
|
|
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
|
return nil, err
|
|
}
|
|
if err := drv.buildHostKeys(); err != nil {
|
|
return nil, err
|
|
}
|
|
matcher, err := makeDomainCatchRegExp(drv.params.DomainExtractRegex)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("invalid domain_extract_regex: %w", err)
|
|
}
|
|
drv.domainRegExp = matcher
|
|
drv.populateFromSSHConfig()
|
|
drv.auth = drv.authenticators()
|
|
return drv, nil
|
|
}
|
|
|
|
func (d *SSH) forward(val sshtun.Forward, domainMatcher func(string)) conn {
|
|
tun := sshtun.New(d.params.Address,
|
|
d.params.Auth.User,
|
|
d.auth,
|
|
sshtun.TunnelConfig{
|
|
Forward: val,
|
|
HostKeys: d.hostKeys,
|
|
NoPTY: d.params.NoPTY,
|
|
Shell: d.params.Shell,
|
|
FakeRemoteHost: d.params.FakeRemoteHost,
|
|
KeepAliveInterval: uint(d.params.KeepAlive.Interval),
|
|
KeepAliveMax: uint(d.params.KeepAlive.MaxAttempts),
|
|
},
|
|
d.Log())
|
|
ctx, cancel := context.WithCancel(d.Context())
|
|
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
|
|
go tun.Connect(ctx,
|
|
sshtun.BannerDebugLogCallback(tunDbgLog),
|
|
sshtun.OutputReaderCallback(func(msg string) {
|
|
d.Log().Debug(msg)
|
|
if domainMatcher != nil {
|
|
domainMatcher(msg)
|
|
}
|
|
}))
|
|
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
|
}
|
|
|
|
func (d *SSH) buildHostKeys() error {
|
|
if d.params.HostKeys == "" {
|
|
return nil
|
|
}
|
|
hostKeys := []ssh.PublicKey{}
|
|
for _, keyLine := range strings.Split(d.params.HostKeys, "\n") {
|
|
key, err := d.pubKeyFromSSHKeyScan(keyLine)
|
|
if err != nil {
|
|
d.Log().Debugf("invalid public key: %s", keyLine)
|
|
return fmt.Errorf("invalid public key for the host: %w", err)
|
|
}
|
|
if key != nil {
|
|
hostKeys = append(hostKeys, key)
|
|
}
|
|
}
|
|
d.hostKeys = hostKeys
|
|
return nil
|
|
}
|
|
|
|
// pubKeyFromSSHKeyScan extracts host public key from ssh-keyscan output format.
|
|
func (d *SSH) pubKeyFromSSHKeyScan(line string) (key ssh.PublicKey, err error) {
|
|
line = strings.TrimSpace(line)
|
|
if strings.HasPrefix(line, "#") || line == "" { // comment or empty line - should be ignored.
|
|
return nil, nil
|
|
}
|
|
cols := strings.Fields(line)
|
|
for i := len(cols) - 1; i >= 0; i-- {
|
|
col := strings.TrimSpace(cols[i])
|
|
keyData, err := base64.StdEncoding.DecodeString(col)
|
|
if err != nil {
|
|
continue
|
|
}
|
|
key, err = ssh.ParsePublicKey(keyData)
|
|
if err == nil {
|
|
return key, nil
|
|
}
|
|
}
|
|
return nil, errors.New("no public key in the provided data")
|
|
}
|
|
|
|
func (d *SSH) makeDomainMatcherFunc(containerID string) func(string) {
|
|
if d.domainRegExp == nil {
|
|
return nil
|
|
}
|
|
return func(msg string) {
|
|
domain := d.domainRegExp.FindString(msg)
|
|
if domain == "" {
|
|
return
|
|
}
|
|
d.PushEventStatus(dto.EventStatus{
|
|
Type: dto.EventStart,
|
|
ID: containerID,
|
|
Domain: domain,
|
|
})
|
|
}
|
|
}
|
|
|
|
func (d *SSH) populateFromSSHConfig() {
|
|
if d.params.Auth.Directory == "" {
|
|
return
|
|
}
|
|
cfg, err := parseSSHConfig(types.SmartPath(path.Join(string(d.params.Auth.Directory), "config")))
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
host := d.extractHostFromAddr(d.params.Address)
|
|
hostCfg := &hostConfig{cfg: cfg, host: host}
|
|
port, err := hostCfg.Get("Port")
|
|
if err != nil {
|
|
port = "22"
|
|
}
|
|
if hostName, err := hostCfg.Get("HostName"); err == nil && hostName != "" {
|
|
d.params.Address = net.JoinHostPort(hostName, port)
|
|
}
|
|
if user, err := hostCfg.Get("User"); err == nil && user != "" {
|
|
d.params.Auth.User = user
|
|
}
|
|
if usePass, err := hostCfg.Get("PasswordAuthentication"); err == nil && usePass == "yes" {
|
|
d.params.Auth.Type = types.AuthTypePassword
|
|
}
|
|
if keyfile, err := hostCfg.Get("IdentityFile"); err == nil && keyfile != "" {
|
|
resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false)
|
|
if err == nil {
|
|
d.params.Auth.Type = types.AuthTypeKey
|
|
d.params.Auth.Keyfile = resolvedKeyFile
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *SSH) extractHostFromAddr(addr string) string {
|
|
host, _, err := net.SplitHostPort(addr)
|
|
if err != nil {
|
|
return addr
|
|
}
|
|
return host
|
|
}
|
|
|
|
func (d *SSH) Handle(event dto.Event) error {
|
|
defer d.rw.Unlock()
|
|
d.rw.Lock()
|
|
switch event.Type {
|
|
case dto.EventStart:
|
|
if d.params.Mode == types.DomainModeSingle && len(d.conns) > 0 {
|
|
return ErrAlreadyInUse
|
|
}
|
|
conn := d.forward(sshtun.Forward{
|
|
Local: d.localEndpoint(event.Container.IP, event.Container.Port),
|
|
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
|
}, d.makeDomainMatcherFunc(event.Container.ID))
|
|
d.conns[event.Container.ID] = conn
|
|
d.wg.Add(1)
|
|
case dto.EventStop:
|
|
conn, found := d.conns[event.Container.ID]
|
|
if !found {
|
|
return nil
|
|
}
|
|
conn.cancel()
|
|
delete(d.conns, event.Container.ID)
|
|
d.propagateStop(event.Container.ID)
|
|
d.wg.Done()
|
|
case dto.EventShutdown:
|
|
for id, conn := range d.conns {
|
|
conn.cancel()
|
|
delete(d.conns, id)
|
|
d.propagateStop(id)
|
|
d.wg.Done()
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (d *SSH) propagateStop(containerID string) {
|
|
d.PushEventStatus(dto.EventStatus{Type: dto.EventStop, ID: containerID})
|
|
}
|
|
|
|
func (d *SSH) localEndpoint(ip net.IP, port uint16) sshtun.Endpoint {
|
|
return sshtun.AddrToEndpoint(net.JoinHostPort(ip.String(), strconv.Itoa(int(port))))
|
|
}
|
|
|
|
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
|
|
port := int(d.params.ForwardPort)
|
|
if port == 0 {
|
|
port = 80
|
|
}
|
|
return sshtun.Endpoint{
|
|
Host: remoteHost,
|
|
Port: port,
|
|
}
|
|
}
|
|
|
|
func (d *SSH) Driver() config.DriverType {
|
|
return config.DriverSSH
|
|
}
|
|
|
|
func (d *SSH) WaitForShutdown() {
|
|
go d.Handle(dto.Event{Type: dto.EventShutdown})
|
|
d.wg.Wait()
|
|
}
|
|
|
|
func (d *SSH) authenticators() []ssh.AuthMethod {
|
|
auth := d.authenticator()
|
|
if auth == nil {
|
|
return nil
|
|
}
|
|
return []ssh.AuthMethod{auth}
|
|
}
|
|
|
|
func (d *SSH) authenticator() ssh.AuthMethod {
|
|
switch d.params.Auth.Type {
|
|
case types.AuthTypePasswordless:
|
|
return sshtun.AuthPassword("")
|
|
case types.AuthTypePassword:
|
|
return sshtun.AuthPassword(d.params.Auth.Password)
|
|
case types.AuthTypeKey:
|
|
if d.params.Auth.Keyfile != "" {
|
|
keyAuth, err := sshtun.AuthKeyFile(
|
|
types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile)))
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return keyAuth
|
|
}
|
|
dirAuth, err := sshtun.AuthKeyDir(d.params.Auth.Directory)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
return dirAuth
|
|
}
|
|
return nil
|
|
}
|