package ssh import ( "context" "errors" "fmt" "os" "path" "strings" "sync" "github.com/Neur0toxine/sshpoke/internal/config" "github.com/Neur0toxine/sshpoke/internal/server/driver/base" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" "github.com/Neur0toxine/sshpoke/internal/server/driver/util" "github.com/Neur0toxine/sshpoke/internal/server/proto/sshtun" "github.com/Neur0toxine/sshpoke/pkg/dto" "golang.org/x/crypto/ssh" ) type SSH struct { base.Base params Params sessions map[string]conn keys []ssh.Signer wg sync.WaitGroup } type conn struct { container dto.Container tun *sshtun.Tunnel } func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) { drv := &SSH{ Base: base.New(ctx, name), sessions: make(map[string]conn), } if err := util.UnmarshalParams(params, &drv.params); err != nil { return nil, err } drv.populateFromSSHConfig() if err := drv.parseKeys(); err != nil { return nil, err } return drv, nil } func (d *SSH) populateFromSSHConfig() { if d.params.Auth.Directory == "" { return } cfg, err := parseSSHConfig(types.SmartPath(path.Join(string(d.params.Auth.Directory), "config"))) if err != nil { return } if user, err := cfg.Get(d.params.Address, "User"); err == nil && user != "" { d.params.Auth.User = user } if usePass, err := cfg.Get(d.params.Address, "PasswordAuthentication"); err == nil && usePass == "yes" { d.params.Auth.Type = types.AuthTypePassword } if keyfile, err := cfg.Get(d.params.Address, "IdentityFile"); err == nil && keyfile != "" { resolvedKeyFile, err := types.SmartPath(keyfile).Resolve(false) if err == nil { d.params.Auth.Type = types.AuthTypeKey d.params.Auth.Keyfile = resolvedKeyFile } } } func (d *SSH) Handle(event dto.Event) error { // TODO: Implement event handling & connections management. return errors.New("server handler is not implemented yet") } func (d *SSH) Driver() config.DriverType { return config.DriverSSH } func (d *SSH) WaitForShutdown() { d.wg.Wait() } func (d *SSH) parseKeys() error { if d.params.Auth.Type != types.AuthTypeKey { return nil } dir, err := d.params.Auth.Directory.Resolve(true) if err != nil { return fmt.Errorf("cannot parse keys: %s", err) } if d.params.Auth.Keyfile != "" { key, err := parseKey(path.Join(dir, d.params.Auth.Keyfile)) if err != nil { return err } d.keys = []ssh.Signer{key} return nil } entries, err := os.ReadDir(dir) if err != nil { return fmt.Errorf("cannot read key directory: %s", err) } keys := []ssh.Signer{} for _, entry := range entries { if entry.IsDir() { d.Log().Debugf("skipping '%s' because it's a directory", entry.Name()) continue } info, err := entry.Info() if err != nil { d.Log().Debugf("skipping '%s' because stat failed: %s", entry.Name(), err) continue } if strings.HasSuffix(entry.Name(), ".pub") { d.Log().Debugf("skipping '%s' because it's probably a public key", entry.Name()) continue } if entry.Name() == "config" { d.Log().Debugf("skipping '%s' because it's probably a ssh-config file", entry.Name()) continue } if entry.Name() == "known_hosts" { d.Log().Debugf( "skipping '%s' because it's probably a list of hosts generated by OpenSSH", entry.Name()) continue } // this file is too small to be a private key if info.Size() < 256 { d.Log().Debugf("skipping '%s' because the file is smaller than 256 bytes", entry.Name()) continue } key, err := parseKey(path.Join(dir, entry.Name())) if err != nil { d.Log().Debugf("skipping '%s' because it's probably not a key: %s", entry.Name(), err) continue } d.Log().Debugf("loading key '%s', type: %s", entry.Name(), key.PublicKey().Type()) keys = append(keys, key) } if len(keys) == 0 { return errors.New("no keys in the provided directory") } d.keys = keys return nil } func parseKey(keyFile string) (ssh.Signer, error) { keyData, err := os.ReadFile(keyFile) if err != nil { return nil, err } key, err := ssh.ParsePrivateKey(keyData) if err != nil { return nil, err } return key, nil }