ssh: mode parameter support, banner support

This commit is contained in:
Pavel 2023-11-18 23:02:41 +03:00
parent 273482e3fa
commit 408f88ebd2
4 changed files with 30 additions and 8 deletions

View File

@ -2,6 +2,7 @@ package ssh
import (
"context"
"errors"
"net"
"path"
"strconv"
@ -16,6 +17,8 @@ import (
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
)
var ErrAlreadyInUse = errors.New("domain is already in use")
type SSH struct {
base.Base
params Params
@ -45,9 +48,15 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
}
func (d *SSH) forward(val sshtun.Forward) conn {
tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.FakeRemoteHost, val, d.auth, d.Log())
tun := sshtun.New(d.params.Address,
d.params.Auth.User,
d.params.FakeRemoteHost,
val,
d.auth,
d.Log())
ctx, cancel := context.WithCancel(d.Context())
go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String())))
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
go tun.Connect(ctx, sshtun.StdoutPrinterBannerCallback(tunDbgLog), sshtun.StdoutPrinterSessionCallback(tunDbgLog))
return conn{ctx: ctx, cancel: cancel, tun: tun}
}
@ -79,6 +88,9 @@ func (d *SSH) Handle(event dto.Event) error {
d.rw.Lock()
switch event.Type {
case dto.EventStart:
if d.params.Mode == types.DomainModeSingle && len(d.conns) > 0 {
return ErrAlreadyInUse
}
conn := d.forward(sshtun.Forward{
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
Remote: d.remoteEndpoint(event.Container.RemoteHost),

View File

@ -18,7 +18,7 @@ type Forward struct {
func AddrToEndpoint(address string) Endpoint {
host, port, err := net.SplitHostPort(address)
if err != nil && errtools.IsPortMissingErr(err) {
return Endpoint{Host: host, Port: 22}
return Endpoint{Host: address, Port: 22}
}
portNum, err := strconv.Atoi(port)
if err != nil {

View File

@ -19,3 +19,10 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
}
}
}
func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback {
return func(msg string) error {
log.Debug(msg)
return nil
}
}

View File

@ -41,7 +41,7 @@ func New(address, user string, fakeRemoteHost bool,
}
}
func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
if c.connected.Load() {
return
}
@ -50,7 +50,7 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
for {
c.connected.Store(true)
err := c.connect(ctx, sessionCb)
err := c.connect(ctx, bannerCb, sessionCb)
if err != nil {
c.log.Error("connect error:", err)
}
@ -67,12 +67,13 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
// connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect
func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
c.log.Debug("connecting")
sshConfig := &ssh.ClientConfig{
User: c.user,
Auth: c.authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
BannerCallback: bannerCb,
}
var sshClient *ssh.Client
@ -89,6 +90,7 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
c.log.Debug("connected")
listenerStopped := make(chan error)
var wg sync.WaitGroup
sess, err := sshClient.NewSession()
if err != nil {
c.log.Errorf("session error: %s", err)
@ -96,15 +98,16 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
}
defer sess.Close()
var wg sync.WaitGroup
if sessionCb == nil {
sessionCb = func(*ssh.Session) {}
}
wg.Add(2)
wg.Add(1)
go func() {
defer wg.Done()
sessionCb(sess)
}()
wg.Add(1)
reverseErr := make(chan error)
go func() {
defer wg.Done()