diff --git a/go.mod b/go.mod index a9132ff..8ecb0e3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.21.4 require ( github.com/docker/docker v24.0.7+incompatible github.com/docker/go-connections v0.4.0 + github.com/function61/gokit v0.0.0-20231117065306-355fe206d542 github.com/go-playground/validator/v10 v10.16.0 github.com/kevinburke/ssh_config v1.2.0 github.com/mitchellh/mapstructure v1.5.0 @@ -13,7 +14,7 @@ require ( github.com/spf13/viper v1.17.0 go.uber.org/zap v1.26.0 golang.design/x/lockfree v0.0.1 - golang.org/x/crypto v0.13.0 + golang.org/x/crypto v0.14.0 google.golang.org/grpc v1.58.2 google.golang.org/protobuf v1.31.0 ) @@ -49,8 +50,9 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/mod v0.12.0 // indirect - golang.org/x/net v0.15.0 // indirect - golang.org/x/sys v0.12.0 // indirect + golang.org/x/net v0.17.0 // indirect + golang.org/x/sync v0.3.0 // indirect + golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.13.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect diff --git a/go.sum b/go.sum index 4757b9a..dec80be 100644 --- a/go.sum +++ b/go.sum @@ -77,6 +77,8 @@ github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0X github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= +github.com/function61/gokit v0.0.0-20231117065306-355fe206d542 h1:a9BTN+DOboRkVih0suT4zrRZ4zLGFpBtHPGNd+EQ4pI= +github.com/function61/gokit v0.0.0-20231117065306-355fe206d542/go.mod h1:sJY957+7ush4oj4ElOMhUFaFIriAFNAGYzVh2tFJNy0= github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= @@ -255,8 +257,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4= golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck= -golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc= +golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= +golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= @@ -326,8 +328,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= -golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= -golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM= +golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -385,11 +387,11 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= -golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE= +golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU= -golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU= +golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek= +golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/server/driver/ssh/driver.go b/internal/server/driver/ssh/driver.go index 69b81d7..f37d7e6 100644 --- a/internal/server/driver/ssh/driver.go +++ b/internal/server/driver/ssh/driver.go @@ -3,46 +3,35 @@ package ssh import ( "context" "errors" - "fmt" - "os" "path" - "strings" "sync" "github.com/Neur0toxine/sshpoke/internal/config" "github.com/Neur0toxine/sshpoke/internal/server/driver/base" + "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshproto" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" "github.com/Neur0toxine/sshpoke/internal/server/driver/util" - "github.com/Neur0toxine/sshpoke/internal/server/proto/sshtun" "github.com/Neur0toxine/sshpoke/pkg/dto" "golang.org/x/crypto/ssh" ) type SSH struct { base.Base - params Params - sessions map[string]conn - keys []ssh.Signer - wg sync.WaitGroup -} - -type conn struct { - container dto.Container - tun *sshtun.Tunnel + params Params + proto *sshproto.Client + wg sync.WaitGroup } func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) { drv := &SSH{ - Base: base.New(ctx, name), - sessions: make(map[string]conn), + Base: base.New(ctx, name), } if err := util.UnmarshalParams(params, &drv.params); err != nil { return nil, err } drv.populateFromSSHConfig() - if err := drv.parseKeys(); err != nil { - return nil, err - } + drv.proto = sshproto.New(drv.params.Address, drv.params.Auth.User, drv.authenticators(), drv.Log()) + go drv.proto.Connect(drv.Context()) return drv, nil } @@ -82,78 +71,34 @@ func (d *SSH) WaitForShutdown() { d.wg.Wait() } -func (d *SSH) parseKeys() error { - if d.params.Auth.Type != types.AuthTypeKey { +func (d *SSH) authenticators() []ssh.AuthMethod { + auth := d.authenticator() + if auth == nil { return nil } - dir, err := d.params.Auth.Directory.Resolve(true) - if err != nil { - return fmt.Errorf("cannot parse keys: %s", err) - } - if d.params.Auth.Keyfile != "" { - key, err := parseKey(path.Join(dir, d.params.Auth.Keyfile)) - if err != nil { - return err - } - d.keys = []ssh.Signer{key} - return nil - } - entries, err := os.ReadDir(dir) - if err != nil { - return fmt.Errorf("cannot read key directory: %s", err) - } - keys := []ssh.Signer{} - for _, entry := range entries { - if entry.IsDir() { - d.Log().Debugf("skipping '%s' because it's a directory", entry.Name()) - continue - } - info, err := entry.Info() - if err != nil { - d.Log().Debugf("skipping '%s' because stat failed: %s", entry.Name(), err) - continue - } - if strings.HasSuffix(entry.Name(), ".pub") { - d.Log().Debugf("skipping '%s' because it's probably a public key", entry.Name()) - continue - } - if entry.Name() == "config" { - d.Log().Debugf("skipping '%s' because it's probably a ssh-config file", entry.Name()) - continue - } - if entry.Name() == "known_hosts" { - d.Log().Debugf( - "skipping '%s' because it's probably a list of hosts generated by OpenSSH", entry.Name()) - continue - } - // this file is too small to be a private key - if info.Size() < 256 { - d.Log().Debugf("skipping '%s' because the file is smaller than 256 bytes", entry.Name()) - continue - } - key, err := parseKey(path.Join(dir, entry.Name())) - if err != nil { - d.Log().Debugf("skipping '%s' because it's probably not a key: %s", entry.Name(), err) - continue - } - d.Log().Debugf("loading key '%s', type: %s", entry.Name(), key.PublicKey().Type()) - keys = append(keys, key) - } - if len(keys) == 0 { - return errors.New("no keys in the provided directory") - } - d.keys = keys - return nil + return []ssh.AuthMethod{auth} } -func parseKey(keyFile string) (ssh.Signer, error) { - keyData, err := os.ReadFile(keyFile) - if err != nil { - return nil, err +func (d *SSH) authenticator() ssh.AuthMethod { + switch d.params.Auth.Type { + case types.AuthTypePasswordless: + return sshproto.AuthPassword("") + case types.AuthTypePassword: + return sshproto.AuthPassword(d.params.Auth.Password) + case types.AuthTypeKey: + if d.params.Auth.Keyfile != "" { + keyAuth, err := sshproto.AuthKeyFile( + types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile))) + if err != nil { + return nil + } + return keyAuth + } + dirAuth, err := sshproto.AuthKeyDir(d.params.Auth.Directory) + if err != nil { + return nil + } + return dirAuth } - key, err := ssh.ParsePrivateKey(keyData) - if err != nil { - return nil, err - } - return key, nil + return nil } diff --git a/internal/server/driver/ssh/sshproto/auth.go b/internal/server/driver/ssh/sshproto/auth.go new file mode 100644 index 0000000..edb5bd6 --- /dev/null +++ b/internal/server/driver/ssh/sshproto/auth.go @@ -0,0 +1,91 @@ +package sshproto + +import ( + "errors" + "fmt" + "os" + "path" + "strings" + + "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" + "golang.org/x/crypto/ssh" +) + +func AuthKeyFile(keyFile types.SmartPath) (ssh.AuthMethod, error) { + key, err := readKey(keyFile) + if err != nil { + return nil, err + } + return ssh.PublicKeys(key), nil +} + +func AuthKeyDir(keyDir types.SmartPath) (ssh.AuthMethod, error) { + keys, err := readKeys(keyDir) + if err != nil { + return nil, err + } + return ssh.PublicKeys(keys...), nil +} + +func AuthPassword(password string) ssh.AuthMethod { + return ssh.Password(password) +} + +func readKeys(keyDir types.SmartPath) ([]ssh.Signer, error) { + dir, err := keyDir.Resolve(true) + if err != nil { + return nil, fmt.Errorf("cannot parse keys: %s", err) + } + entries, err := os.ReadDir(dir) + if err != nil { + return nil, fmt.Errorf("cannot read key directory: %s", err) + } + keys := []ssh.Signer{} + for _, entry := range entries { + if entry.IsDir() { + continue + } + info, err := entry.Info() + if err != nil { + continue + } + if strings.HasSuffix(entry.Name(), ".pub") { + continue + } + if entry.Name() == "config" { + continue + } + if entry.Name() == "known_hosts" { + continue + } + // this file is too small to be a private key + if info.Size() < 256 { + continue + } + key, err := readKey(types.SmartPath(path.Join(dir, entry.Name()))) + if err != nil { + continue + } + keys = append(keys, key) + } + if len(keys) == 0 { + return nil, errors.New("no keys in the provided directory") + } + return keys, nil +} + +func readKey(keyFile types.SmartPath) (ssh.Signer, error) { + fileName, err := keyFile.Resolve(false) + if err != nil { + return nil, err + } + keyData, err := os.ReadFile(fileName) + if err != nil { + return nil, err + } + key, err := ssh.ParsePrivateKey(keyData) + if err != nil { + return nil, err + } + return key, nil +} diff --git a/internal/server/driver/ssh/sshproto/config.go b/internal/server/driver/ssh/sshproto/config.go new file mode 100644 index 0000000..b4b987a --- /dev/null +++ b/internal/server/driver/ssh/sshproto/config.go @@ -0,0 +1,21 @@ +package sshproto + +import ( + "bytes" + "os" + + "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" + "github.com/kevinburke/ssh_config" +) + +func parseSSHConfig(filePath types.SmartPath) (*ssh_config.Config, error) { + fileName, err := filePath.Resolve(false) + if err != nil { + return nil, err + } + file, err := os.ReadFile(fileName) + if err != nil { + return nil, err + } + return ssh_config.Decode(bytes.NewReader(file)) +} diff --git a/internal/server/driver/ssh/sshproto/forward.go b/internal/server/driver/ssh/sshproto/forward.go new file mode 100644 index 0000000..7b5040a --- /dev/null +++ b/internal/server/driver/ssh/sshproto/forward.go @@ -0,0 +1,19 @@ +package sshproto + +import "fmt" + +type Forward struct { + // local service to be forwarded + Local Endpoint `json:"local"` + // remote forwarding port (on remote SSH server network) + Remote Endpoint `json:"remote"` +} + +type Endpoint struct { + Host string `json:"host"` + Port int `json:"port"` +} + +func (endpoint *Endpoint) String() string { + return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) +} diff --git a/internal/server/driver/ssh/sshproto/ssh.go b/internal/server/driver/ssh/sshproto/ssh.go new file mode 100644 index 0000000..7e7dce6 --- /dev/null +++ b/internal/server/driver/ssh/sshproto/ssh.go @@ -0,0 +1,242 @@ +package sshproto + +import ( + "context" + "fmt" + "net" + "sync/atomic" + "time" + + "github.com/Neur0toxine/sshpoke/pkg/errtools" + "github.com/function61/gokit/app/backoff" + "github.com/function61/gokit/io/bidipipe" + "go.uber.org/zap" + "golang.org/x/crypto/ssh" +) + +type Client struct { + user string + address string + authMethods []ssh.AuthMethod + log *zap.SugaredLogger + connected atomic.Bool +} + +func New(address, user string, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Client { + return &Client{ + address: prepareAddress(address), + user: user, + authMethods: auth, + log: log.With(zap.String("sshServer", address)), + } +} + +func prepareAddress(address string) string { + _, _, err := net.SplitHostPort(address) + if err != nil && errtools.IsPortMissingErr(err) { + return net.JoinHostPort(address, "22") + } + return address +} + +func (c *Client) Connect(ctx context.Context) { + if c.connected.Load() { + return + } + + defer c.connected.Store(false) + backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) + for { + c.connected.Store(true) + err := c.connect(ctx) + if err != nil { + c.log.Error("connect error:", err) + } + + select { + case <-ctx.Done(): + return + default: + } + + time.Sleep(backoffTime()) + } +} + +// connect once to the SSH server. if the connection breaks, we return error and the caller +// will try to re-connect +func (c *Client) connect(ctx context.Context) error { + c.log.Debug("connecting") + sshConfig := &ssh.ClientConfig{ + User: c.user, + Auth: c.authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + var sshClient *ssh.Client + var errConnect error + + sshClient, errConnect = dialSSH(ctx, c.address, sshConfig) + if errConnect != nil { + return errConnect + } + + // always disconnect when function returns + defer sshClient.Close() + defer c.log.Debug("disconnecting") + + c.log.Debug("connected") + + <-ctx.Done() + return nil +} + +// // connect once to the SSH server. if the connection breaks, we return error and the caller +// // will try to re-connect +// func connectToSshAndServe2( +// ctx context.Context, +// address string, +// authConfig types.Auth, +// auth ssh.AuthMethod, +// log *zap.SugaredLogger, +// ) error { +// log = log.With(zap.String("sshServer", address)) +// log.Debug("connecting") +// sshConfig := &ssh.ClientConfig{ +// User: authConfig.User, +// Auth: []ssh.AuthMethod{auth}, +// HostKeyCallback: ssh.InsecureIgnoreHostKey(), +// } +// +// var sshClient *ssh.Client +// var errConnect error +// +// sshClient, errConnect = dialSSH(ctx, address, sshConfig) +// if errConnect != nil { +// return errConnect +// } +// +// // always disconnect when function returns +// defer sshClient.Close() +// defer log.Debug("disconnecting") +// +// log.Debug("connected") +// +// // each listener in reverseForwardOnePort() can return one error, so make sure channel +// // has enough buffering so even if we return from here, the goroutines won't block trying +// // to return an error +// listenerStopped := make(chan error, len(conf.Forwards)) +// +// for _, forward := range conf.Forwards { +// // TODO: "if any fails, tear down all workers" -style error handling would be better +// // handled with https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc +// if err := reverseForwardOnePort( +// forward, +// sshClient, +// listenerStopped, +// makeLogger("reverseForwardOnePort"), +// makeLogger, +// ); err != nil { +// // closes SSH connection if even one forward Listen() fails +// return err +// } +// } +// +// // we're connected and have succesfully started listening on all reverse forwards, wait +// // for either user to ask us to stop or any of the listeners to error +// select { +// case <-ctx.Done(): // cancel requested +// return nil +// case listenerFirstErr := <-listenerStopped: +// // one or more of the listeners encountered an error. react by closing the connection +// // assumes all the other listeners failed too so no teardown necessary +// select { +// case <-ctx.Done(): // pretty much errors are to be expected if cancellation triggered +// return nil +// default: +// return listenerFirstErr +// } +// } +// } + +// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error +// +// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped +func reverseForwardOnePort( + forward Forward, + sshClient *ssh.Client, + listenerStopped chan<- error, + log *zap.SugaredLogger, +) error { + // reverse listen on remote server port + listener, err := sshClient.Listen("tcp", forward.Remote.String()) + if err != nil { + return err + } + + go func() { + defer listener.Close() + log.Debugf("forwarding %s -> %s", forward.Local.String(), forward.Remote.String()) + + // handle incoming connections on reverse forwarded tunnel + for { + client, err := listener.Accept() + if err != nil { + listenerStopped <- fmt.Errorf("error on Accept(): %w", err) + return + } + + // handle the connection in another goroutine, so we can support multiple concurrent + // connections on the same port + go handleReverseForwardConn(client, forward, log) + } + }() + + return nil +} + +func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) { + defer client.Close() + + log.Debugf("%s connected", client.RemoteAddr()) + defer log.Debug("closed") + + remote, err := net.Dial("tcp", forward.Local.String()) + if err != nil { + log.Errorf("dial INTO local service error: %s", err.Error()) + return + } + + // pipe data in both directions: + // - client => remote + // - remote => client + // + // - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed + // remote endpoint. + // - the "client" and "remote" strings we give Pipe() is just for error&log messages + // - this blocks until either of the parties' socket closes (or breaks) + if err := bidipipe.Pipe( + bidipipe.WithName("client", client), + bidipipe.WithName("remote", remote), + ); err != nil { + log.Error(err) + } +} + +func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { + dialer := net.Dialer{ + Timeout: 10 * time.Second, + } + + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig) + if err != nil { + return nil, err + } + + return ssh.NewClient(clConn, newChan, reqChan), nil +} diff --git a/internal/server/driver/ssh/types/auth.go b/internal/server/driver/ssh/types/auth.go index 96b9733..293e722 100644 --- a/internal/server/driver/ssh/types/auth.go +++ b/internal/server/driver/ssh/types/auth.go @@ -62,6 +62,10 @@ func (k SmartPath) Resolve(shouldBeDirectory bool) (result string, err error) { return } +func (k SmartPath) String() string { + return string(k) +} + func (a Auth) Validate() error { if a.Type == AuthTypePassword && a.Password == "" { return fmt.Errorf("password must be provided for authentication type '%s'", AuthTypePassword) diff --git a/pkg/errtools/IsPortMissingErr.go b/pkg/errtools/IsPortMissingErr.go new file mode 100644 index 0000000..67287f6 --- /dev/null +++ b/pkg/errtools/IsPortMissingErr.go @@ -0,0 +1,18 @@ +package errtools + +import ( + "errors" + "strings" +) + +func IsPortMissingErr(err error) bool { + for { + if err == nil { + return false + } + if strings.Contains(err.Error(), "missing port in address") { + return true + } + err = errors.Unwrap(err) + } +} diff --git a/pkg/plugin/client.go b/pkg/plugin/client.go index d24be02..b2dd345 100644 --- a/pkg/plugin/client.go +++ b/pkg/plugin/client.go @@ -9,6 +9,7 @@ import ( "github.com/Neur0toxine/sshpoke/pkg/convert" "github.com/Neur0toxine/sshpoke/pkg/dto" + "github.com/Neur0toxine/sshpoke/pkg/errtools" "github.com/Neur0toxine/sshpoke/pkg/plugin/pb" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" @@ -57,12 +58,12 @@ func normalizeAddr(addr string) string { if strings.HasPrefix(addr, "grpc://") { addr = addr[7:] } - host, port, err := net.SplitHostPort(addr) - if err != nil && err.Error() == "missing port in address" { - host, port, err = net.SplitHostPort(addr + ":" + strconv.Itoa(DefaultPort)) + _, _, err := net.SplitHostPort(addr) + if err != nil && errtools.IsPortMissingErr(err) { + addr = net.JoinHostPort(addr, strconv.Itoa(DefaultPort)) } if err != nil { return "" } - return host + ":" + port + return addr }