ssh: mode parameter support, banner support

This commit is contained in:
Pavel 2023-11-18 23:02:41 +03:00
parent 273482e3fa
commit 408f88ebd2
4 changed files with 30 additions and 8 deletions

View File

@ -2,6 +2,7 @@ package ssh
import ( import (
"context" "context"
"errors"
"net" "net"
"path" "path"
"strconv" "strconv"
@ -16,6 +17,8 @@ import (
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh" "github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
) )
var ErrAlreadyInUse = errors.New("domain is already in use")
type SSH struct { type SSH struct {
base.Base base.Base
params Params params Params
@ -45,9 +48,15 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
} }
func (d *SSH) forward(val sshtun.Forward) conn { func (d *SSH) forward(val sshtun.Forward) conn {
tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.FakeRemoteHost, val, d.auth, d.Log()) tun := sshtun.New(d.params.Address,
d.params.Auth.User,
d.params.FakeRemoteHost,
val,
d.auth,
d.Log())
ctx, cancel := context.WithCancel(d.Context()) ctx, cancel := context.WithCancel(d.Context())
go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String()))) tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
go tun.Connect(ctx, sshtun.StdoutPrinterBannerCallback(tunDbgLog), sshtun.StdoutPrinterSessionCallback(tunDbgLog))
return conn{ctx: ctx, cancel: cancel, tun: tun} return conn{ctx: ctx, cancel: cancel, tun: tun}
} }
@ -79,6 +88,9 @@ func (d *SSH) Handle(event dto.Event) error {
d.rw.Lock() d.rw.Lock()
switch event.Type { switch event.Type {
case dto.EventStart: case dto.EventStart:
if d.params.Mode == types.DomainModeSingle && len(d.conns) > 0 {
return ErrAlreadyInUse
}
conn := d.forward(sshtun.Forward{ conn := d.forward(sshtun.Forward{
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))), Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
Remote: d.remoteEndpoint(event.Container.RemoteHost), Remote: d.remoteEndpoint(event.Container.RemoteHost),

View File

@ -18,7 +18,7 @@ type Forward struct {
func AddrToEndpoint(address string) Endpoint { func AddrToEndpoint(address string) Endpoint {
host, port, err := net.SplitHostPort(address) host, port, err := net.SplitHostPort(address)
if err != nil && errtools.IsPortMissingErr(err) { if err != nil && errtools.IsPortMissingErr(err) {
return Endpoint{Host: host, Port: 22} return Endpoint{Host: address, Port: 22}
} }
portNum, err := strconv.Atoi(port) portNum, err := strconv.Atoi(port)
if err != nil { if err != nil {

View File

@ -19,3 +19,10 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
} }
} }
} }
func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback {
return func(msg string) error {
log.Debug(msg)
return nil
}
}

View File

@ -41,7 +41,7 @@ func New(address, user string, fakeRemoteHost bool,
} }
} }
func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) { func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
if c.connected.Load() { if c.connected.Load() {
return return
} }
@ -50,7 +50,7 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second) backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
for { for {
c.connected.Store(true) c.connected.Store(true)
err := c.connect(ctx, sessionCb) err := c.connect(ctx, bannerCb, sessionCb)
if err != nil { if err != nil {
c.log.Error("connect error:", err) c.log.Error("connect error:", err)
} }
@ -67,12 +67,13 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
// connect once to the SSH server. if the connection breaks, we return error and the caller // connect once to the SSH server. if the connection breaks, we return error and the caller
// will try to re-connect // will try to re-connect
func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error { func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
c.log.Debug("connecting") c.log.Debug("connecting")
sshConfig := &ssh.ClientConfig{ sshConfig := &ssh.ClientConfig{
User: c.user, User: c.user,
Auth: c.authMethods, Auth: c.authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: ssh.InsecureIgnoreHostKey(),
BannerCallback: bannerCb,
} }
var sshClient *ssh.Client var sshClient *ssh.Client
@ -89,6 +90,7 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
c.log.Debug("connected") c.log.Debug("connected")
listenerStopped := make(chan error) listenerStopped := make(chan error)
var wg sync.WaitGroup
sess, err := sshClient.NewSession() sess, err := sshClient.NewSession()
if err != nil { if err != nil {
c.log.Errorf("session error: %s", err) c.log.Errorf("session error: %s", err)
@ -96,15 +98,16 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
} }
defer sess.Close() defer sess.Close()
var wg sync.WaitGroup
if sessionCb == nil { if sessionCb == nil {
sessionCb = func(*ssh.Session) {} sessionCb = func(*ssh.Session) {}
} }
wg.Add(2) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
sessionCb(sess) sessionCb(sess)
}() }()
wg.Add(1)
reverseErr := make(chan error) reverseErr := make(chan error)
go func() { go func() {
defer wg.Done() defer wg.Done()