ssh: mode parameter support, banner support
This commit is contained in:
parent
273482e3fa
commit
408f88ebd2
@ -2,6 +2,7 @@ package ssh
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"path"
|
||||
"strconv"
|
||||
@ -16,6 +17,8 @@ import (
|
||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||
)
|
||||
|
||||
var ErrAlreadyInUse = errors.New("domain is already in use")
|
||||
|
||||
type SSH struct {
|
||||
base.Base
|
||||
params Params
|
||||
@ -45,9 +48,15 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
|
||||
}
|
||||
|
||||
func (d *SSH) forward(val sshtun.Forward) conn {
|
||||
tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.FakeRemoteHost, val, d.auth, d.Log())
|
||||
tun := sshtun.New(d.params.Address,
|
||||
d.params.Auth.User,
|
||||
d.params.FakeRemoteHost,
|
||||
val,
|
||||
d.auth,
|
||||
d.Log())
|
||||
ctx, cancel := context.WithCancel(d.Context())
|
||||
go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String())))
|
||||
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
|
||||
go tun.Connect(ctx, sshtun.StdoutPrinterBannerCallback(tunDbgLog), sshtun.StdoutPrinterSessionCallback(tunDbgLog))
|
||||
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
||||
}
|
||||
|
||||
@ -79,6 +88,9 @@ func (d *SSH) Handle(event dto.Event) error {
|
||||
d.rw.Lock()
|
||||
switch event.Type {
|
||||
case dto.EventStart:
|
||||
if d.params.Mode == types.DomainModeSingle && len(d.conns) > 0 {
|
||||
return ErrAlreadyInUse
|
||||
}
|
||||
conn := d.forward(sshtun.Forward{
|
||||
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
|
||||
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
||||
|
@ -18,7 +18,7 @@ type Forward struct {
|
||||
func AddrToEndpoint(address string) Endpoint {
|
||||
host, port, err := net.SplitHostPort(address)
|
||||
if err != nil && errtools.IsPortMissingErr(err) {
|
||||
return Endpoint{Host: host, Port: 22}
|
||||
return Endpoint{Host: address, Port: 22}
|
||||
}
|
||||
portNum, err := strconv.Atoi(port)
|
||||
if err != nil {
|
||||
|
@ -19,3 +19,10 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback {
|
||||
return func(msg string) error {
|
||||
log.Debug(msg)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ func New(address, user string, fakeRemoteHost bool,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
||||
func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
|
||||
if c.connected.Load() {
|
||||
return
|
||||
}
|
||||
@ -50,7 +50,7 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
||||
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
||||
for {
|
||||
c.connected.Store(true)
|
||||
err := c.connect(ctx, sessionCb)
|
||||
err := c.connect(ctx, bannerCb, sessionCb)
|
||||
if err != nil {
|
||||
c.log.Error("connect error:", err)
|
||||
}
|
||||
@ -67,12 +67,13 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
||||
|
||||
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
||||
// will try to re-connect
|
||||
func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
||||
func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
|
||||
c.log.Debug("connecting")
|
||||
sshConfig := &ssh.ClientConfig{
|
||||
User: c.user,
|
||||
Auth: c.authMethods,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
BannerCallback: bannerCb,
|
||||
}
|
||||
|
||||
var sshClient *ssh.Client
|
||||
@ -89,6 +90,7 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
||||
c.log.Debug("connected")
|
||||
listenerStopped := make(chan error)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
sess, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
c.log.Errorf("session error: %s", err)
|
||||
@ -96,15 +98,16 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
||||
}
|
||||
defer sess.Close()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
if sessionCb == nil {
|
||||
sessionCb = func(*ssh.Session) {}
|
||||
}
|
||||
wg.Add(2)
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
sessionCb(sess)
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
reverseErr := make(chan error)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
Loading…
Reference in New Issue
Block a user