ssh connection
This commit is contained in:
parent
d3029d09e7
commit
401d9123c8
8
go.mod
8
go.mod
@ -5,6 +5,7 @@ go 1.21.4
|
||||
require (
|
||||
github.com/docker/docker v24.0.7+incompatible
|
||||
github.com/docker/go-connections v0.4.0
|
||||
github.com/function61/gokit v0.0.0-20231117065306-355fe206d542
|
||||
github.com/go-playground/validator/v10 v10.16.0
|
||||
github.com/kevinburke/ssh_config v1.2.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
@ -13,7 +14,7 @@ require (
|
||||
github.com/spf13/viper v1.17.0
|
||||
go.uber.org/zap v1.26.0
|
||||
golang.design/x/lockfree v0.0.1
|
||||
golang.org/x/crypto v0.13.0
|
||||
golang.org/x/crypto v0.14.0
|
||||
google.golang.org/grpc v1.58.2
|
||||
google.golang.org/protobuf v1.31.0
|
||||
)
|
||||
@ -49,8 +50,9 @@ require (
|
||||
go.uber.org/multierr v1.11.0 // indirect
|
||||
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
|
||||
golang.org/x/mod v0.12.0 // indirect
|
||||
golang.org/x/net v0.15.0 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sync v0.3.0 // indirect
|
||||
golang.org/x/sys v0.13.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/tools v0.13.0 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20230920204549-e6e6cdab5c13 // indirect
|
||||
|
18
go.sum
18
go.sum
@ -77,6 +77,8 @@ github.com/frankban/quicktest v1.14.4 h1:g2rn0vABPOOXmZUj+vbmUp0lPoXEMuhTpIluN0X
|
||||
github.com/frankban/quicktest v1.14.4/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0=
|
||||
github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY=
|
||||
github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw=
|
||||
github.com/function61/gokit v0.0.0-20231117065306-355fe206d542 h1:a9BTN+DOboRkVih0suT4zrRZ4zLGFpBtHPGNd+EQ4pI=
|
||||
github.com/function61/gokit v0.0.0-20231117065306-355fe206d542/go.mod h1:sJY957+7ush4oj4ElOMhUFaFIriAFNAGYzVh2tFJNy0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU=
|
||||
github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA=
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
@ -255,8 +257,8 @@ golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8U
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
|
||||
golang.org/x/crypto v0.0.0-20220722155217-630584e8d5aa/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/crypto v0.13.0 h1:mvySKfSWJ+UKUii46M40LOvyWfN0s2U+46/jDd0e6Ck=
|
||||
golang.org/x/crypto v0.13.0/go.mod h1:y6Z2r+Rw4iayiXXAIxJIDAJ1zMW4yaTpebo8fPOliYc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||
@ -326,8 +328,8 @@ golang.org/x/net v0.0.0-20201209123823-ac852fbbde11/go.mod h1:m0MpNAwzfU5UDzcl9v
|
||||
golang.org/x/net v0.0.0-20201224014010-6772e930b67b/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8=
|
||||
golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
@ -385,11 +387,11 @@ golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.12.0 h1:/ZfYdc3zq+q02Rv9vGqTeSItdzZTSNDmfTi0mBAuidU=
|
||||
golang.org/x/term v0.12.0/go.mod h1:owVbMEjm3cBLCHdkQu9b1opXd4ETQWc3BhuQGKgXgvU=
|
||||
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
|
@ -3,46 +3,35 @@ package ssh
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/internal/config"
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshproto"
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/proto/sshtun"
|
||||
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type SSH struct {
|
||||
base.Base
|
||||
params Params
|
||||
sessions map[string]conn
|
||||
keys []ssh.Signer
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
type conn struct {
|
||||
container dto.Container
|
||||
tun *sshtun.Tunnel
|
||||
params Params
|
||||
proto *sshproto.Client
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) {
|
||||
drv := &SSH{
|
||||
Base: base.New(ctx, name),
|
||||
sessions: make(map[string]conn),
|
||||
Base: base.New(ctx, name),
|
||||
}
|
||||
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
drv.populateFromSSHConfig()
|
||||
if err := drv.parseKeys(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
drv.proto = sshproto.New(drv.params.Address, drv.params.Auth.User, drv.authenticators(), drv.Log())
|
||||
go drv.proto.Connect(drv.Context())
|
||||
return drv, nil
|
||||
}
|
||||
|
||||
@ -82,78 +71,34 @@ func (d *SSH) WaitForShutdown() {
|
||||
d.wg.Wait()
|
||||
}
|
||||
|
||||
func (d *SSH) parseKeys() error {
|
||||
if d.params.Auth.Type != types.AuthTypeKey {
|
||||
func (d *SSH) authenticators() []ssh.AuthMethod {
|
||||
auth := d.authenticator()
|
||||
if auth == nil {
|
||||
return nil
|
||||
}
|
||||
dir, err := d.params.Auth.Directory.Resolve(true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot parse keys: %s", err)
|
||||
}
|
||||
if d.params.Auth.Keyfile != "" {
|
||||
key, err := parseKey(path.Join(dir, d.params.Auth.Keyfile))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
d.keys = []ssh.Signer{key}
|
||||
return nil
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("cannot read key directory: %s", err)
|
||||
}
|
||||
keys := []ssh.Signer{}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
d.Log().Debugf("skipping '%s' because it's a directory", entry.Name())
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
d.Log().Debugf("skipping '%s' because stat failed: %s", entry.Name(), err)
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(entry.Name(), ".pub") {
|
||||
d.Log().Debugf("skipping '%s' because it's probably a public key", entry.Name())
|
||||
continue
|
||||
}
|
||||
if entry.Name() == "config" {
|
||||
d.Log().Debugf("skipping '%s' because it's probably a ssh-config file", entry.Name())
|
||||
continue
|
||||
}
|
||||
if entry.Name() == "known_hosts" {
|
||||
d.Log().Debugf(
|
||||
"skipping '%s' because it's probably a list of hosts generated by OpenSSH", entry.Name())
|
||||
continue
|
||||
}
|
||||
// this file is too small to be a private key
|
||||
if info.Size() < 256 {
|
||||
d.Log().Debugf("skipping '%s' because the file is smaller than 256 bytes", entry.Name())
|
||||
continue
|
||||
}
|
||||
key, err := parseKey(path.Join(dir, entry.Name()))
|
||||
if err != nil {
|
||||
d.Log().Debugf("skipping '%s' because it's probably not a key: %s", entry.Name(), err)
|
||||
continue
|
||||
}
|
||||
d.Log().Debugf("loading key '%s', type: %s", entry.Name(), key.PublicKey().Type())
|
||||
keys = append(keys, key)
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return errors.New("no keys in the provided directory")
|
||||
}
|
||||
d.keys = keys
|
||||
return nil
|
||||
return []ssh.AuthMethod{auth}
|
||||
}
|
||||
|
||||
func parseKey(keyFile string) (ssh.Signer, error) {
|
||||
keyData, err := os.ReadFile(keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func (d *SSH) authenticator() ssh.AuthMethod {
|
||||
switch d.params.Auth.Type {
|
||||
case types.AuthTypePasswordless:
|
||||
return sshproto.AuthPassword("")
|
||||
case types.AuthTypePassword:
|
||||
return sshproto.AuthPassword(d.params.Auth.Password)
|
||||
case types.AuthTypeKey:
|
||||
if d.params.Auth.Keyfile != "" {
|
||||
keyAuth, err := sshproto.AuthKeyFile(
|
||||
types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile)))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return keyAuth
|
||||
}
|
||||
dirAuth, err := sshproto.AuthKeyDir(d.params.Auth.Directory)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return dirAuth
|
||||
}
|
||||
key, err := ssh.ParsePrivateKey(keyData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
return nil
|
||||
}
|
||||
|
91
internal/server/driver/ssh/sshproto/auth.go
Normal file
91
internal/server/driver/ssh/sshproto/auth.go
Normal file
@ -0,0 +1,91 @@
|
||||
package sshproto
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
func AuthKeyFile(keyFile types.SmartPath) (ssh.AuthMethod, error) {
|
||||
key, err := readKey(keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(key), nil
|
||||
}
|
||||
|
||||
func AuthKeyDir(keyDir types.SmartPath) (ssh.AuthMethod, error) {
|
||||
keys, err := readKeys(keyDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh.PublicKeys(keys...), nil
|
||||
}
|
||||
|
||||
func AuthPassword(password string) ssh.AuthMethod {
|
||||
return ssh.Password(password)
|
||||
}
|
||||
|
||||
func readKeys(keyDir types.SmartPath) ([]ssh.Signer, error) {
|
||||
dir, err := keyDir.Resolve(true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot parse keys: %s", err)
|
||||
}
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cannot read key directory: %s", err)
|
||||
}
|
||||
keys := []ssh.Signer{}
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
info, err := entry.Info()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(entry.Name(), ".pub") {
|
||||
continue
|
||||
}
|
||||
if entry.Name() == "config" {
|
||||
continue
|
||||
}
|
||||
if entry.Name() == "known_hosts" {
|
||||
continue
|
||||
}
|
||||
// this file is too small to be a private key
|
||||
if info.Size() < 256 {
|
||||
continue
|
||||
}
|
||||
key, err := readKey(types.SmartPath(path.Join(dir, entry.Name())))
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
if len(keys) == 0 {
|
||||
return nil, errors.New("no keys in the provided directory")
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func readKey(keyFile types.SmartPath) (ssh.Signer, error) {
|
||||
fileName, err := keyFile.Resolve(false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyData, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
key, err := ssh.ParsePrivateKey(keyData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return key, nil
|
||||
}
|
21
internal/server/driver/ssh/sshproto/config.go
Normal file
21
internal/server/driver/ssh/sshproto/config.go
Normal file
@ -0,0 +1,21 @@
|
||||
package sshproto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
||||
"github.com/kevinburke/ssh_config"
|
||||
)
|
||||
|
||||
func parseSSHConfig(filePath types.SmartPath) (*ssh_config.Config, error) {
|
||||
fileName, err := filePath.Resolve(false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ssh_config.Decode(bytes.NewReader(file))
|
||||
}
|
19
internal/server/driver/ssh/sshproto/forward.go
Normal file
19
internal/server/driver/ssh/sshproto/forward.go
Normal file
@ -0,0 +1,19 @@
|
||||
package sshproto
|
||||
|
||||
import "fmt"
|
||||
|
||||
type Forward struct {
|
||||
// local service to be forwarded
|
||||
Local Endpoint `json:"local"`
|
||||
// remote forwarding port (on remote SSH server network)
|
||||
Remote Endpoint `json:"remote"`
|
||||
}
|
||||
|
||||
type Endpoint struct {
|
||||
Host string `json:"host"`
|
||||
Port int `json:"port"`
|
||||
}
|
||||
|
||||
func (endpoint *Endpoint) String() string {
|
||||
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
||||
}
|
242
internal/server/driver/ssh/sshproto/ssh.go
Normal file
242
internal/server/driver/ssh/sshproto/ssh.go
Normal file
@ -0,0 +1,242 @@
|
||||
package sshproto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
||||
"github.com/function61/gokit/app/backoff"
|
||||
"github.com/function61/gokit/io/bidipipe"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
user string
|
||||
address string
|
||||
authMethods []ssh.AuthMethod
|
||||
log *zap.SugaredLogger
|
||||
connected atomic.Bool
|
||||
}
|
||||
|
||||
func New(address, user string, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Client {
|
||||
return &Client{
|
||||
address: prepareAddress(address),
|
||||
user: user,
|
||||
authMethods: auth,
|
||||
log: log.With(zap.String("sshServer", address)),
|
||||
}
|
||||
}
|
||||
|
||||
func prepareAddress(address string) string {
|
||||
_, _, err := net.SplitHostPort(address)
|
||||
if err != nil && errtools.IsPortMissingErr(err) {
|
||||
return net.JoinHostPort(address, "22")
|
||||
}
|
||||
return address
|
||||
}
|
||||
|
||||
func (c *Client) Connect(ctx context.Context) {
|
||||
if c.connected.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
defer c.connected.Store(false)
|
||||
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
||||
for {
|
||||
c.connected.Store(true)
|
||||
err := c.connect(ctx)
|
||||
if err != nil {
|
||||
c.log.Error("connect error:", err)
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
time.Sleep(backoffTime())
|
||||
}
|
||||
}
|
||||
|
||||
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
||||
// will try to re-connect
|
||||
func (c *Client) connect(ctx context.Context) error {
|
||||
c.log.Debug("connecting")
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: c.user,
|
||||
Auth: c.authMethods,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
}
|
||||
|
||||
var sshClient *ssh.Client
|
||||
var errConnect error
|
||||
|
||||
sshClient, errConnect = dialSSH(ctx, c.address, sshConfig)
|
||||
if errConnect != nil {
|
||||
return errConnect
|
||||
}
|
||||
|
||||
// always disconnect when function returns
|
||||
defer sshClient.Close()
|
||||
defer c.log.Debug("disconnecting")
|
||||
|
||||
c.log.Debug("connected")
|
||||
|
||||
<-ctx.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
// // connect once to the SSH server. if the connection breaks, we return error and the caller
|
||||
// // will try to re-connect
|
||||
// func connectToSshAndServe2(
|
||||
// ctx context.Context,
|
||||
// address string,
|
||||
// authConfig types.Auth,
|
||||
// auth ssh.AuthMethod,
|
||||
// log *zap.SugaredLogger,
|
||||
// ) error {
|
||||
// log = log.With(zap.String("sshServer", address))
|
||||
// log.Debug("connecting")
|
||||
// sshConfig := &ssh.ClientConfig{
|
||||
// User: authConfig.User,
|
||||
// Auth: []ssh.AuthMethod{auth},
|
||||
// HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
// }
|
||||
//
|
||||
// var sshClient *ssh.Client
|
||||
// var errConnect error
|
||||
//
|
||||
// sshClient, errConnect = dialSSH(ctx, address, sshConfig)
|
||||
// if errConnect != nil {
|
||||
// return errConnect
|
||||
// }
|
||||
//
|
||||
// // always disconnect when function returns
|
||||
// defer sshClient.Close()
|
||||
// defer log.Debug("disconnecting")
|
||||
//
|
||||
// log.Debug("connected")
|
||||
//
|
||||
// // each listener in reverseForwardOnePort() can return one error, so make sure channel
|
||||
// // has enough buffering so even if we return from here, the goroutines won't block trying
|
||||
// // to return an error
|
||||
// listenerStopped := make(chan error, len(conf.Forwards))
|
||||
//
|
||||
// for _, forward := range conf.Forwards {
|
||||
// // TODO: "if any fails, tear down all workers" -style error handling would be better
|
||||
// // handled with https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc
|
||||
// if err := reverseForwardOnePort(
|
||||
// forward,
|
||||
// sshClient,
|
||||
// listenerStopped,
|
||||
// makeLogger("reverseForwardOnePort"),
|
||||
// makeLogger,
|
||||
// ); err != nil {
|
||||
// // closes SSH connection if even one forward Listen() fails
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // we're connected and have succesfully started listening on all reverse forwards, wait
|
||||
// // for either user to ask us to stop or any of the listeners to error
|
||||
// select {
|
||||
// case <-ctx.Done(): // cancel requested
|
||||
// return nil
|
||||
// case listenerFirstErr := <-listenerStopped:
|
||||
// // one or more of the listeners encountered an error. react by closing the connection
|
||||
// // assumes all the other listeners failed too so no teardown necessary
|
||||
// select {
|
||||
// case <-ctx.Done(): // pretty much errors are to be expected if cancellation triggered
|
||||
// return nil
|
||||
// default:
|
||||
// return listenerFirstErr
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
||||
//
|
||||
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
||||
func reverseForwardOnePort(
|
||||
forward Forward,
|
||||
sshClient *ssh.Client,
|
||||
listenerStopped chan<- error,
|
||||
log *zap.SugaredLogger,
|
||||
) error {
|
||||
// reverse listen on remote server port
|
||||
listener, err := sshClient.Listen("tcp", forward.Remote.String())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
log.Debugf("forwarding %s -> %s", forward.Local.String(), forward.Remote.String())
|
||||
|
||||
// handle incoming connections on reverse forwarded tunnel
|
||||
for {
|
||||
client, err := listener.Accept()
|
||||
if err != nil {
|
||||
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// handle the connection in another goroutine, so we can support multiple concurrent
|
||||
// connections on the same port
|
||||
go handleReverseForwardConn(client, forward, log)
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
||||
defer client.Close()
|
||||
|
||||
log.Debugf("%s connected", client.RemoteAddr())
|
||||
defer log.Debug("closed")
|
||||
|
||||
remote, err := net.Dial("tcp", forward.Local.String())
|
||||
if err != nil {
|
||||
log.Errorf("dial INTO local service error: %s", err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// pipe data in both directions:
|
||||
// - client => remote
|
||||
// - remote => client
|
||||
//
|
||||
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
||||
// remote endpoint.
|
||||
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
||||
// - this blocks until either of the parties' socket closes (or breaks)
|
||||
if err := bidipipe.Pipe(
|
||||
bidipipe.WithName("client", client),
|
||||
bidipipe.WithName("remote", remote),
|
||||
); err != nil {
|
||||
log.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
||||
dialer := net.Dialer{
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ssh.NewClient(clConn, newChan, reqChan), nil
|
||||
}
|
@ -62,6 +62,10 @@ func (k SmartPath) Resolve(shouldBeDirectory bool) (result string, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (k SmartPath) String() string {
|
||||
return string(k)
|
||||
}
|
||||
|
||||
func (a Auth) Validate() error {
|
||||
if a.Type == AuthTypePassword && a.Password == "" {
|
||||
return fmt.Errorf("password must be provided for authentication type '%s'", AuthTypePassword)
|
||||
|
18
pkg/errtools/IsPortMissingErr.go
Normal file
18
pkg/errtools/IsPortMissingErr.go
Normal file
@ -0,0 +1,18 @@
|
||||
package errtools
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func IsPortMissingErr(err error) bool {
|
||||
for {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if strings.Contains(err.Error(), "missing port in address") {
|
||||
return true
|
||||
}
|
||||
err = errors.Unwrap(err)
|
||||
}
|
||||
}
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/Neur0toxine/sshpoke/pkg/convert"
|
||||
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
||||
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
||||
"github.com/Neur0toxine/sshpoke/pkg/plugin/pb"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
@ -57,12 +58,12 @@ func normalizeAddr(addr string) string {
|
||||
if strings.HasPrefix(addr, "grpc://") {
|
||||
addr = addr[7:]
|
||||
}
|
||||
host, port, err := net.SplitHostPort(addr)
|
||||
if err != nil && err.Error() == "missing port in address" {
|
||||
host, port, err = net.SplitHostPort(addr + ":" + strconv.Itoa(DefaultPort))
|
||||
_, _, err := net.SplitHostPort(addr)
|
||||
if err != nil && errtools.IsPortMissingErr(err) {
|
||||
addr = net.JoinHostPort(addr, strconv.Itoa(DefaultPort))
|
||||
}
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return host + ":" + port
|
||||
return addr
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user