sish support (wip)
This commit is contained in:
parent
401d9123c8
commit
11a5f48d68
@ -110,7 +110,7 @@ func (d *Docker) Listen() (chan dto.Event, error) {
|
|||||||
"container.ip", converted.IP.String(),
|
"container.ip", converted.IP.String(),
|
||||||
"container.port", converted.Port,
|
"container.port", converted.Port,
|
||||||
"container.server", converted.Server,
|
"container.server", converted.Server,
|
||||||
"container.prefix", converted.Prefix)
|
"container.remote_host", converted.RemoteHost)
|
||||||
output <- newEvent
|
output <- newEvent
|
||||||
case err := <-errSource:
|
case err := <-errSource:
|
||||||
if errors.Is(err, context.Canceled) {
|
if errors.Is(err, context.Canceled) {
|
||||||
|
@ -17,7 +17,7 @@ type labelsConfig struct {
|
|||||||
Network string `mapstructure:"sshpoke.network"`
|
Network string `mapstructure:"sshpoke.network"`
|
||||||
Server string `mapstructure:"sshpoke.server"`
|
Server string `mapstructure:"sshpoke.server"`
|
||||||
Port string `mapstructure:"sshpoke.port"`
|
Port string `mapstructure:"sshpoke.port"`
|
||||||
Prefix string `mapstructure:"sshpoke.prefix"`
|
RemoteHost string `mapstructure:"sshpoke.remote_host"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type boolStr string
|
type boolStr string
|
||||||
@ -82,7 +82,7 @@ func dockerContainerToInternal(container types.Container) (result dto.Container,
|
|||||||
IP: ip,
|
IP: ip,
|
||||||
Port: uint16(port),
|
Port: uint16(port),
|
||||||
Server: labels.Server,
|
Server: labels.Server,
|
||||||
Prefix: labels.Prefix,
|
RemoteHost: labels.RemoteHost,
|
||||||
}, true
|
}, true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,13 +2,14 @@ package ssh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"net"
|
||||||
"path"
|
"path"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/internal/config"
|
"github.com/Neur0toxine/sshpoke/internal/config"
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/base"
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshproto"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/sshtun"
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/ssh/types"
|
||||||
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
"github.com/Neur0toxine/sshpoke/internal/server/driver/util"
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
"github.com/Neur0toxine/sshpoke/pkg/dto"
|
||||||
@ -18,23 +19,38 @@ import (
|
|||||||
type SSH struct {
|
type SSH struct {
|
||||||
base.Base
|
base.Base
|
||||||
params Params
|
params Params
|
||||||
proto *sshproto.Client
|
auth []ssh.AuthMethod
|
||||||
|
conns map[string]conn
|
||||||
|
rw sync.RWMutex
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type conn struct {
|
||||||
|
ctx context.Context
|
||||||
|
cancel func()
|
||||||
|
tun *sshtun.Tunnel
|
||||||
|
}
|
||||||
|
|
||||||
func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) {
|
func New(ctx context.Context, name string, params config.DriverParams) (base.Driver, error) {
|
||||||
drv := &SSH{
|
drv := &SSH{
|
||||||
Base: base.New(ctx, name),
|
Base: base.New(ctx, name),
|
||||||
|
conns: make(map[string]conn),
|
||||||
}
|
}
|
||||||
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
if err := util.UnmarshalParams(params, &drv.params); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
drv.populateFromSSHConfig()
|
drv.populateFromSSHConfig()
|
||||||
drv.proto = sshproto.New(drv.params.Address, drv.params.Auth.User, drv.authenticators(), drv.Log())
|
drv.auth = drv.authenticators()
|
||||||
go drv.proto.Connect(drv.Context())
|
|
||||||
return drv, nil
|
return drv, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *SSH) forward(val sshtun.Forward) conn {
|
||||||
|
tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.DisableRemoteHostResolve, val, d.auth, d.Log())
|
||||||
|
ctx, cancel := context.WithCancel(d.Context())
|
||||||
|
go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String())))
|
||||||
|
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
||||||
|
}
|
||||||
|
|
||||||
func (d *SSH) populateFromSSHConfig() {
|
func (d *SSH) populateFromSSHConfig() {
|
||||||
if d.params.Auth.Directory == "" {
|
if d.params.Auth.Directory == "" {
|
||||||
return
|
return
|
||||||
@ -59,8 +75,43 @@ func (d *SSH) populateFromSSHConfig() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) Handle(event dto.Event) error {
|
func (d *SSH) Handle(event dto.Event) error {
|
||||||
// TODO: Implement event handling & connections management.
|
defer d.rw.Unlock()
|
||||||
return errors.New("server handler is not implemented yet")
|
d.rw.Lock()
|
||||||
|
switch event.Type {
|
||||||
|
case dto.EventStart:
|
||||||
|
conn := d.forward(sshtun.Forward{
|
||||||
|
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
|
||||||
|
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
||||||
|
})
|
||||||
|
d.conns[event.Container.ID] = conn
|
||||||
|
d.wg.Add(1)
|
||||||
|
case dto.EventStop:
|
||||||
|
conn, found := d.conns[event.Container.ID]
|
||||||
|
if !found {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
conn.cancel()
|
||||||
|
delete(d.conns, event.Container.ID)
|
||||||
|
d.wg.Done()
|
||||||
|
case dto.EventShutdown:
|
||||||
|
for id, conn := range d.conns {
|
||||||
|
conn.cancel()
|
||||||
|
delete(d.conns, id)
|
||||||
|
d.wg.Done()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *SSH) remoteEndpoint(remoteHost string) sshtun.Endpoint {
|
||||||
|
port := int(d.params.ForwardPort)
|
||||||
|
if port == 0 {
|
||||||
|
port = 80
|
||||||
|
}
|
||||||
|
return sshtun.Endpoint{
|
||||||
|
Host: remoteHost,
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) Driver() config.DriverType {
|
func (d *SSH) Driver() config.DriverType {
|
||||||
@ -68,6 +119,7 @@ func (d *SSH) Driver() config.DriverType {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) WaitForShutdown() {
|
func (d *SSH) WaitForShutdown() {
|
||||||
|
go d.Handle(dto.Event{Type: dto.EventShutdown})
|
||||||
d.wg.Wait()
|
d.wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -82,19 +134,19 @@ func (d *SSH) authenticators() []ssh.AuthMethod {
|
|||||||
func (d *SSH) authenticator() ssh.AuthMethod {
|
func (d *SSH) authenticator() ssh.AuthMethod {
|
||||||
switch d.params.Auth.Type {
|
switch d.params.Auth.Type {
|
||||||
case types.AuthTypePasswordless:
|
case types.AuthTypePasswordless:
|
||||||
return sshproto.AuthPassword("")
|
return sshtun.AuthPassword("")
|
||||||
case types.AuthTypePassword:
|
case types.AuthTypePassword:
|
||||||
return sshproto.AuthPassword(d.params.Auth.Password)
|
return sshtun.AuthPassword(d.params.Auth.Password)
|
||||||
case types.AuthTypeKey:
|
case types.AuthTypeKey:
|
||||||
if d.params.Auth.Keyfile != "" {
|
if d.params.Auth.Keyfile != "" {
|
||||||
keyAuth, err := sshproto.AuthKeyFile(
|
keyAuth, err := sshtun.AuthKeyFile(
|
||||||
types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile)))
|
types.SmartPath(path.Join(d.params.Auth.Directory.String(), d.params.Auth.Keyfile)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return keyAuth
|
return keyAuth
|
||||||
}
|
}
|
||||||
dirAuth, err := sshproto.AuthKeyDir(d.params.Auth.Directory)
|
dirAuth, err := sshtun.AuthKeyDir(d.params.Auth.Directory)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -7,13 +7,15 @@ import (
|
|||||||
|
|
||||||
type Params struct {
|
type Params struct {
|
||||||
Address string `mapstructure:"address" validate:"required"`
|
Address string `mapstructure:"address" validate:"required"`
|
||||||
|
DefaultHost *string `mapstructure:"default_host,omitempty"`
|
||||||
|
ForwardPort uint16 `mapstructure:"forward_port"`
|
||||||
Auth types.Auth `mapstructure:"auth"`
|
Auth types.Auth `mapstructure:"auth"`
|
||||||
KeepAlive types.KeepAlive `mapstructure:"keepalive"`
|
KeepAlive types.KeepAlive `mapstructure:"keepalive"`
|
||||||
Domain string `mapstructure:"domain"`
|
Domain string `mapstructure:"domain"`
|
||||||
DomainProto string `mapstructure:"domain_proto"`
|
DomainProto string `mapstructure:"domain_proto"`
|
||||||
DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"`
|
DomainExtractRegex string `mapstructure:"domain_extract_regex" validate:"validregexp"`
|
||||||
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
|
Mode types.DomainMode `mapstructure:"mode" validate:"required,oneof=single multi"`
|
||||||
Prefix bool `mapstructure:"prefix"`
|
DisableRemoteHostResolve bool `mapstructure:"disable_remote_host_resolve"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Params) Validate() error {
|
func (p *Params) Validate() error {
|
||||||
|
@ -1,19 +0,0 @@
|
|||||||
package sshproto
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
type Forward struct {
|
|
||||||
// local service to be forwarded
|
|
||||||
Local Endpoint `json:"local"`
|
|
||||||
// remote forwarding port (on remote SSH server network)
|
|
||||||
Remote Endpoint `json:"remote"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Endpoint struct {
|
|
||||||
Host string `json:"host"`
|
|
||||||
Port int `json:"port"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (endpoint *Endpoint) String() string {
|
|
||||||
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
|
||||||
}
|
|
@ -1,242 +0,0 @@
|
|||||||
package sshproto
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
|
||||||
"github.com/function61/gokit/app/backoff"
|
|
||||||
"github.com/function61/gokit/io/bidipipe"
|
|
||||||
"go.uber.org/zap"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Client struct {
|
|
||||||
user string
|
|
||||||
address string
|
|
||||||
authMethods []ssh.AuthMethod
|
|
||||||
log *zap.SugaredLogger
|
|
||||||
connected atomic.Bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(address, user string, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Client {
|
|
||||||
return &Client{
|
|
||||||
address: prepareAddress(address),
|
|
||||||
user: user,
|
|
||||||
authMethods: auth,
|
|
||||||
log: log.With(zap.String("sshServer", address)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareAddress(address string) string {
|
|
||||||
_, _, err := net.SplitHostPort(address)
|
|
||||||
if err != nil && errtools.IsPortMissingErr(err) {
|
|
||||||
return net.JoinHostPort(address, "22")
|
|
||||||
}
|
|
||||||
return address
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) Connect(ctx context.Context) {
|
|
||||||
if c.connected.Load() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer c.connected.Store(false)
|
|
||||||
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
|
||||||
for {
|
|
||||||
c.connected.Store(true)
|
|
||||||
err := c.connect(ctx)
|
|
||||||
if err != nil {
|
|
||||||
c.log.Error("connect error:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Sleep(backoffTime())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
|
||||||
// will try to re-connect
|
|
||||||
func (c *Client) connect(ctx context.Context) error {
|
|
||||||
c.log.Debug("connecting")
|
|
||||||
sshConfig := &ssh.ClientConfig{
|
|
||||||
User: c.user,
|
|
||||||
Auth: c.authMethods,
|
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
||||||
}
|
|
||||||
|
|
||||||
var sshClient *ssh.Client
|
|
||||||
var errConnect error
|
|
||||||
|
|
||||||
sshClient, errConnect = dialSSH(ctx, c.address, sshConfig)
|
|
||||||
if errConnect != nil {
|
|
||||||
return errConnect
|
|
||||||
}
|
|
||||||
|
|
||||||
// always disconnect when function returns
|
|
||||||
defer sshClient.Close()
|
|
||||||
defer c.log.Debug("disconnecting")
|
|
||||||
|
|
||||||
c.log.Debug("connected")
|
|
||||||
|
|
||||||
<-ctx.Done()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// // connect once to the SSH server. if the connection breaks, we return error and the caller
|
|
||||||
// // will try to re-connect
|
|
||||||
// func connectToSshAndServe2(
|
|
||||||
// ctx context.Context,
|
|
||||||
// address string,
|
|
||||||
// authConfig types.Auth,
|
|
||||||
// auth ssh.AuthMethod,
|
|
||||||
// log *zap.SugaredLogger,
|
|
||||||
// ) error {
|
|
||||||
// log = log.With(zap.String("sshServer", address))
|
|
||||||
// log.Debug("connecting")
|
|
||||||
// sshConfig := &ssh.ClientConfig{
|
|
||||||
// User: authConfig.User,
|
|
||||||
// Auth: []ssh.AuthMethod{auth},
|
|
||||||
// HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// var sshClient *ssh.Client
|
|
||||||
// var errConnect error
|
|
||||||
//
|
|
||||||
// sshClient, errConnect = dialSSH(ctx, address, sshConfig)
|
|
||||||
// if errConnect != nil {
|
|
||||||
// return errConnect
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // always disconnect when function returns
|
|
||||||
// defer sshClient.Close()
|
|
||||||
// defer log.Debug("disconnecting")
|
|
||||||
//
|
|
||||||
// log.Debug("connected")
|
|
||||||
//
|
|
||||||
// // each listener in reverseForwardOnePort() can return one error, so make sure channel
|
|
||||||
// // has enough buffering so even if we return from here, the goroutines won't block trying
|
|
||||||
// // to return an error
|
|
||||||
// listenerStopped := make(chan error, len(conf.Forwards))
|
|
||||||
//
|
|
||||||
// for _, forward := range conf.Forwards {
|
|
||||||
// // TODO: "if any fails, tear down all workers" -style error handling would be better
|
|
||||||
// // handled with https://pkg.go.dev/golang.org/x/sync/errgroup?tab=doc
|
|
||||||
// if err := reverseForwardOnePort(
|
|
||||||
// forward,
|
|
||||||
// sshClient,
|
|
||||||
// listenerStopped,
|
|
||||||
// makeLogger("reverseForwardOnePort"),
|
|
||||||
// makeLogger,
|
|
||||||
// ); err != nil {
|
|
||||||
// // closes SSH connection if even one forward Listen() fails
|
|
||||||
// return err
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
//
|
|
||||||
// // we're connected and have succesfully started listening on all reverse forwards, wait
|
|
||||||
// // for either user to ask us to stop or any of the listeners to error
|
|
||||||
// select {
|
|
||||||
// case <-ctx.Done(): // cancel requested
|
|
||||||
// return nil
|
|
||||||
// case listenerFirstErr := <-listenerStopped:
|
|
||||||
// // one or more of the listeners encountered an error. react by closing the connection
|
|
||||||
// // assumes all the other listeners failed too so no teardown necessary
|
|
||||||
// select {
|
|
||||||
// case <-ctx.Done(): // pretty much errors are to be expected if cancellation triggered
|
|
||||||
// return nil
|
|
||||||
// default:
|
|
||||||
// return listenerFirstErr
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
|
||||||
//
|
|
||||||
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
|
||||||
func reverseForwardOnePort(
|
|
||||||
forward Forward,
|
|
||||||
sshClient *ssh.Client,
|
|
||||||
listenerStopped chan<- error,
|
|
||||||
log *zap.SugaredLogger,
|
|
||||||
) error {
|
|
||||||
// reverse listen on remote server port
|
|
||||||
listener, err := sshClient.Listen("tcp", forward.Remote.String())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer listener.Close()
|
|
||||||
log.Debugf("forwarding %s -> %s", forward.Local.String(), forward.Remote.String())
|
|
||||||
|
|
||||||
// handle incoming connections on reverse forwarded tunnel
|
|
||||||
for {
|
|
||||||
client, err := listener.Accept()
|
|
||||||
if err != nil {
|
|
||||||
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// handle the connection in another goroutine, so we can support multiple concurrent
|
|
||||||
// connections on the same port
|
|
||||||
go handleReverseForwardConn(client, forward, log)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
|
||||||
defer client.Close()
|
|
||||||
|
|
||||||
log.Debugf("%s connected", client.RemoteAddr())
|
|
||||||
defer log.Debug("closed")
|
|
||||||
|
|
||||||
remote, err := net.Dial("tcp", forward.Local.String())
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("dial INTO local service error: %s", err.Error())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// pipe data in both directions:
|
|
||||||
// - client => remote
|
|
||||||
// - remote => client
|
|
||||||
//
|
|
||||||
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
|
||||||
// remote endpoint.
|
|
||||||
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
|
||||||
// - this blocks until either of the parties' socket closes (or breaks)
|
|
||||||
if err := bidipipe.Pipe(
|
|
||||||
bidipipe.WithName("client", client),
|
|
||||||
bidipipe.WithName("remote", remote),
|
|
||||||
); err != nil {
|
|
||||||
log.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
|
||||||
dialer := net.Dialer{
|
|
||||||
Timeout: 10 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return ssh.NewClient(clConn, newChan, reqChan), nil
|
|
||||||
}
|
|
@ -1,4 +1,4 @@
|
|||||||
package sshproto
|
package sshtun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
@ -1,4 +1,4 @@
|
|||||||
package sshproto
|
package sshtun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
37
internal/server/driver/ssh/sshtun/forward.go
Normal file
37
internal/server/driver/ssh/sshtun/forward.go
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
package sshtun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/Neur0toxine/sshpoke/pkg/errtools"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Forward struct {
|
||||||
|
// local service to be forwarded
|
||||||
|
Local Endpoint `json:"local"`
|
||||||
|
// remote forwarding port (on remote SSH server network)
|
||||||
|
Remote Endpoint `json:"remote"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func AddrToEndpoint(address string) Endpoint {
|
||||||
|
host, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil && errtools.IsPortMissingErr(err) {
|
||||||
|
return Endpoint{Host: host, Port: 22}
|
||||||
|
}
|
||||||
|
portNum, err := strconv.Atoi(port)
|
||||||
|
if err != nil {
|
||||||
|
portNum = 22
|
||||||
|
}
|
||||||
|
return Endpoint{Host: host, Port: portNum}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Endpoint struct {
|
||||||
|
Host string `json:"host"`
|
||||||
|
Port int `json:"port"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (endpoint *Endpoint) String() string {
|
||||||
|
return fmt.Sprintf("%s:%d", endpoint.Host, endpoint.Port)
|
||||||
|
}
|
21
internal/server/driver/ssh/sshtun/printer.go
Normal file
21
internal/server/driver/ssh/sshtun/printer.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package sshtun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
||||||
|
return func(session *ssh.Session) {
|
||||||
|
stdout, err := session.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
scan := bufio.NewScanner(stdout)
|
||||||
|
for scan.Scan() {
|
||||||
|
log.Debug(scan.Text())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
76
internal/server/driver/ssh/sshtun/sish_compat.go
Normal file
76
internal/server/driver/ssh/sshtun/sish_compat.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
package sshtun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type sishHostListener struct {
|
||||||
|
parent *ssh.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSishHostListener(parent *ssh.Client) *sishHostListener {
|
||||||
|
return &sishHostListener{parent: parent}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sishHostListener) ListenFakeRemoteHost(ep Endpoint) (net.Listener, error) {
|
||||||
|
c.doHandleForwardsOnce()
|
||||||
|
m := channelForwardMsg{
|
||||||
|
ep.Host,
|
||||||
|
uint32(ep.Port),
|
||||||
|
}
|
||||||
|
// send message
|
||||||
|
ok, resp, err := c.parent.SendRequest("tcpip-forward", true, ssh.Marshal(&m))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("ssh: tcpip-forward request denied by peer")
|
||||||
|
}
|
||||||
|
|
||||||
|
laddr := &net.TCPAddr{
|
||||||
|
IP: net.ParseIP("127.0.0.1"),
|
||||||
|
Port: ep.Port,
|
||||||
|
}
|
||||||
|
if ep.Port == 0 {
|
||||||
|
var p struct {
|
||||||
|
Port uint32
|
||||||
|
}
|
||||||
|
if err := ssh.Unmarshal(resp, &p); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
laddr.Port = int(p.Port)
|
||||||
|
}
|
||||||
|
c.registerForward(laddr)
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sishHostListener) registerForward(addr *net.TCPAddr) {
|
||||||
|
cl := reflect.ValueOf(c.parent).Elem()
|
||||||
|
forwardsUn := cl.FieldByName("forwards")
|
||||||
|
forwards := reflect.NewAt(forwardsUn.Type(), unsafe.Pointer(forwardsUn.UnsafeAddr())).Elem()
|
||||||
|
forwardVal := forwards.MethodByName("add").Call([]reflect.Value{reflect.ValueOf(addr)})[0]
|
||||||
|
_ = forwardVal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *sishHostListener) doHandleForwardsOnce() {
|
||||||
|
cl := reflect.ValueOf(c.parent)
|
||||||
|
clVal := cl.Elem()
|
||||||
|
onceField := clVal.FieldByName("handleForwardsOnce")
|
||||||
|
once := reflect.NewAt(onceField.Type(), unsafe.Pointer(onceField.UnsafeAddr())).Interface().(*sync.Once)
|
||||||
|
handleForwards := clVal.MethodByName("handleForwards")
|
||||||
|
once.Do(func() {
|
||||||
|
handleForwards.Call(nil)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type channelForwardMsg struct {
|
||||||
|
addr string
|
||||||
|
rport uint32
|
||||||
|
}
|
210
internal/server/driver/ssh/sshtun/ssh.go
Normal file
210
internal/server/driver/ssh/sshtun/ssh.go
Normal file
@ -0,0 +1,210 @@
|
|||||||
|
package sshtun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/function61/gokit/app/backoff"
|
||||||
|
"github.com/function61/gokit/io/bidipipe"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionCallback func(*ssh.Session)
|
||||||
|
|
||||||
|
var NoopSessionCallback SessionCallback = func(*ssh.Session) {}
|
||||||
|
|
||||||
|
type Tunnel struct {
|
||||||
|
user string
|
||||||
|
address Endpoint
|
||||||
|
forward Forward
|
||||||
|
authMethods []ssh.AuthMethod
|
||||||
|
log *zap.SugaredLogger
|
||||||
|
connected atomic.Bool
|
||||||
|
fakeRemoteHost bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(address, user string, fakeRemoteHost bool,
|
||||||
|
forward Forward, auth []ssh.AuthMethod, log *zap.SugaredLogger) *Tunnel {
|
||||||
|
return &Tunnel{
|
||||||
|
address: AddrToEndpoint(address),
|
||||||
|
user: user,
|
||||||
|
fakeRemoteHost: fakeRemoteHost,
|
||||||
|
forward: forward,
|
||||||
|
authMethods: auth,
|
||||||
|
log: log.With(zap.String("sshServer", address)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
||||||
|
if c.connected.Load() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer c.connected.Store(false)
|
||||||
|
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
||||||
|
for {
|
||||||
|
c.connected.Store(true)
|
||||||
|
err := c.connect(ctx, sessionCb)
|
||||||
|
if err != nil {
|
||||||
|
c.log.Error("connect error:", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(backoffTime())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
||||||
|
// will try to re-connect
|
||||||
|
func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
||||||
|
c.log.Debug("connecting")
|
||||||
|
sshConfig := &ssh.ClientConfig{
|
||||||
|
User: c.user,
|
||||||
|
Auth: c.authMethods,
|
||||||
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
}
|
||||||
|
|
||||||
|
var sshClient *ssh.Client
|
||||||
|
var errConnect error
|
||||||
|
|
||||||
|
sshClient, errConnect = dialSSH(ctx, c.address.String(), sshConfig)
|
||||||
|
if errConnect != nil {
|
||||||
|
return errConnect
|
||||||
|
}
|
||||||
|
|
||||||
|
defer sshClient.Close()
|
||||||
|
defer c.log.Debug("disconnecting")
|
||||||
|
|
||||||
|
c.log.Debug("connected")
|
||||||
|
listenerStopped := make(chan error)
|
||||||
|
|
||||||
|
sess, err := sshClient.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
c.log.Errorf("session error: %s", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer sess.Close()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
if sessionCb == nil {
|
||||||
|
sessionCb = func(*ssh.Session) {}
|
||||||
|
}
|
||||||
|
wg.Add(2)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
sessionCb(sess)
|
||||||
|
}()
|
||||||
|
reverseErr := make(chan error)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
reverseErr <- c.reverseForwardOnePort(sshClient, listenerStopped)
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err := <-reverseErr; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
case listenerFirstErr := <-listenerStopped:
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return listenerFirstErr
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// blocking flow: calls Listen() on the SSH connection, and if succeeds returns non-nil error
|
||||||
|
//
|
||||||
|
// nonblocking flow: if Accept() call fails, stops goroutine and returns error on ch listenerStopped
|
||||||
|
func (c *Tunnel) reverseForwardOnePort(sshClient *ssh.Client, listenerStopped chan<- error) error {
|
||||||
|
if c.fakeRemoteHost {
|
||||||
|
newSishHostListener(sshClient).ListenFakeRemoteHost(c.forward.Remote)
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
listener, err := sshClient.Listen("tcp", c.forward.Remote.String())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer listener.Close()
|
||||||
|
c.log.Debugf("forwarding %s <- %s", c.forward.Local.String(), c.forward.Remote.String())
|
||||||
|
|
||||||
|
for {
|
||||||
|
client, err := listener.Accept()
|
||||||
|
if err != nil {
|
||||||
|
listenerStopped <- fmt.Errorf("error on Accept(): %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
go handleReverseForwardConn(client, c.forward, c.log)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Tunnel) listenTCPWithoutResolving() {
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleReverseForwardConn(client net.Conn, forward Forward, log *zap.SugaredLogger) {
|
||||||
|
defer client.Close()
|
||||||
|
|
||||||
|
log.Debugf("%s connected", client.RemoteAddr())
|
||||||
|
defer log.Debug("closed")
|
||||||
|
|
||||||
|
remote, err := net.Dial("tcp", forward.Local.String())
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("dial INTO local service error: %s", err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// pipe data in both directions:
|
||||||
|
// - client => remote
|
||||||
|
// - remote => client
|
||||||
|
//
|
||||||
|
// - in effect, we act as a proxy between the reverse tunnel's client and locally-dialed
|
||||||
|
// remote endpoint.
|
||||||
|
// - the "client" and "remote" strings we give Pipe() is just for error&log messages
|
||||||
|
// - this blocks until either of the parties' socket closes (or breaks)
|
||||||
|
if err := bidipipe.Pipe(
|
||||||
|
bidipipe.WithName("client", client),
|
||||||
|
bidipipe.WithName("remote", remote),
|
||||||
|
); err != nil {
|
||||||
|
log.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func dialSSH(ctx context.Context, addr string, sshConfig *ssh.ClientConfig) (*ssh.Client, error) {
|
||||||
|
dialer := net.Dialer{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
clConn, newChan, reqChan, err := ssh.NewClientConn(conn, addr, sshConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return ssh.NewClient(clConn, newChan, reqChan), nil
|
||||||
|
}
|
@ -1,249 +0,0 @@
|
|||||||
// Copyright 2017, The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE.md file.
|
|
||||||
|
|
||||||
package sshtun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type TunnelMode uint8
|
|
||||||
|
|
||||||
func (t TunnelMode) String() string {
|
|
||||||
switch t {
|
|
||||||
case TunnelForward:
|
|
||||||
return "->"
|
|
||||||
case TunnelReverse:
|
|
||||||
return "<-"
|
|
||||||
default:
|
|
||||||
return "<?>"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
TunnelForward TunnelMode = iota
|
|
||||||
TunnelReverse
|
|
||||||
)
|
|
||||||
|
|
||||||
type logger interface {
|
|
||||||
Printf(string, ...interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
type Tunnel struct {
|
|
||||||
Auth []ssh.AuthMethod
|
|
||||||
HostKeys ssh.HostKeyCallback
|
|
||||||
Mode TunnelMode
|
|
||||||
User string
|
|
||||||
HostAddr string
|
|
||||||
BindAddr string
|
|
||||||
DialAddr string
|
|
||||||
RetryInterval time.Duration
|
|
||||||
KeepAlive KeepAliveConfig
|
|
||||||
Logger logger
|
|
||||||
}
|
|
||||||
|
|
||||||
type KeepAliveConfig struct {
|
|
||||||
// Interval is the amount of time in seconds to wait before the
|
|
||||||
// Tunnel client will send a keep-alive message to ensure some minimum
|
|
||||||
// traffic on the SSH connection.
|
|
||||||
Interval uint
|
|
||||||
|
|
||||||
// CountMax is the maximum number of consecutive failed responses to
|
|
||||||
// keep-alive messages the client is willing to tolerate before considering
|
|
||||||
// the SSH connection as dead.
|
|
||||||
CountMax uint
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tunnel) String() string {
|
|
||||||
var left, right string
|
|
||||||
switch t.Mode {
|
|
||||||
case TunnelForward:
|
|
||||||
left, right = t.BindAddr, t.DialAddr
|
|
||||||
case TunnelReverse:
|
|
||||||
left, right = t.DialAddr, t.BindAddr
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%s@%s | %s %s %s", t.User, t.HostAddr, left, t.Mode, right)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tunnel) Bind(ctx context.Context, wg *sync.WaitGroup) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
for {
|
|
||||||
var once sync.Once // Only print errors once per session
|
|
||||||
func() {
|
|
||||||
// Connect to the server host via SSH.
|
|
||||||
cl, err := ssh.Dial("tcp", t.HostAddr, &ssh.ClientConfig{
|
|
||||||
User: t.User,
|
|
||||||
Auth: t.Auth,
|
|
||||||
HostKeyCallback: t.HostKeys,
|
|
||||||
Timeout: 5 * time.Second,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
once.Do(func() { t.Logger.Printf("(%v) SSH dial error: %v", t, err) })
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
go t.keepAliveMonitor(&once, wg, cl)
|
|
||||||
defer cl.Close()
|
|
||||||
|
|
||||||
// Attempt to bind to the inbound socket.
|
|
||||||
var ln net.Listener
|
|
||||||
switch t.Mode {
|
|
||||||
case TunnelForward:
|
|
||||||
ln, err = net.Listen("tcp", t.BindAddr)
|
|
||||||
case TunnelReverse:
|
|
||||||
ln, err = cl.Listen("tcp", t.BindAddr)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
once.Do(func() { t.Logger.Printf("(%v) bind error: %v", t, err) })
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// The socket is bound. Make sure we close it eventually.
|
|
||||||
bindCtx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
go func() {
|
|
||||||
cl.Wait()
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
<-bindCtx.Done()
|
|
||||||
once.Do(func() {}) // Suppress future errors
|
|
||||||
ln.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
t.Logger.Printf("(%v) bound Tunnel", t)
|
|
||||||
defer t.Logger.Printf("(%v) collapsed Tunnel", t)
|
|
||||||
|
|
||||||
// Accept all incoming connections.
|
|
||||||
for {
|
|
||||||
cn1, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
once.Do(func() { t.Logger.Printf("(%v) accept error: %v", t, err) })
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
go t.dialTunnel(bindCtx, wg, cl, cn1)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(t.RetryInterval):
|
|
||||||
t.Logger.Printf("(%v) retrying...", t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh.Client, cn1 net.Conn) {
|
|
||||||
defer wg.Done()
|
|
||||||
|
|
||||||
// The inbound connection is established. Make sure we close it eventually.
|
|
||||||
connCtx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
go func() {
|
|
||||||
<-connCtx.Done()
|
|
||||||
cn1.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Establish the outbound connection.
|
|
||||||
var cn2 net.Conn
|
|
||||||
var err error
|
|
||||||
switch t.Mode {
|
|
||||||
case TunnelForward:
|
|
||||||
cn2, err = client.Dial("tcp", t.DialAddr)
|
|
||||||
case TunnelReverse:
|
|
||||||
cn2, err = net.Dial("tcp", t.DialAddr)
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
t.Logger.Printf("(%v) dial error: %v", t, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-connCtx.Done()
|
|
||||||
cn2.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
t.Logger.Printf("(%v) connection established", t)
|
|
||||||
defer t.Logger.Printf("(%v) connection closed", t)
|
|
||||||
|
|
||||||
// Copy bytes from one connection to the other until one side closes.
|
|
||||||
var once sync.Once
|
|
||||||
var wg2 sync.WaitGroup
|
|
||||||
wg2.Add(2)
|
|
||||||
go func() {
|
|
||||||
defer wg2.Done()
|
|
||||||
defer cancel()
|
|
||||||
if _, err := io.Copy(cn1, cn2); err != nil {
|
|
||||||
once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) })
|
|
||||||
}
|
|
||||||
once.Do(func() {}) // Suppress future errors
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer wg2.Done()
|
|
||||||
defer cancel()
|
|
||||||
if _, err := io.Copy(cn2, cn1); err != nil {
|
|
||||||
once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) })
|
|
||||||
}
|
|
||||||
once.Do(func() {}) // Suppress future errors
|
|
||||||
}()
|
|
||||||
wg2.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
// keepAliveMonitor periodically sends messages to invoke a response.
|
|
||||||
// If the server does not respond after some period of time,
|
|
||||||
// assume that the underlying net.Conn abruptly died.
|
|
||||||
func (t Tunnel) keepAliveMonitor(once *sync.Once, wg *sync.WaitGroup, client *ssh.Client) {
|
|
||||||
defer wg.Done()
|
|
||||||
if t.KeepAlive.Interval == 0 || t.KeepAlive.CountMax == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detect when the SSH connection is closed.
|
|
||||||
wait := make(chan error, 1)
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
wait <- client.Wait()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Repeatedly check if the remote server is still alive.
|
|
||||||
var aliveCount int32
|
|
||||||
ticker := time.NewTicker(time.Duration(t.KeepAlive.Interval) * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case err := <-wait:
|
|
||||||
if err != nil && err != io.EOF {
|
|
||||||
once.Do(func() { t.Logger.Printf("(%v) SSH error: %v", t, err) })
|
|
||||||
}
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.KeepAlive.CountMax) {
|
|
||||||
once.Do(func() { t.Logger.Printf("(%v) SSH keep-alive termination", t) })
|
|
||||||
client.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
|
|
||||||
if err == nil {
|
|
||||||
atomic.StoreInt32(&aliveCount, 0)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,509 +0,0 @@
|
|||||||
// Copyright 2017, The Go Authors. All rights reserved.
|
|
||||||
// Use of this source code is governed by a BSD-style
|
|
||||||
// license that can be found in the LICENSE.md file.
|
|
||||||
|
|
||||||
package sshtun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"crypto/md5"
|
|
||||||
"crypto/rsa"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"io/ioutil"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
|
||||||
"reflect"
|
|
||||||
"strconv"
|
|
||||||
"sync"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
type testLogger struct {
|
|
||||||
*testing.T // Already has Fatalf method
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t testLogger) Printf(f string, x ...interface{}) { t.Logf(f, x...) }
|
|
||||||
|
|
||||||
func TestTunnel(t *testing.T) {
|
|
||||||
rootWG := new(sync.WaitGroup)
|
|
||||||
defer rootWG.Wait()
|
|
||||||
rootCtx, cancelAll := context.WithCancel(context.Background())
|
|
||||||
defer cancelAll()
|
|
||||||
|
|
||||||
// Open all of the TCP sockets needed for the test.
|
|
||||||
tcpLn0 := openListener(t) // Start of the chain
|
|
||||||
tcpLn1 := openListener(t) // Mid-point of the chain
|
|
||||||
tcpLn2 := openListener(t) // End of the chain
|
|
||||||
srvLn0 := openListener(t) // Socket for SSH server in reverse Mode
|
|
||||||
srvLn1 := openListener(t) // Socket for SSH server in forward Mode
|
|
||||||
|
|
||||||
tcpLn0.Close() // To be later binded by the reverse Tunnel
|
|
||||||
tcpLn1.Close() // To be later binded by the forward Tunnel
|
|
||||||
go closeWhenDone(rootCtx, tcpLn2)
|
|
||||||
go closeWhenDone(rootCtx, srvLn0)
|
|
||||||
go closeWhenDone(rootCtx, srvLn1)
|
|
||||||
|
|
||||||
// Generate keys for both the servers and clients.
|
|
||||||
clientPriv0, clientPub0 := generateKeys(t)
|
|
||||||
clientPriv1, clientPub1 := generateKeys(t)
|
|
||||||
serverPriv0, serverPub0 := generateKeys(t)
|
|
||||||
serverPriv1, serverPub1 := generateKeys(t)
|
|
||||||
|
|
||||||
// Start the SSH servers.
|
|
||||||
rootWG.Add(2)
|
|
||||||
go func() {
|
|
||||||
defer rootWG.Done()
|
|
||||||
runServer(t, rootCtx, srvLn0, serverPriv0, clientPub0, clientPub1)
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
defer rootWG.Done()
|
|
||||||
runServer(t, rootCtx, srvLn1, serverPriv1, clientPub0, clientPub1)
|
|
||||||
}()
|
|
||||||
|
|
||||||
wg := new(sync.WaitGroup)
|
|
||||||
defer wg.Wait()
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
// Create the Tunnel configurations.
|
|
||||||
tn0 := Tunnel{
|
|
||||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv0)},
|
|
||||||
HostKeys: ssh.FixedHostKey(serverPub0),
|
|
||||||
Mode: TunnelReverse, // Reverse Tunnel
|
|
||||||
User: "user0",
|
|
||||||
HostAddr: srvLn0.Addr().String(),
|
|
||||||
BindAddr: tcpLn0.Addr().String(),
|
|
||||||
DialAddr: tcpLn1.Addr().String(),
|
|
||||||
Logger: testLogger{t},
|
|
||||||
}
|
|
||||||
tn1 := Tunnel{
|
|
||||||
Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv1)},
|
|
||||||
HostKeys: ssh.FixedHostKey(serverPub1),
|
|
||||||
Mode: TunnelForward, // Forward Tunnel
|
|
||||||
User: "user1",
|
|
||||||
HostAddr: srvLn1.Addr().String(),
|
|
||||||
BindAddr: tcpLn1.Addr().String(),
|
|
||||||
DialAddr: tcpLn2.Addr().String(),
|
|
||||||
Logger: testLogger{t},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start the SSH client tunnels.
|
|
||||||
wg.Add(2)
|
|
||||||
go tn0.Bind(ctx, wg)
|
|
||||||
go tn1.Bind(ctx, wg)
|
|
||||||
|
|
||||||
t.Log("test started")
|
|
||||||
done := make(chan bool, 10)
|
|
||||||
|
|
||||||
// Start all the transmitters.
|
|
||||||
for i := 0; i < cap(done); i++ {
|
|
||||||
i := i
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
rnd := rand.New(rand.NewSource(int64(i)))
|
|
||||||
hash := md5.New()
|
|
||||||
size := uint32((1 << 10) + rnd.Intn(1<<20))
|
|
||||||
buf4 := make([]byte, 4)
|
|
||||||
binary.LittleEndian.PutUint32(buf4, size)
|
|
||||||
|
|
||||||
cnStart, err := net.Dial("tcp", tcpLn0.Addr().String())
|
|
||||||
if err != nil {
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer cnStart.Close()
|
|
||||||
if _, err := cnStart.Write(buf4); err != nil {
|
|
||||||
t.Errorf("write size error: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
r := io.LimitReader(rnd, int64(size))
|
|
||||||
w := io.MultiWriter(cnStart, hash)
|
|
||||||
if _, err := io.Copy(w, r); err != nil {
|
|
||||||
t.Errorf("copy error: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if _, err := cnStart.Write(hash.Sum(nil)); err != nil {
|
|
||||||
t.Errorf("write hash error: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err := cnStart.Close(); err != nil {
|
|
||||||
t.Errorf("close error: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start all the receivers.
|
|
||||||
for i := 0; i < cap(done); i++ {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
hash := md5.New()
|
|
||||||
buf4 := make([]byte, 4)
|
|
||||||
|
|
||||||
cnEnd, err := tcpLn2.Accept()
|
|
||||||
if err != nil {
|
|
||||||
time.Sleep(10 * time.Millisecond)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
defer cnEnd.Close()
|
|
||||||
|
|
||||||
if _, err := io.ReadFull(cnEnd, buf4); err != nil {
|
|
||||||
t.Errorf("read size error: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
size := binary.LittleEndian.Uint32(buf4)
|
|
||||||
r := io.LimitReader(cnEnd, int64(size))
|
|
||||||
if _, err := io.Copy(hash, r); err != nil {
|
|
||||||
t.Errorf("copy error: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
wantHash, err := ioutil.ReadAll(cnEnd)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("read hash error: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if err := cnEnd.Close(); err != nil {
|
|
||||||
t.Errorf("close error: %v", err)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotHash := hash.Sum(nil); !bytes.Equal(gotHash, wantHash) {
|
|
||||||
t.Errorf("hash mismatch:\ngot %x\nwant %x", gotHash, wantHash)
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
done <- true
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < cap(done); i++ {
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-time.After(10 * time.Second):
|
|
||||||
t.Errorf("timed out: %d remaining", cap(done)-i)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.Log("test complete")
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateKeys generates a random pair of SSH private and public keys.
|
|
||||||
func generateKeys(t *testing.T) (priv ssh.Signer, pub ssh.PublicKey) {
|
|
||||||
rnd := rand.New(rand.NewSource(time.Now().Unix()))
|
|
||||||
rsaKey, err := rsa.GenerateKey(rnd, 1024)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate RSA key pair: %v", err)
|
|
||||||
}
|
|
||||||
priv, err = ssh.NewSignerFromKey(rsaKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate signer: %v", err)
|
|
||||||
}
|
|
||||||
pub, err = ssh.NewPublicKey(&rsaKey.PublicKey)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unable to generate public key: %v", err)
|
|
||||||
}
|
|
||||||
return priv, pub
|
|
||||||
}
|
|
||||||
|
|
||||||
func openListener(t *testing.T) net.Listener {
|
|
||||||
ln, err := net.Listen("tcp", ":0")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("listen error: %v", err)
|
|
||||||
}
|
|
||||||
return ln
|
|
||||||
}
|
|
||||||
|
|
||||||
// runServer starts an SSH server capable of handling forward and reverse
|
|
||||||
// TCP tunnels. This function blocks for the entire duration that the
|
|
||||||
// server is running and can be stopped by canceling the context.
|
|
||||||
//
|
|
||||||
// The server listens on the provided Listener and will present to clients
|
|
||||||
// a certificate from serverKey and will only accept users that match
|
|
||||||
// the provided clientKeys. Only users of the name "User%d" are allowed where
|
|
||||||
// the ID number is the index for the specified client key provided.
|
|
||||||
func runServer(t *testing.T, ctx context.Context, ln net.Listener, serverKey ssh.Signer, clientKeys ...ssh.PublicKey) {
|
|
||||||
wg := new(sync.WaitGroup)
|
|
||||||
defer wg.Wait()
|
|
||||||
|
|
||||||
// Generate SSH server configuration.
|
|
||||||
conf := ssh.ServerConfig{
|
|
||||||
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
|
||||||
var uid int
|
|
||||||
_, err := fmt.Sscanf(c.User(), "User%d", &uid)
|
|
||||||
if err != nil || uid >= len(clientKeys) || !bytes.Equal(clientKeys[uid].Marshal(), pubKey.Marshal()) {
|
|
||||||
return nil, fmt.Errorf("unknown public key for %q", c.User())
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
conf.AddHostKey(serverKey)
|
|
||||||
|
|
||||||
// Handle every SSH client connection.
|
|
||||||
for {
|
|
||||||
tcpCn, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if !isDone(ctx) {
|
|
||||||
t.Errorf("accept error: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
wg.Add(1)
|
|
||||||
go handleServerConn(t, ctx, wg, tcpCn, &conf)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleServerConn handles a single SSH connection.
|
|
||||||
func handleServerConn(t *testing.T, ctx context.Context, wg *sync.WaitGroup, tcpCn net.Conn, conf *ssh.ServerConfig) {
|
|
||||||
defer wg.Done()
|
|
||||||
go closeWhenDone(ctx, tcpCn)
|
|
||||||
defer tcpCn.Close()
|
|
||||||
|
|
||||||
sshCn, chans, reqs, err := ssh.NewServerConn(tcpCn, conf)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("new connection error: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
go closeWhenDone(ctx, sshCn)
|
|
||||||
defer sshCn.Close()
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go handleServerChannels(t, ctx, wg, sshCn, chans)
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go handleServerRequests(t, ctx, wg, sshCn, reqs)
|
|
||||||
|
|
||||||
if err := sshCn.Wait(); err != nil && err != io.EOF && !isDone(ctx) {
|
|
||||||
t.Errorf("connection error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleServerChannels handles new channels on a SSH connection.
|
|
||||||
// The client initiates a new channel when forwarding a TCP dial.
|
|
||||||
func handleServerChannels(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, chans <-chan ssh.NewChannel) {
|
|
||||||
defer wg.Done()
|
|
||||||
for nc := range chans {
|
|
||||||
if nc.ChannelType() != "direct-tcpip" {
|
|
||||||
nc.Reject(ssh.UnknownChannelType, "not implemented")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var args struct {
|
|
||||||
DstHost string
|
|
||||||
DstPort uint32
|
|
||||||
SrcHost string
|
|
||||||
SrcPort uint32
|
|
||||||
}
|
|
||||||
if !unmarshalData(nc.ExtraData(), &args) {
|
|
||||||
nc.Reject(ssh.Prohibited, "invalid request")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// Open a connection for both sides.
|
|
||||||
cn, err := net.Dial("tcp", net.JoinHostPort(args.DstHost, strconv.Itoa(int(args.DstPort))))
|
|
||||||
if err != nil {
|
|
||||||
nc.Reject(ssh.ConnectionFailed, err.Error())
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ch, reqs, err := nc.Accept()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("accept channel error: %v", err)
|
|
||||||
cn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go bidirCopyAndClose(t, ctx, wg, cn, ch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleServerRequests handles new requests on a SSH connection.
|
|
||||||
// The client initiates a new request for binding a local TCP socket.
|
|
||||||
func handleServerRequests(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, reqs <-chan *ssh.Request) {
|
|
||||||
defer wg.Done()
|
|
||||||
for r := range reqs {
|
|
||||||
if !r.WantReply {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if r.Type != "tcpip-forward" {
|
|
||||||
r.Reply(false, nil)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
var args struct {
|
|
||||||
Host string
|
|
||||||
Port uint32
|
|
||||||
}
|
|
||||||
if !unmarshalData(r.Payload, &args) {
|
|
||||||
r.Reply(false, nil)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
ln, err := net.Listen("tcp", net.JoinHostPort(args.Host, strconv.Itoa(int(args.Port))))
|
|
||||||
if err != nil {
|
|
||||||
r.Reply(false, nil)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp struct{ Port uint32 }
|
|
||||||
_, resp.Port = splitHostPort(ln.Addr().String())
|
|
||||||
if err := r.Reply(true, marshalData(resp)); err != nil {
|
|
||||||
t.Errorf("request reply error: %v", err)
|
|
||||||
ln.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go handleLocalListener(t, ctx, wg, sshCn, ln, args.Host)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleLocalListener handles every new connection on the provided socket.
|
|
||||||
// All local connections will be forwarded to the client via a new channel.
|
|
||||||
func handleLocalListener(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, ln net.Listener, host string) {
|
|
||||||
defer wg.Done()
|
|
||||||
go closeWhenDone(ctx, ln)
|
|
||||||
defer ln.Close()
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Open a connection for both sides.
|
|
||||||
cn, err := ln.Accept()
|
|
||||||
if err != nil {
|
|
||||||
if !isDone(ctx) {
|
|
||||||
t.Errorf("accept error: %v", err)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
var args struct {
|
|
||||||
DstHost string
|
|
||||||
DstPort uint32
|
|
||||||
SrcHost string
|
|
||||||
SrcPort uint32
|
|
||||||
}
|
|
||||||
args.DstHost, args.DstPort = splitHostPort(cn.LocalAddr().String())
|
|
||||||
args.SrcHost, args.SrcPort = splitHostPort(cn.RemoteAddr().String())
|
|
||||||
args.DstHost = host // This must match on client side!
|
|
||||||
ch, reqs, err := sshCn.OpenChannel("forwarded-tcpip", marshalData(args))
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("open channel error: %v", err)
|
|
||||||
cn.Close()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
go ssh.DiscardRequests(reqs)
|
|
||||||
|
|
||||||
wg.Add(1)
|
|
||||||
go bidirCopyAndClose(t, ctx, wg, cn, ch)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// bidirCopyAndClose performs a bi-directional copy on both connections
|
|
||||||
// until either side closes the connection or the context is canceled.
|
|
||||||
// This will close both connections before returning.
|
|
||||||
func bidirCopyAndClose(t *testing.T, ctx context.Context, wg *sync.WaitGroup, c1, c2 io.ReadWriteCloser) {
|
|
||||||
defer wg.Done()
|
|
||||||
go closeWhenDone(ctx, c1)
|
|
||||||
go closeWhenDone(ctx, c2)
|
|
||||||
defer c1.Close()
|
|
||||||
defer c2.Close()
|
|
||||||
|
|
||||||
errc := make(chan error, 2)
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(c1, c2)
|
|
||||||
errc <- err
|
|
||||||
}()
|
|
||||||
go func() {
|
|
||||||
_, err := io.Copy(c2, c1)
|
|
||||||
errc <- err
|
|
||||||
}()
|
|
||||||
if err := <-errc; err != nil && err != io.EOF && !isDone(ctx) {
|
|
||||||
t.Errorf("copy error: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// unmarshalData parses b into s, where s is a pointer to a struct.
|
|
||||||
// Only unexported fields of type uint32 or string are allowed.
|
|
||||||
func unmarshalData(b []byte, s interface{}) bool {
|
|
||||||
v := reflect.ValueOf(s)
|
|
||||||
if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
|
|
||||||
panic("destination must be pointer to struct")
|
|
||||||
}
|
|
||||||
v = v.Elem()
|
|
||||||
for i := 0; i < v.NumField(); i++ {
|
|
||||||
switch v.Type().Field(i).Type.Kind() {
|
|
||||||
case reflect.Uint32:
|
|
||||||
if len(b) < 4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
v.Field(i).Set(reflect.ValueOf(binary.BigEndian.Uint32(b)))
|
|
||||||
b = b[4:]
|
|
||||||
case reflect.String:
|
|
||||||
if len(b) < 4 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
n := binary.BigEndian.Uint32(b)
|
|
||||||
b = b[4:]
|
|
||||||
if uint64(len(b)) < uint64(n) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
v.Field(i).Set(reflect.ValueOf(string(b[:n])))
|
|
||||||
b = b[n:]
|
|
||||||
default:
|
|
||||||
panic("invalid field type: " + v.Type().Field(i).Type.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return len(b) == 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// marshalData serializes s into b, where s is a struct (or a pointer to one).
|
|
||||||
// Only unexported fields of type uint32 or string are allowed.
|
|
||||||
func marshalData(s interface{}) (b []byte) {
|
|
||||||
v := reflect.ValueOf(s)
|
|
||||||
if v.IsValid() && v.Kind() == reflect.Ptr {
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
if !v.IsValid() || v.Kind() != reflect.Struct {
|
|
||||||
panic("source must be a struct")
|
|
||||||
}
|
|
||||||
var arr32 [4]byte
|
|
||||||
for i := 0; i < v.NumField(); i++ {
|
|
||||||
switch v.Type().Field(i).Type.Kind() {
|
|
||||||
case reflect.Uint32:
|
|
||||||
binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Uint()))
|
|
||||||
b = append(b, arr32[:]...)
|
|
||||||
case reflect.String:
|
|
||||||
binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Len()))
|
|
||||||
b = append(b, arr32[:]...)
|
|
||||||
b = append(b, v.Field(i).String()...)
|
|
||||||
default:
|
|
||||||
panic("invalid field type: " + v.Type().Field(i).Type.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func splitHostPort(s string) (string, uint32) {
|
|
||||||
host, port, _ := net.SplitHostPort(s)
|
|
||||||
p, _ := strconv.Atoi(port)
|
|
||||||
return host, uint32(p)
|
|
||||||
}
|
|
||||||
|
|
||||||
func closeWhenDone(ctx context.Context, c io.Closer) {
|
|
||||||
<-ctx.Done()
|
|
||||||
c.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func isDone(ctx context.Context) bool {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
@ -16,7 +16,7 @@ func MessageToAppEvent(event *pb.EventMessage) dto.Event {
|
|||||||
IP: net.ParseIP(event.GetContainer().GetIp()),
|
IP: net.ParseIP(event.GetContainer().GetIp()),
|
||||||
Port: uint16(event.GetContainer().GetPort()),
|
Port: uint16(event.GetContainer().GetPort()),
|
||||||
Server: event.GetContainer().GetServer(),
|
Server: event.GetContainer().GetServer(),
|
||||||
Prefix: event.GetContainer().GetPrefix(),
|
RemoteHost: event.GetContainer().GetRemoteHost(),
|
||||||
Domain: event.GetContainer().GetDomain(),
|
Domain: event.GetContainer().GetDomain(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -31,7 +31,7 @@ func AppEventToMessage(event dto.Event) *pb.EventMessage {
|
|||||||
Ip: event.Container.IP.String(),
|
Ip: event.Container.IP.String(),
|
||||||
Port: uint32(event.Container.Port),
|
Port: uint32(event.Container.Port),
|
||||||
Server: event.Container.Server,
|
Server: event.Container.Server,
|
||||||
Prefix: event.Container.Prefix,
|
RemoteHost: event.Container.RemoteHost,
|
||||||
Domain: event.Container.Domain,
|
Domain: event.Container.Domain,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -41,6 +41,6 @@ type Container struct {
|
|||||||
IP net.IP `json:"ip"`
|
IP net.IP `json:"ip"`
|
||||||
Port uint16 `json:"port"`
|
Port uint16 `json:"port"`
|
||||||
Server string `json:"-"`
|
Server string `json:"-"`
|
||||||
Prefix string `json:"prefix"`
|
RemoteHost string `json:"remote_host"`
|
||||||
Domain string `json:"domain"`
|
Domain string `json:"domain"`
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ message Container {
|
|||||||
string ip = 3;
|
string ip = 3;
|
||||||
uint32 port = 4;
|
uint32 port = 4;
|
||||||
string server = 5;
|
string server = 5;
|
||||||
string prefix = 6;
|
string remote_host = 6;
|
||||||
string domain = 7;
|
string domain = 7;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user