From 11a5f48d682ef88d9fcbd183cb1a95cd09d50bdc Mon Sep 17 00:00:00 2001 From: Neur0toxine Date: Sat, 18 Nov 2023 21:23:29 +0300 Subject: [PATCH] sish support (wip) --- internal/docker/api.go | 2 +- internal/docker/convert.go | 22 +- internal/server/driver/ssh/driver.go | 76 ++- internal/server/driver/ssh/params.go | 18 +- .../server/driver/ssh/sshproto/forward.go | 19 - internal/server/driver/ssh/sshproto/ssh.go | 242 --------- .../driver/ssh/{sshproto => sshtun}/auth.go | 2 +- .../driver/ssh/{sshproto => sshtun}/config.go | 2 +- internal/server/driver/ssh/sshtun/forward.go | 37 ++ internal/server/driver/ssh/sshtun/printer.go | 21 + .../server/driver/ssh/sshtun/sish_compat.go | 76 +++ internal/server/driver/ssh/sshtun/ssh.go | 210 ++++++++ internal/server/proto/sshtun/tunnel.go | 249 --------- internal/server/proto/sshtun/tunnel_test.go | 509 ------------------ pkg/convert/convert.go | 28 +- pkg/dto/models.go | 14 +- ...rtMissingErr.go => is_port_missing_err.go} | 0 pkg/plugin/pb/pb.proto | 2 +- 18 files changed, 454 insertions(+), 1075 deletions(-) delete mode 100644 internal/server/driver/ssh/sshproto/forward.go delete mode 100644 internal/server/driver/ssh/sshproto/ssh.go rename internal/server/driver/ssh/{sshproto => sshtun}/auth.go (99%) rename internal/server/driver/ssh/{sshproto => sshtun}/config.go (96%) create mode 100644 internal/server/driver/ssh/sshtun/forward.go create mode 100644 internal/server/driver/ssh/sshtun/printer.go create mode 100644 internal/server/driver/ssh/sshtun/sish_compat.go create mode 100644 internal/server/driver/ssh/sshtun/ssh.go delete mode 100644 internal/server/proto/sshtun/tunnel.go delete mode 100644 internal/server/proto/sshtun/tunnel_test.go rename pkg/errtools/{IsPortMissingErr.go => is_port_missing_err.go} (100%) diff --git a/internal/docker/api.go b/internal/docker/api.go index a69c3cc..8ce2fb7 100644 --- a/internal/docker/api.go +++ b/internal/docker/api.go @@ -110,7 +110,7 @@ func (d *Docker) Listen() (chan dto.Event, error) { "container.ip", converted.IP.String(), "container.port", converted.Port, "container.server", converted.Server, - "container.prefix", converted.Prefix) + "container.remote_host", converted.RemoteHost) output <- newEvent case err := <-errSource: if errors.Is(err, context.Canceled) { diff --git a/internal/docker/convert.go b/internal/docker/convert.go index b6cfb6c..baa301a 100644 --- a/internal/docker/convert.go +++ b/internal/docker/convert.go @@ -13,11 +13,11 @@ import ( ) type labelsConfig struct { - Enable boolStr `mapstructure:"sshpoke.enable"` - Network string `mapstructure:"sshpoke.network"` - Server string `mapstructure:"sshpoke.server"` - Port string `mapstructure:"sshpoke.port"` - Prefix string `mapstructure:"sshpoke.prefix"` + Enable boolStr `mapstructure:"sshpoke.enable"` + Network string `mapstructure:"sshpoke.network"` + Server string `mapstructure:"sshpoke.server"` + Port string `mapstructure:"sshpoke.port"` + RemoteHost string `mapstructure:"sshpoke.remote_host"` } type boolStr string @@ -77,12 +77,12 @@ func dockerContainerToInternal(container types.Container) (result dto.Container, } return dto.Container{ - ID: container.ID, - Names: container.Names, - IP: ip, - Port: uint16(port), - Server: labels.Server, - Prefix: labels.Prefix, + ID: container.ID, + Names: container.Names, + IP: ip, + Port: uint16(port), + Server: labels.Server, + RemoteHost: labels.RemoteHost, }, true } diff --git a/internal/server/driver/ssh/driver.go b/internal/server/driver/ssh/driver.go index f37d7e6..220ff43 100644 --- a/internal/server/driver/ssh/driver.go +++ b/internal/server/driver/ssh/driver.go @@ -2,13 +2,14 @@ package ssh import ( "context" - "errors" + "net" "path" + "strconv" "sync" "github.com/Neur0toxine/sshpoke/internal/config" "github.com/Neur0toxine/sshpoke/internal/server/driver/base" - "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshproto" + "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshtun" "github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types" "github.com/Neur0toxine/sshpoke/internal/server/driver/util" "github.com/Neur0toxine/sshpoke/pkg/dto" @@ -18,23 +19,38 @@ import ( type SSH struct { base.Base params Params - proto *sshproto.Client + auth []ssh.AuthMethod + conns map[string]conn + rw sync.RWMutex wg sync.WaitGroup } +type conn struct { + ctx context.Context + cancel func() + tun *sshtun.Tunnel +} + func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) { drv := &SSH{ - Base: base.New(ctx, name), + Base: base.New(ctx, name), + conns: make(map[string]conn), } if err := util.UnmarshalParams(params, &drv.params); err != nil { return nil, err } drv.populateFromSSHConfig() - drv.proto = sshproto.New(drv.params.Address, drv.params.Auth.User, drv.authenticators(), drv.Log()) - go drv.proto.Connect(drv.Context()) + drv.auth = drv.authenticators() return drv, nil } +func (d *SSH) forward(val sshtun.Forward) conn { + tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.DisableRemoteHostResolve, val, d.auth, d.Log()) + ctx, cancel := context.WithCancel(d.Context()) + go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String()))) + return conn{ctx: ctx, cancel: cancel, tun: tun} +} + func (d *SSH) populateFromSSHConfig() { if d.params.Auth.Directory == "" { return @@ -59,8 +75,43 @@ func (d *SSH) populateFromSSHConfig() { } func (d *SSH) Handle(event dto.Event) error { - // TODO: Implement event handling & connections management. - return errors.New("server handler is not implemented yet") + defer d.rw.Unlock() + d.rw.Lock() + switch event.Type { + case dto.EventStart: + conn := d.forward(sshtun.Forward{ + Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))), + Remote: d.remoteEndpoint(event.Container.RemoteHost), + }) + d.conns[event.Container.ID] = conn + d.wg.Add(1) + case dto.EventStop: + conn, found := d.conns[event.Container.ID] + if !found { + return nil + } + conn.cancel() + delete(d.conns, event.Container.ID) + d.wg.Done() + case dto.EventShutdown: + for id, conn := range d.conns { + conn.cancel() + delete(d.conns, id) + d.wg.Done() + } + } + return nil +} + +func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint { + port := int(d.params.ForwardPort) + if port == 0 { + port = 80 + } + return sshtun.Endpoint{ + Host: remoteHost, + Port: port, + } } func (d *SSH) Driver() config.DriverType { @@ -68,6 +119,7 @@ func (d *SSH) Driver() config.DriverType { } func (d *SSH) WaitForShutdown() { + go d.Handle(dto.Event{Type: dto.EventShutdown}) d.wg.Wait() } @@ -82,19 +134,19 @@ func (d *SSH) authenticators() []ssh.AuthMethod { func (d *SSH) authenticator() ssh.AuthMethod { switch d.params.Auth.Type { case types.AuthTypePasswordless: - return sshproto.AuthPassword("") + return sshtun.AuthPassword("") case types.AuthTypePassword: - return sshproto.AuthPassword(d.params.Auth.Password) + return sshtun.AuthPassword(d.params.Auth.Password) case types.AuthTypeKey: if d.params.Auth.Keyfile != "" { - keyAuth, err := sshproto.AuthKeyFile( + keyAuth, err := sshtun.AuthKeyFile( types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile))) if err != nil { return nil } return keyAuth } - dirAuth, err := sshproto.AuthKeyDir(d.params.Auth.Directory) + dirAuth, err := sshtun.AuthKeyDir(d.params.Auth.Directory) if err != nil { return nil } diff --git a/internal/server/driver/ssh/params.go b/internal/server/driver/ssh/params.go index 2602ac8..ae695bf 100644 --- a/internal/server/driver/ssh/params.go +++ b/internal/server/driver/ssh/params.go @@ -6,14 +6,16 @@ import ( ) type Params struct { - Address string `mapstructure:"address" validate:"required"` - Auth types.Auth `mapstructure:"auth"` - KeepAlive types.KeepAlive `mapstructure:"keepalive"` - Domain string `mapstructure:"domain"` - DomainProto string `mapstructure:"domain_proto"` - DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"` - Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"` - Prefix bool `mapstructure:"prefix"` + Address string `mapstructure:"address" validate:"required"` + DefaultHost *string `mapstructure:"default_host,omitempty"` + ForwardPort uint16 `mapstructure:"forward_port"` + Auth types.Auth `mapstructure:"auth"` + KeepAlive types.KeepAlive `mapstructure:"keepalive"` + Domain string `mapstructure:"domain"` + DomainProto string `mapstructure:"domain_proto"` + DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"` + Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"` + DisableRemoteHostResolve bool `mapstructure:"disable_remote_host_resolve"` } func (p *Params) Validate() error { diff --git a/internal/server/driver/ssh/sshproto/forward.go b/internal/server/driver/ssh/sshproto/forward.go deleted file mode 100644 index 7b5040a..0000000 --- a/internal/server/driver/ssh/sshproto/forward.go +++ /dev/null @@ -1,19 +0,0 @@ -package sshproto - -import "fmt" - -type Forward struct { - // local service to be forwarded - Local Endpoint `json:"local"` - // remote forwarding port (on remote SSH server network) - Remote Endpoint `json:"remote"` -} - -type Endpoint struct { - Host string `json:"host"` - Port int `json:"port"` -} - -func (endpoint *Endpoint) String() string { - return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) -} diff --git a/internal/server/driver/ssh/sshproto/ssh.go b/internal/server/driver/ssh/sshproto/ssh.go deleted file mode 100644 index 7e7dce6..0000000 --- a/internal/server/driver/ssh/sshproto/ssh.go +++ /dev/null @@ -1,242 +0,0 @@ -package sshproto - -import ( - "context" - "fmt" - "net" - "sync/atomic" - "time" - - "github.com/Neur0toxine/sshpoke/pkg/errtools" - "github.com/function61/gokit/app/backoff" - "github.com/function61/gokit/io/bidipipe" - "go.uber.org/zap" - "golang.org/x/crypto/ssh" -) - -type Client struct { - user string - address string - authMethods []ssh.AuthMethod - log *zap.SugaredLogger - connected atomic.Bool -} - -func New(address, user string, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Client { - return &Client{ - address: prepareAddress(address), - user: user, - authMethods: auth, - log: log.With(zap.String("sshServer", address)), - } -} - -func prepareAddress(address string) string { - _, _, err := net.SplitHostPort(address) - if err != nil && errtools.IsPortMissingErr(err) { - return net.JoinHostPort(address, "22") - } - return address -} - -func (c *Client) Connect(ctx context.Context) { - if c.connected.Load() { - return - } - - defer c.connected.Store(false) - backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) - for { - c.connected.Store(true) - err := c.connect(ctx) - if err != nil { - c.log.Error("connect error:", err) - } - - select { - case <-ctx.Done(): - return - default: - } - - time.Sleep(backoffTime()) - } -} - -// connect once to the SSH server. if the connection breaks, we return error and the caller -// will try to re-connect -func (c *Client) connect(ctx context.Context) error { - c.log.Debug("connecting") - sshConfig := &ssh.ClientConfig{ - User: c.user, - Auth: c.authMethods, - HostKeyCallback: ssh.InsecureIgnoreHostKey(), - } - - var sshClient *ssh.Client - var errConnect error - - sshClient, errConnect = dialSSH(ctx, c.address, sshConfig) - if errConnect != nil { - return errConnect - } - - // always disconnect when function returns - defer sshClient.Close() - defer c.log.Debug("disconnecting") - - c.log.Debug("connected") - - <-ctx.Done() - return nil -} - -// // connect once to the SSH server. if the connection breaks, we return error and the caller -// // will try to re-connect -// func connectToSshAndServe2( -// ctx context.Context, -// address string, -// authConfig types.Auth, -// auth ssh.AuthMethod, -// log *zap.SugaredLogger, -// ) error { -// log = log.With(zap.String("sshServer", address)) -// log.Debug("connecting") -// sshConfig := &ssh.ClientConfig{ -// User: authConfig.User, -// Auth: []ssh.AuthMethod{auth}, -// HostKeyCallback: ssh.InsecureIgnoreHostKey(), -// } -// -// var sshClient *ssh.Client -// var errConnect error -// -// sshClient, errConnect = dialSSH(ctx, address, sshConfig) -// if errConnect != nil { -// return errConnect -// } -// -// // always disconnect when function returns -// defer sshClient.Close() -// defer log.Debug("disconnecting") -// -// log.Debug("connected") -// -// // each listener in reverseForwardOnePort() can return one error, so make sure channel -// // has enough buffering so even if we return from here, the goroutines won't block trying -// // to return an error -// listenerStopped := make(chan error, len(conf.Forwards)) -// -// for _, forward := range conf.Forwards { -// // TODO: "if any fails, tear down all workers" -style error handling would be better -// // handled with https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc -// if err := reverseForwardOnePort( -// forward, -// sshClient, -// listenerStopped, -// makeLogger("reverseForwardOnePort"), -// makeLogger, -// ); err != nil { -// // closes SSH connection if even one forward Listen() fails -// return err -// } -// } -// -// // we're connected and have succesfully started listening on all reverse forwards, wait -// // for either user to ask us to stop or any of the listeners to error -// select { -// case <-ctx.Done(): // cancel requested -// return nil -// case listenerFirstErr := <-listenerStopped: -// // one or more of the listeners encountered an error. react by closing the connection -// // assumes all the other listeners failed too so no teardown necessary -// select { -// case <-ctx.Done(): // pretty much errors are to be expected if cancellation triggered -// return nil -// default: -// return listenerFirstErr -// } -// } -// } - -// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error -// -// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped -func reverseForwardOnePort( - forward Forward, - sshClient *ssh.Client, - listenerStopped chan<- error, - log *zap.SugaredLogger, -) error { - // reverse listen on remote server port - listener, err := sshClient.Listen("tcp", forward.Remote.String()) - if err != nil { - return err - } - - go func() { - defer listener.Close() - log.Debugf("forwarding %s -> %s", forward.Local.String(), forward.Remote.String()) - - // handle incoming connections on reverse forwarded tunnel - for { - client, err := listener.Accept() - if err != nil { - listenerStopped <- fmt.Errorf("error on Accept(): %w", err) - return - } - - // handle the connection in another goroutine, so we can support multiple concurrent - // connections on the same port - go handleReverseForwardConn(client, forward, log) - } - }() - - return nil -} - -func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) { - defer client.Close() - - log.Debugf("%s connected", client.RemoteAddr()) - defer log.Debug("closed") - - remote, err := net.Dial("tcp", forward.Local.String()) - if err != nil { - log.Errorf("dial INTO local service error: %s", err.Error()) - return - } - - // pipe data in both directions: - // - client => remote - // - remote => client - // - // - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed - // remote endpoint. - // - the "client" and "remote" strings we give Pipe() is just for error&log messages - // - this blocks until either of the parties' socket closes (or breaks) - if err := bidipipe.Pipe( - bidipipe.WithName("client", client), - bidipipe.WithName("remote", remote), - ); err != nil { - log.Error(err) - } -} - -func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { - dialer := net.Dialer{ - Timeout: 10 * time.Second, - } - - conn, err := dialer.DialContext(ctx, "tcp", addr) - if err != nil { - return nil, err - } - - clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig) - if err != nil { - return nil, err - } - - return ssh.NewClient(clConn, newChan, reqChan), nil -} diff --git a/internal/server/driver/ssh/sshproto/auth.go b/internal/server/driver/ssh/sshtun/auth.go similarity index 99% rename from internal/server/driver/ssh/sshproto/auth.go rename to internal/server/driver/ssh/sshtun/auth.go index edb5bd6..80bfe71 100644 --- a/internal/server/driver/ssh/sshproto/auth.go +++ b/internal/server/driver/ssh/sshtun/auth.go @@ -1,4 +1,4 @@ -package sshproto +package sshtun import ( "errors" diff --git a/internal/server/driver/ssh/sshproto/config.go b/internal/server/driver/ssh/sshtun/config.go similarity index 96% rename from internal/server/driver/ssh/sshproto/config.go rename to internal/server/driver/ssh/sshtun/config.go index b4b987a..bb3c5bb 100644 --- a/internal/server/driver/ssh/sshproto/config.go +++ b/internal/server/driver/ssh/sshtun/config.go @@ -1,4 +1,4 @@ -package sshproto +package sshtun import ( "bytes" diff --git a/internal/server/driver/ssh/sshtun/forward.go b/internal/server/driver/ssh/sshtun/forward.go new file mode 100644 index 0000000..e7f0ec3 --- /dev/null +++ b/internal/server/driver/ssh/sshtun/forward.go @@ -0,0 +1,37 @@ +package sshtun + +import ( + "fmt" + "net" + "strconv" + + "github.com/Neur0toxine/sshpoke/pkg/errtools" +) + +type Forward struct { + // local service to be forwarded + Local Endpoint `json:"local"` + // remote forwarding port (on remote SSH server network) + Remote Endpoint `json:"remote"` +} + +func AddrToEndpoint(address string) Endpoint { + host, port, err := net.SplitHostPort(address) + if err != nil && errtools.IsPortMissingErr(err) { + return Endpoint{Host: host, Port: 22} + } + portNum, err := strconv.Atoi(port) + if err != nil { + portNum = 22 + } + return Endpoint{Host: host, Port: portNum} +} + +type Endpoint struct { + Host string `json:"host"` + Port int `json:"port"` +} + +func (endpoint *Endpoint) String() string { + return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port) +} diff --git a/internal/server/driver/ssh/sshtun/printer.go b/internal/server/driver/ssh/sshtun/printer.go new file mode 100644 index 0000000..89ad3c8 --- /dev/null +++ b/internal/server/driver/ssh/sshtun/printer.go @@ -0,0 +1,21 @@ +package sshtun + +import ( + "bufio" + + "go.uber.org/zap" + "golang.org/x/crypto/ssh" +) + +func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback { + return func(session *ssh.Session) { + stdout, err := session.StdoutPipe() + if err != nil { + return + } + scan := bufio.NewScanner(stdout) + for scan.Scan() { + log.Debug(scan.Text()) + } + } +} diff --git a/internal/server/driver/ssh/sshtun/sish_compat.go b/internal/server/driver/ssh/sshtun/sish_compat.go new file mode 100644 index 0000000..9f8c2d4 --- /dev/null +++ b/internal/server/driver/ssh/sshtun/sish_compat.go @@ -0,0 +1,76 @@ +package sshtun + +import ( + "errors" + "net" + "reflect" + "sync" + "unsafe" + + "golang.org/x/crypto/ssh" +) + +type sishHostListener struct { + parent *ssh.Client +} + +func newSishHostListener(parent *ssh.Client) *sishHostListener { + return &sishHostListener{parent: parent} +} + +func (c *sishHostListener) ListenFakeRemoteHost(ep Endpoint) (net.Listener, error) { + c.doHandleForwardsOnce() + m := channelForwardMsg{ + ep.Host, + uint32(ep.Port), + } + // send message + ok, resp, err := c.parent.SendRequest("tcpip-forward", true, ssh.Marshal(&m)) + if err != nil { + return nil, err + } + if !ok { + return nil, errors.New("ssh: tcpip-forward request denied by peer") + } + + laddr := &net.TCPAddr{ + IP: net.ParseIP("127.0.0.1"), + Port: ep.Port, + } + if ep.Port == 0 { + var p struct { + Port uint32 + } + if err := ssh.Unmarshal(resp, &p); err != nil { + return nil, err + } + + laddr.Port = int(p.Port) + } + c.registerForward(laddr) + return nil, nil +} + +func (c *sishHostListener) registerForward(addr *net.TCPAddr) { + cl := reflect.ValueOf(c.parent).Elem() + forwardsUn := cl.FieldByName("forwards") + forwards := reflect.NewAt(forwardsUn.Type(), unsafe.Pointer(forwardsUn.UnsafeAddr())).Elem() + forwardVal := forwards.MethodByName("add").Call([]reflect.Value{reflect.ValueOf(addr)})[0] + _ = forwardVal +} + +func (c *sishHostListener) doHandleForwardsOnce() { + cl := reflect.ValueOf(c.parent) + clVal := cl.Elem() + onceField := clVal.FieldByName("handleForwardsOnce") + once := reflect.NewAt(onceField.Type(), unsafe.Pointer(onceField.UnsafeAddr())).Interface().(*sync.Once) + handleForwards := clVal.MethodByName("handleForwards") + once.Do(func() { + handleForwards.Call(nil) + }) +} + +type channelForwardMsg struct { + addr string + rport uint32 +} diff --git a/internal/server/driver/ssh/sshtun/ssh.go b/internal/server/driver/ssh/sshtun/ssh.go new file mode 100644 index 0000000..04f910c --- /dev/null +++ b/internal/server/driver/ssh/sshtun/ssh.go @@ -0,0 +1,210 @@ +package sshtun + +import ( + "context" + "fmt" + "net" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/function61/gokit/app/backoff" + "github.com/function61/gokit/io/bidipipe" + "go.uber.org/zap" + "golang.org/x/crypto/ssh" +) + +type SessionCallback func(*ssh.Session) + +var NoopSessionCallback SessionCallback = func(*ssh.Session) {} + +type Tunnel struct { + user string + address Endpoint + forward Forward + authMethods []ssh.AuthMethod + log *zap.SugaredLogger + connected atomic.Bool + fakeRemoteHost bool +} + +func New(address, user string, fakeRemoteHost bool, + forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel { + return &Tunnel{ + address: AddrToEndpoint(address), + user: user, + fakeRemoteHost: fakeRemoteHost, + forward: forward, + authMethods: auth, + log: log.With(zap.String("sshServer", address)), + } +} + +func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) { + if c.connected.Load() { + return + } + + defer c.connected.Store(false) + backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) + for { + c.connected.Store(true) + err := c.connect(ctx, sessionCb) + if err != nil { + c.log.Error("connect error:", err) + } + + select { + case <-ctx.Done(): + return + default: + } + + time.Sleep(backoffTime()) + } +} + +// connect once to the SSH server. if the connection breaks, we return error and the caller +// will try to re-connect +func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error { + c.log.Debug("connecting") + sshConfig := &ssh.ClientConfig{ + User: c.user, + Auth: c.authMethods, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + } + + var sshClient *ssh.Client + var errConnect error + + sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig) + if errConnect != nil { + return errConnect + } + + defer sshClient.Close() + defer c.log.Debug("disconnecting") + + c.log.Debug("connected") + listenerStopped := make(chan error) + + sess, err := sshClient.NewSession() + if err != nil { + c.log.Errorf("session error: %s", err) + return err + } + defer sess.Close() + + var wg sync.WaitGroup + if sessionCb == nil { + sessionCb = func(*ssh.Session) {} + } + wg.Add(2) + go func() { + defer wg.Done() + sessionCb(sess) + }() + reverseErr := make(chan error) + go func() { + defer wg.Done() + reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped) + }() + + if err := <-reverseErr; err != nil { + return err + } + + select { + case <-ctx.Done(): + return nil + case listenerFirstErr := <-listenerStopped: + select { + case <-ctx.Done(): + return nil + default: + return listenerFirstErr + } + } +} + +// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error +// +// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped +func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error { + if c.fakeRemoteHost { + newSishHostListener(sshClient).ListenFakeRemoteHost(c.forward.Remote) + time.Sleep(time.Second) + os.Exit(0) + } + listener, err := sshClient.Listen("tcp", c.forward.Remote.String()) + if err != nil { + return err + } + + go func() { + defer listener.Close() + c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String()) + + for { + client, err := listener.Accept() + if err != nil { + listenerStopped <- fmt.Errorf("error on Accept(): %w", err) + return + } + + go handleReverseForwardConn(client, c.forward, c.log) + } + }() + + return nil +} + +func (c *Tunnel) listenTCPWithoutResolving() { +} + +func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) { + defer client.Close() + + log.Debugf("%s connected", client.RemoteAddr()) + defer log.Debug("closed") + + remote, err := net.Dial("tcp", forward.Local.String()) + if err != nil { + log.Errorf("dial INTO local service error: %s", err.Error()) + return + } + + // pipe data in both directions: + // - client => remote + // - remote => client + // + // - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed + // remote endpoint. + // - the "client" and "remote" strings we give Pipe() is just for error&log messages + // - this blocks until either of the parties' socket closes (or breaks) + if err := bidipipe.Pipe( + bidipipe.WithName("client", client), + bidipipe.WithName("remote", remote), + ); err != nil { + log.Error(err) + } +} + +func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) { + dialer := net.Dialer{ + Timeout: 10 * time.Second, + } + + conn, err := dialer.DialContext(ctx, "tcp", addr) + if err != nil { + return nil, err + } + + clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig) + if err != nil { + return nil, err + } + + return ssh.NewClient(clConn, newChan, reqChan), nil +} diff --git a/internal/server/proto/sshtun/tunnel.go b/internal/server/proto/sshtun/tunnel.go deleted file mode 100644 index a3b169f..0000000 --- a/internal/server/proto/sshtun/tunnel.go +++ /dev/null @@ -1,249 +0,0 @@ -// Copyright 2017, The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE.md file. - -package sshtun - -import ( - "context" - "fmt" - "io" - "net" - "sync" - "sync/atomic" - "time" - - "golang.org/x/crypto/ssh" -) - -type TunnelMode uint8 - -func (t TunnelMode) String() string { - switch t { - case TunnelForward: - return "->" - case TunnelReverse: - return "<-" - default: - return "" - } -} - -const ( - TunnelForward TunnelMode = iota - TunnelReverse -) - -type logger interface { - Printf(string, ...interface{}) -} - -type Tunnel struct { - Auth []ssh.AuthMethod - HostKeys ssh.HostKeyCallback - Mode TunnelMode - User string - HostAddr string - BindAddr string - DialAddr string - RetryInterval time.Duration - KeepAlive KeepAliveConfig - Logger logger -} - -type KeepAliveConfig struct { - // Interval is the amount of time in seconds to wait before the - // Tunnel client will send a keep-alive message to ensure some minimum - // traffic on the SSH connection. - Interval uint - - // CountMax is the maximum number of consecutive failed responses to - // keep-alive messages the client is willing to tolerate before considering - // the SSH connection as dead. - CountMax uint -} - -func (t Tunnel) String() string { - var left, right string - switch t.Mode { - case TunnelForward: - left, right = t.BindAddr, t.DialAddr - case TunnelReverse: - left, right = t.DialAddr, t.BindAddr - } - return fmt.Sprintf("%s@%s | %s %s %s", t.User, t.HostAddr, left, t.Mode, right) -} - -func (t Tunnel) Bind(ctx context.Context, wg *sync.WaitGroup) { - defer wg.Done() - - for { - var once sync.Once // Only print errors once per session - func() { - // Connect to the server host via SSH. - cl, err := ssh.Dial("tcp", t.HostAddr, &ssh.ClientConfig{ - User: t.User, - Auth: t.Auth, - HostKeyCallback: t.HostKeys, - Timeout: 5 * time.Second, - }) - if err != nil { - once.Do(func() { t.Logger.Printf("(%v) SSH dial error: %v", t, err) }) - return - } - wg.Add(1) - go t.keepAliveMonitor(&once, wg, cl) - defer cl.Close() - - // Attempt to bind to the inbound socket. - var ln net.Listener - switch t.Mode { - case TunnelForward: - ln, err = net.Listen("tcp", t.BindAddr) - case TunnelReverse: - ln, err = cl.Listen("tcp", t.BindAddr) - } - if err != nil { - once.Do(func() { t.Logger.Printf("(%v) bind error: %v", t, err) }) - return - } - - // The socket is bound. Make sure we close it eventually. - bindCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - cl.Wait() - cancel() - }() - go func() { - <-bindCtx.Done() - once.Do(func() {}) // Suppress future errors - ln.Close() - }() - - t.Logger.Printf("(%v) bound Tunnel", t) - defer t.Logger.Printf("(%v) collapsed Tunnel", t) - - // Accept all incoming connections. - for { - cn1, err := ln.Accept() - if err != nil { - once.Do(func() { t.Logger.Printf("(%v) accept error: %v", t, err) }) - return - } - wg.Add(1) - go t.dialTunnel(bindCtx, wg, cl, cn1) - } - }() - - select { - case <-ctx.Done(): - return - case <-time.After(t.RetryInterval): - t.Logger.Printf("(%v) retrying...", t) - } - } -} - -func (t Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh.Client, cn1 net.Conn) { - defer wg.Done() - - // The inbound connection is established. Make sure we close it eventually. - connCtx, cancel := context.WithCancel(ctx) - defer cancel() - go func() { - <-connCtx.Done() - cn1.Close() - }() - - // Establish the outbound connection. - var cn2 net.Conn - var err error - switch t.Mode { - case TunnelForward: - cn2, err = client.Dial("tcp", t.DialAddr) - case TunnelReverse: - cn2, err = net.Dial("tcp", t.DialAddr) - } - if err != nil { - t.Logger.Printf("(%v) dial error: %v", t, err) - return - } - - go func() { - <-connCtx.Done() - cn2.Close() - }() - - t.Logger.Printf("(%v) connection established", t) - defer t.Logger.Printf("(%v) connection closed", t) - - // Copy bytes from one connection to the other until one side closes. - var once sync.Once - var wg2 sync.WaitGroup - wg2.Add(2) - go func() { - defer wg2.Done() - defer cancel() - if _, err := io.Copy(cn1, cn2); err != nil { - once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) }) - } - once.Do(func() {}) // Suppress future errors - }() - go func() { - defer wg2.Done() - defer cancel() - if _, err := io.Copy(cn2, cn1); err != nil { - once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) }) - } - once.Do(func() {}) // Suppress future errors - }() - wg2.Wait() -} - -// keepAliveMonitor periodically sends messages to invoke a response. -// If the server does not respond after some period of time, -// assume that the underlying net.Conn abruptly died. -func (t Tunnel) keepAliveMonitor(once *sync.Once, wg *sync.WaitGroup, client *ssh.Client) { - defer wg.Done() - if t.KeepAlive.Interval == 0 || t.KeepAlive.CountMax == 0 { - return - } - - // Detect when the SSH connection is closed. - wait := make(chan error, 1) - wg.Add(1) - go func() { - defer wg.Done() - wait <- client.Wait() - }() - - // Repeatedly check if the remote server is still alive. - var aliveCount int32 - ticker := time.NewTicker(time.Duration(t.KeepAlive.Interval) * time.Second) - defer ticker.Stop() - for { - select { - case err := <-wait: - if err != nil && err != io.EOF { - once.Do(func() { t.Logger.Printf("(%v) SSH error: %v", t, err) }) - } - return - case <-ticker.C: - if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.KeepAlive.CountMax) { - once.Do(func() { t.Logger.Printf("(%v) SSH keep-alive termination", t) }) - client.Close() - return - } - } - - wg.Add(1) - go func() { - defer wg.Done() - _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) - if err == nil { - atomic.StoreInt32(&aliveCount, 0) - } - }() - } -} diff --git a/internal/server/proto/sshtun/tunnel_test.go b/internal/server/proto/sshtun/tunnel_test.go deleted file mode 100644 index 2561d62..0000000 --- a/internal/server/proto/sshtun/tunnel_test.go +++ /dev/null @@ -1,509 +0,0 @@ -// Copyright 2017, The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE.md file. - -package sshtun - -import ( - "bytes" - "context" - "crypto/md5" - "crypto/rsa" - "encoding/binary" - "fmt" - "io" - "io/ioutil" - "math/rand" - "net" - "reflect" - "strconv" - "sync" - "testing" - "time" - - "golang.org/x/crypto/ssh" -) - -type testLogger struct { - *testing.T // Already has Fatalf method -} - -func (t testLogger) Printf(f string, x ...interface{}) { t.Logf(f, x...) } - -func TestTunnel(t *testing.T) { - rootWG := new(sync.WaitGroup) - defer rootWG.Wait() - rootCtx, cancelAll := context.WithCancel(context.Background()) - defer cancelAll() - - // Open all of the TCP sockets needed for the test. - tcpLn0 := openListener(t) // Start of the chain - tcpLn1 := openListener(t) // Mid-point of the chain - tcpLn2 := openListener(t) // End of the chain - srvLn0 := openListener(t) // Socket for SSH server in reverse Mode - srvLn1 := openListener(t) // Socket for SSH server in forward Mode - - tcpLn0.Close() // To be later binded by the reverse Tunnel - tcpLn1.Close() // To be later binded by the forward Tunnel - go closeWhenDone(rootCtx, tcpLn2) - go closeWhenDone(rootCtx, srvLn0) - go closeWhenDone(rootCtx, srvLn1) - - // Generate keys for both the servers and clients. - clientPriv0, clientPub0 := generateKeys(t) - clientPriv1, clientPub1 := generateKeys(t) - serverPriv0, serverPub0 := generateKeys(t) - serverPriv1, serverPub1 := generateKeys(t) - - // Start the SSH servers. - rootWG.Add(2) - go func() { - defer rootWG.Done() - runServer(t, rootCtx, srvLn0, serverPriv0, clientPub0, clientPub1) - }() - go func() { - defer rootWG.Done() - runServer(t, rootCtx, srvLn1, serverPriv1, clientPub0, clientPub1) - }() - - wg := new(sync.WaitGroup) - defer wg.Wait() - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // Create the Tunnel configurations. - tn0 := Tunnel{ - Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv0)}, - HostKeys: ssh.FixedHostKey(serverPub0), - Mode: TunnelReverse, // Reverse Tunnel - User: "user0", - HostAddr: srvLn0.Addr().String(), - BindAddr: tcpLn0.Addr().String(), - DialAddr: tcpLn1.Addr().String(), - Logger: testLogger{t}, - } - tn1 := Tunnel{ - Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv1)}, - HostKeys: ssh.FixedHostKey(serverPub1), - Mode: TunnelForward, // Forward Tunnel - User: "user1", - HostAddr: srvLn1.Addr().String(), - BindAddr: tcpLn1.Addr().String(), - DialAddr: tcpLn2.Addr().String(), - Logger: testLogger{t}, - } - - // Start the SSH client tunnels. - wg.Add(2) - go tn0.Bind(ctx, wg) - go tn1.Bind(ctx, wg) - - t.Log("test started") - done := make(chan bool, 10) - - // Start all the transmitters. - for i := 0; i < cap(done); i++ { - i := i - go func() { - for { - rnd := rand.New(rand.NewSource(int64(i))) - hash := md5.New() - size := uint32((1 << 10) + rnd.Intn(1<<20)) - buf4 := make([]byte, 4) - binary.LittleEndian.PutUint32(buf4, size) - - cnStart, err := net.Dial("tcp", tcpLn0.Addr().String()) - if err != nil { - time.Sleep(10 * time.Millisecond) - continue - } - defer cnStart.Close() - if _, err := cnStart.Write(buf4); err != nil { - t.Errorf("write size error: %v", err) - break - } - r := io.LimitReader(rnd, int64(size)) - w := io.MultiWriter(cnStart, hash) - if _, err := io.Copy(w, r); err != nil { - t.Errorf("copy error: %v", err) - break - } - if _, err := cnStart.Write(hash.Sum(nil)); err != nil { - t.Errorf("write hash error: %v", err) - break - } - if err := cnStart.Close(); err != nil { - t.Errorf("close error: %v", err) - break - } - break - } - }() - } - - // Start all the receivers. - for i := 0; i < cap(done); i++ { - go func() { - for { - hash := md5.New() - buf4 := make([]byte, 4) - - cnEnd, err := tcpLn2.Accept() - if err != nil { - time.Sleep(10 * time.Millisecond) - continue - } - defer cnEnd.Close() - - if _, err := io.ReadFull(cnEnd, buf4); err != nil { - t.Errorf("read size error: %v", err) - break - } - size := binary.LittleEndian.Uint32(buf4) - r := io.LimitReader(cnEnd, int64(size)) - if _, err := io.Copy(hash, r); err != nil { - t.Errorf("copy error: %v", err) - break - } - wantHash, err := ioutil.ReadAll(cnEnd) - if err != nil { - t.Errorf("read hash error: %v", err) - break - } - if err := cnEnd.Close(); err != nil { - t.Errorf("close error: %v", err) - break - } - - if gotHash := hash.Sum(nil); !bytes.Equal(gotHash, wantHash) { - t.Errorf("hash mismatch:\ngot %x\nwant %x", gotHash, wantHash) - } - break - } - done <- true - }() - } - - for i := 0; i < cap(done); i++ { - select { - case <-done: - case <-time.After(10 * time.Second): - t.Errorf("timed out: %d remaining", cap(done)-i) - return - } - } - t.Log("test complete") -} - -// generateKeys generates a random pair of SSH private and public keys. -func generateKeys(t *testing.T) (priv ssh.Signer, pub ssh.PublicKey) { - rnd := rand.New(rand.NewSource(time.Now().Unix())) - rsaKey, err := rsa.GenerateKey(rnd, 1024) - if err != nil { - t.Fatalf("unable to generate RSA key pair: %v", err) - } - priv, err = ssh.NewSignerFromKey(rsaKey) - if err != nil { - t.Fatalf("unable to generate signer: %v", err) - } - pub, err = ssh.NewPublicKey(&rsaKey.PublicKey) - if err != nil { - t.Fatalf("unable to generate public key: %v", err) - } - return priv, pub -} - -func openListener(t *testing.T) net.Listener { - ln, err := net.Listen("tcp", ":0") - if err != nil { - t.Fatalf("listen error: %v", err) - } - return ln -} - -// runServer starts an SSH server capable of handling forward and reverse -// TCP tunnels. This function blocks for the entire duration that the -// server is running and can be stopped by canceling the context. -// -// The server listens on the provided Listener and will present to clients -// a certificate from serverKey and will only accept users that match -// the provided clientKeys. Only users of the name "User%d" are allowed where -// the ID number is the index for the specified client key provided. -func runServer(t *testing.T, ctx context.Context, ln net.Listener, serverKey ssh.Signer, clientKeys ...ssh.PublicKey) { - wg := new(sync.WaitGroup) - defer wg.Wait() - - // Generate SSH server configuration. - conf := ssh.ServerConfig{ - PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { - var uid int - _, err := fmt.Sscanf(c.User(), "User%d", &uid) - if err != nil || uid >= len(clientKeys) || !bytes.Equal(clientKeys[uid].Marshal(), pubKey.Marshal()) { - return nil, fmt.Errorf("unknown public key for %q", c.User()) - } - return nil, nil - }, - } - conf.AddHostKey(serverKey) - - // Handle every SSH client connection. - for { - tcpCn, err := ln.Accept() - if err != nil { - if !isDone(ctx) { - t.Errorf("accept error: %v", err) - } - return - } - wg.Add(1) - go handleServerConn(t, ctx, wg, tcpCn, &conf) - } -} - -// handleServerConn handles a single SSH connection. -func handleServerConn(t *testing.T, ctx context.Context, wg *sync.WaitGroup, tcpCn net.Conn, conf *ssh.ServerConfig) { - defer wg.Done() - go closeWhenDone(ctx, tcpCn) - defer tcpCn.Close() - - sshCn, chans, reqs, err := ssh.NewServerConn(tcpCn, conf) - if err != nil { - t.Errorf("new connection error: %v", err) - return - } - go closeWhenDone(ctx, sshCn) - defer sshCn.Close() - - wg.Add(1) - go handleServerChannels(t, ctx, wg, sshCn, chans) - - wg.Add(1) - go handleServerRequests(t, ctx, wg, sshCn, reqs) - - if err := sshCn.Wait(); err != nil && err != io.EOF && !isDone(ctx) { - t.Errorf("connection error: %v", err) - } -} - -// handleServerChannels handles new channels on a SSH connection. -// The client initiates a new channel when forwarding a TCP dial. -func handleServerChannels(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, chans <-chan ssh.NewChannel) { - defer wg.Done() - for nc := range chans { - if nc.ChannelType() != "direct-tcpip" { - nc.Reject(ssh.UnknownChannelType, "not implemented") - continue - } - var args struct { - DstHost string - DstPort uint32 - SrcHost string - SrcPort uint32 - } - if !unmarshalData(nc.ExtraData(), &args) { - nc.Reject(ssh.Prohibited, "invalid request") - continue - } - - // Open a connection for both sides. - cn, err := net.Dial("tcp", net.JoinHostPort(args.DstHost, strconv.Itoa(int(args.DstPort)))) - if err != nil { - nc.Reject(ssh.ConnectionFailed, err.Error()) - continue - } - ch, reqs, err := nc.Accept() - if err != nil { - t.Errorf("accept channel error: %v", err) - cn.Close() - continue - } - go ssh.DiscardRequests(reqs) - - wg.Add(1) - go bidirCopyAndClose(t, ctx, wg, cn, ch) - } -} - -// handleServerRequests handles new requests on a SSH connection. -// The client initiates a new request for binding a local TCP socket. -func handleServerRequests(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, reqs <-chan *ssh.Request) { - defer wg.Done() - for r := range reqs { - if !r.WantReply { - continue - } - if r.Type != "tcpip-forward" { - r.Reply(false, nil) - continue - } - var args struct { - Host string - Port uint32 - } - if !unmarshalData(r.Payload, &args) { - r.Reply(false, nil) - continue - } - ln, err := net.Listen("tcp", net.JoinHostPort(args.Host, strconv.Itoa(int(args.Port)))) - if err != nil { - r.Reply(false, nil) - continue - } - - var resp struct{ Port uint32 } - _, resp.Port = splitHostPort(ln.Addr().String()) - if err := r.Reply(true, marshalData(resp)); err != nil { - t.Errorf("request reply error: %v", err) - ln.Close() - continue - } - - wg.Add(1) - go handleLocalListener(t, ctx, wg, sshCn, ln, args.Host) - - } -} - -// handleLocalListener handles every new connection on the provided socket. -// All local connections will be forwarded to the client via a new channel. -func handleLocalListener(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, ln net.Listener, host string) { - defer wg.Done() - go closeWhenDone(ctx, ln) - defer ln.Close() - - for { - // Open a connection for both sides. - cn, err := ln.Accept() - if err != nil { - if !isDone(ctx) { - t.Errorf("accept error: %v", err) - } - return - } - var args struct { - DstHost string - DstPort uint32 - SrcHost string - SrcPort uint32 - } - args.DstHost, args.DstPort = splitHostPort(cn.LocalAddr().String()) - args.SrcHost, args.SrcPort = splitHostPort(cn.RemoteAddr().String()) - args.DstHost = host // This must match on client side! - ch, reqs, err := sshCn.OpenChannel("forwarded-tcpip", marshalData(args)) - if err != nil { - t.Errorf("open channel error: %v", err) - cn.Close() - continue - } - go ssh.DiscardRequests(reqs) - - wg.Add(1) - go bidirCopyAndClose(t, ctx, wg, cn, ch) - } -} - -// bidirCopyAndClose performs a bi-directional copy on both connections -// until either side closes the connection or the context is canceled. -// This will close both connections before returning. -func bidirCopyAndClose(t *testing.T, ctx context.Context, wg *sync.WaitGroup, c1, c2 io.ReadWriteCloser) { - defer wg.Done() - go closeWhenDone(ctx, c1) - go closeWhenDone(ctx, c2) - defer c1.Close() - defer c2.Close() - - errc := make(chan error, 2) - go func() { - _, err := io.Copy(c1, c2) - errc <- err - }() - go func() { - _, err := io.Copy(c2, c1) - errc <- err - }() - if err := <-errc; err != nil && err != io.EOF && !isDone(ctx) { - t.Errorf("copy error: %v", err) - } -} - -// unmarshalData parses b into s, where s is a pointer to a struct. -// Only unexported fields of type uint32 or string are allowed. -func unmarshalData(b []byte, s interface{}) bool { - v := reflect.ValueOf(s) - if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { - panic("destination must be pointer to struct") - } - v = v.Elem() - for i := 0; i < v.NumField(); i++ { - switch v.Type().Field(i).Type.Kind() { - case reflect.Uint32: - if len(b) < 4 { - return false - } - v.Field(i).Set(reflect.ValueOf(binary.BigEndian.Uint32(b))) - b = b[4:] - case reflect.String: - if len(b) < 4 { - return false - } - n := binary.BigEndian.Uint32(b) - b = b[4:] - if uint64(len(b)) < uint64(n) { - return false - } - v.Field(i).Set(reflect.ValueOf(string(b[:n]))) - b = b[n:] - default: - panic("invalid field type: " + v.Type().Field(i).Type.String()) - } - } - return len(b) == 0 -} - -// marshalData serializes s into b, where s is a struct (or a pointer to one). -// Only unexported fields of type uint32 or string are allowed. -func marshalData(s interface{}) (b []byte) { - v := reflect.ValueOf(s) - if v.IsValid() && v.Kind() == reflect.Ptr { - v = v.Elem() - } - if !v.IsValid() || v.Kind() != reflect.Struct { - panic("source must be a struct") - } - var arr32 [4]byte - for i := 0; i < v.NumField(); i++ { - switch v.Type().Field(i).Type.Kind() { - case reflect.Uint32: - binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Uint())) - b = append(b, arr32[:]...) - case reflect.String: - binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Len())) - b = append(b, arr32[:]...) - b = append(b, v.Field(i).String()...) - default: - panic("invalid field type: " + v.Type().Field(i).Type.String()) - } - } - return b - -} - -func splitHostPort(s string) (string, uint32) { - host, port, _ := net.SplitHostPort(s) - p, _ := strconv.Atoi(port) - return host, uint32(p) -} - -func closeWhenDone(ctx context.Context, c io.Closer) { - <-ctx.Done() - c.Close() -} - -func isDone(ctx context.Context) bool { - select { - case <-ctx.Done(): - return true - default: - return false - } -} diff --git a/pkg/convert/convert.go b/pkg/convert/convert.go index d75ee2b..d202ce5 100644 --- a/pkg/convert/convert.go +++ b/pkg/convert/convert.go @@ -11,13 +11,13 @@ func MessageToAppEvent(event *pb.EventMessage) dto.Event { return dto.Event{ Type: MessageEventTypeToApp(event.GetType()), Container: dto.Container{ - ID: event.GetContainer().GetId(), - Names: event.GetContainer().GetNames(), - IP: net.ParseIP(event.GetContainer().GetIp()), - Port: uint16(event.GetContainer().GetPort()), - Server: event.GetContainer().GetServer(), - Prefix: event.GetContainer().GetPrefix(), - Domain: event.GetContainer().GetDomain(), + ID: event.GetContainer().GetId(), + Names: event.GetContainer().GetNames(), + IP: net.ParseIP(event.GetContainer().GetIp()), + Port: uint16(event.GetContainer().GetPort()), + Server: event.GetContainer().GetServer(), + RemoteHost: event.GetContainer().GetRemoteHost(), + Domain: event.GetContainer().GetDomain(), }, } } @@ -26,13 +26,13 @@ func AppEventToMessage(event dto.Event) *pb.EventMessage { return &pb.EventMessage{ Type: AppEventTypeToMessage(event.Type), Container: &pb.Container{ - Id: event.Container.ID, - Names: event.Container.Names, - Ip: event.Container.IP.String(), - Port: uint32(event.Container.Port), - Server: event.Container.Server, - Prefix: event.Container.Prefix, - Domain: event.Container.Domain, + Id: event.Container.ID, + Names: event.Container.Names, + Ip: event.Container.IP.String(), + Port: uint32(event.Container.Port), + Server: event.Container.Server, + RemoteHost: event.Container.RemoteHost, + Domain: event.Container.Domain, }, } } diff --git a/pkg/dto/models.go b/pkg/dto/models.go index 8f097d4..eb4ee99 100644 --- a/pkg/dto/models.go +++ b/pkg/dto/models.go @@ -36,11 +36,11 @@ type EventStatus struct { } type Container struct { - ID string `json:"id"` - Names []string `json:"names"` - IP net.IP `json:"ip"` - Port uint16 `json:"port"` - Server string `json:"-"` - Prefix string `json:"prefix"` - Domain string `json:"domain"` + ID string `json:"id"` + Names []string `json:"names"` + IP net.IP `json:"ip"` + Port uint16 `json:"port"` + Server string `json:"-"` + RemoteHost string `json:"remote_host"` + Domain string `json:"domain"` } diff --git a/pkg/errtools/IsPortMissingErr.go b/pkg/errtools/is_port_missing_err.go similarity index 100% rename from pkg/errtools/IsPortMissingErr.go rename to pkg/errtools/is_port_missing_err.go diff --git a/pkg/plugin/pb/pb.proto b/pkg/plugin/pb/pb.proto index 338749b..044ecf8 100644 --- a/pkg/plugin/pb/pb.proto +++ b/pkg/plugin/pb/pb.proto @@ -23,7 +23,7 @@ message Container { string ip = 3; uint32 port = 4; string server = 5; - string prefix = 6; + string remote_host = 6; string domain = 7; }