sshpoke/internal/server/driver/ssh/driver.go

157 lines
3.9 KiB
Go

package ssh
import (
"context"
"net"
"path"
"strconv"
"sync"
"github.com/Neur0toxine/sshpoke/internal/config"
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshtun"
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
"github.com/Neur0toxine/sshpoke/pkg/dto"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
)
type SSH struct {
base.Base
params Params
auth []ssh.AuthMethod
conns map[string]conn
rw sync.RWMutex
wg sync.WaitGroup
}
type conn struct {
ctx context.Context
cancel func()
tun *sshtun.Tunnel
}
func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) {
drv := &SSH{
Base: base.New(ctx, name),
conns: make(map[string]conn),
}
if err := util.UnmarshalParams(params, &drv.params); err != nil {
return nil, err
}
drv.populateFromSSHConfig()
drv.auth = drv.authenticators()
return drv, nil
}
func (d *SSH) forward(val sshtun.Forward) conn {
tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.DisableRemoteHostResolve, val, d.auth, d.Log())
ctx, cancel := context.WithCancel(d.Context())
go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String())))
return conn{ctx: ctx, cancel: cancel, tun: tun}
}
func (d *SSH) populateFromSSHConfig() {
if d.params.Auth.Directory == "" {
return
}
cfg, err := parseSSHConfig(types.SmartPath(path.Join(string(d.params.Auth.Directory), "config")))
if err != nil {
return
}
if user, err := cfg.Get(d.params.Address, "User"); err == nil && user != "" {
d.params.Auth.User = user
}
if usePass, err := cfg.Get(d.params.Address, "PasswordAuthentication"); err == nil && usePass == "yes" {
d.params.Auth.Type = types.AuthTypePassword
}
if keyfile, err := cfg.Get(d.params.Address, "IdentityFile"); err == nil && keyfile != "" {
resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false)
if err == nil {
d.params.Auth.Type = types.AuthTypeKey
d.params.Auth.Keyfile = resolvedKeyFile
}
}
}
func (d *SSH) Handle(event dto.Event) error {
defer d.rw.Unlock()
d.rw.Lock()
switch event.Type {
case dto.EventStart:
conn := d.forward(sshtun.Forward{
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
Remote: d.remoteEndpoint(event.Container.RemoteHost),
})
d.conns[event.Container.ID] = conn
d.wg.Add(1)
case dto.EventStop:
conn, found := d.conns[event.Container.ID]
if !found {
return nil
}
conn.cancel()
delete(d.conns, event.Container.ID)
d.wg.Done()
case dto.EventShutdown:
for id, conn := range d.conns {
conn.cancel()
delete(d.conns, id)
d.wg.Done()
}
}
return nil
}
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
port := int(d.params.ForwardPort)
if port == 0 {
port = 80
}
return sshtun.Endpoint{
Host: remoteHost,
Port: port,
}
}
func (d *SSH) Driver() config.DriverType {
return config.DriverSSH
}
func (d *SSH) WaitForShutdown() {
go d.Handle(dto.Event{Type: dto.EventShutdown})
d.wg.Wait()
}
func (d *SSH) authenticators() []ssh.AuthMethod {
auth := d.authenticator()
if auth == nil {
return nil
}
return []ssh.AuthMethod{auth}
}
func (d *SSH) authenticator() ssh.AuthMethod {
switch d.params.Auth.Type {
case types.AuthTypePasswordless:
return sshtun.AuthPassword("")
case types.AuthTypePassword:
return sshtun.AuthPassword(d.params.Auth.Password)
case types.AuthTypeKey:
if d.params.Auth.Keyfile != "" {
keyAuth, err := sshtun.AuthKeyFile(
types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile)))
if err != nil {
return nil
}
return keyAuth
}
dirAuth, err := sshtun.AuthKeyDir(d.params.Auth.Directory)
if err != nil {
return nil
}
return dirAuth
}
return nil
}