ssh: mode parameter support, banner support
This commit is contained in:
parent
273482e3fa
commit
408f88ebd2
@ -2,6 +2,7 @@ package ssh
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"path"
|
"path"
|
||||||
"strconv"
|
"strconv"
|
||||||
@ -16,6 +17,8 @@ import (
|
|||||||
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var ErrAlreadyInUse = errors.New("domain is already in use")
|
||||||
|
|
||||||
type SSH struct {
|
type SSH struct {
|
||||||
base.Base
|
base.Base
|
||||||
params Params
|
params Params
|
||||||
@ -45,9 +48,15 @@ func New(ctx context.Context, name string, params config.DriverParams) (base.Dri
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *SSH) forward(val sshtun.Forward) conn {
|
func (d *SSH) forward(val sshtun.Forward) conn {
|
||||||
tun := sshtun.New(d.params.Address, d.params.Auth.User, d.params.FakeRemoteHost, val, d.auth, d.Log())
|
tun := sshtun.New(d.params.Address,
|
||||||
|
d.params.Auth.User,
|
||||||
|
d.params.FakeRemoteHost,
|
||||||
|
val,
|
||||||
|
d.auth,
|
||||||
|
d.Log())
|
||||||
ctx, cancel := context.WithCancel(d.Context())
|
ctx, cancel := context.WithCancel(d.Context())
|
||||||
go tun.Connect(ctx, sshtun.StdoutPrinterSessionCallback(d.Log().With("ssh-output", val.Remote.String())))
|
tunDbgLog := d.Log().With("ssh-output", val.Remote.String())
|
||||||
|
go tun.Connect(ctx, sshtun.StdoutPrinterBannerCallback(tunDbgLog), sshtun.StdoutPrinterSessionCallback(tunDbgLog))
|
||||||
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
return conn{ctx: ctx, cancel: cancel, tun: tun}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -79,6 +88,9 @@ func (d *SSH) Handle(event dto.Event) error {
|
|||||||
d.rw.Lock()
|
d.rw.Lock()
|
||||||
switch event.Type {
|
switch event.Type {
|
||||||
case dto.EventStart:
|
case dto.EventStart:
|
||||||
|
if d.params.Mode == types.DomainModeSingle && len(d.conns) > 0 {
|
||||||
|
return ErrAlreadyInUse
|
||||||
|
}
|
||||||
conn := d.forward(sshtun.Forward{
|
conn := d.forward(sshtun.Forward{
|
||||||
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
|
Local: sshtun.AddrToEndpoint(net.JoinHostPort(event.Container.IP.String(), strconv.Itoa(int(event.Container.Port)))),
|
||||||
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
Remote: d.remoteEndpoint(event.Container.RemoteHost),
|
||||||
|
@ -18,7 +18,7 @@ type Forward struct {
|
|||||||
func AddrToEndpoint(address string) Endpoint {
|
func AddrToEndpoint(address string) Endpoint {
|
||||||
host, port, err := net.SplitHostPort(address)
|
host, port, err := net.SplitHostPort(address)
|
||||||
if err != nil && errtools.IsPortMissingErr(err) {
|
if err != nil && errtools.IsPortMissingErr(err) {
|
||||||
return Endpoint{Host: host, Port: 22}
|
return Endpoint{Host: address, Port: 22}
|
||||||
}
|
}
|
||||||
portNum, err := strconv.Atoi(port)
|
portNum, err := strconv.Atoi(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -19,3 +19,10 @@ func StdoutPrinterSessionCallback(log *zap.SugaredLogger) SessionCallback {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func StdoutPrinterBannerCallback(log *zap.SugaredLogger) ssh.BannerCallback {
|
||||||
|
return func(msg string) error {
|
||||||
|
log.Debug(msg)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -41,7 +41,7 @@ func New(address, user string, fakeRemoteHost bool,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
func (c *Tunnel) Connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) {
|
||||||
if c.connected.Load() {
|
if c.connected.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -50,7 +50,7 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
|||||||
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
backoffTime := backoff.ExponentialWithCappedMax(100*time.Millisecond, 5*time.Second)
|
||||||
for {
|
for {
|
||||||
c.connected.Store(true)
|
c.connected.Store(true)
|
||||||
err := c.connect(ctx, sessionCb)
|
err := c.connect(ctx, bannerCb, sessionCb)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Error("connect error:", err)
|
c.log.Error("connect error:", err)
|
||||||
}
|
}
|
||||||
@ -67,12 +67,13 @@ func (c *Tunnel) Connect(ctx context.Context, sessionCb SessionCallback) {
|
|||||||
|
|
||||||
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
// connect once to the SSH server. if the connection breaks, we return error and the caller
|
||||||
// will try to re-connect
|
// will try to re-connect
|
||||||
func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
func (c *Tunnel) connect(ctx context.Context, bannerCb ssh.BannerCallback, sessionCb SessionCallback) error {
|
||||||
c.log.Debug("connecting")
|
c.log.Debug("connecting")
|
||||||
sshConfig := &ssh.ClientConfig{
|
sshConfig := &ssh.ClientConfig{
|
||||||
User: c.user,
|
User: c.user,
|
||||||
Auth: c.authMethods,
|
Auth: c.authMethods,
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||||
|
BannerCallback: bannerCb,
|
||||||
}
|
}
|
||||||
|
|
||||||
var sshClient *ssh.Client
|
var sshClient *ssh.Client
|
||||||
@ -89,6 +90,7 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
|||||||
c.log.Debug("connected")
|
c.log.Debug("connected")
|
||||||
listenerStopped := make(chan error)
|
listenerStopped := make(chan error)
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
sess, err := sshClient.NewSession()
|
sess, err := sshClient.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.log.Errorf("session error: %s", err)
|
c.log.Errorf("session error: %s", err)
|
||||||
@ -96,15 +98,16 @@ func (c *Tunnel) connect(ctx context.Context, sessionCb SessionCallback) error {
|
|||||||
}
|
}
|
||||||
defer sess.Close()
|
defer sess.Close()
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
if sessionCb == nil {
|
if sessionCb == nil {
|
||||||
sessionCb = func(*ssh.Session) {}
|
sessionCb = func(*ssh.Session) {}
|
||||||
}
|
}
|
||||||
wg.Add(2)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
sessionCb(sess)
|
sessionCb(sess)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
reverseErr := make(chan error)
|
reverseErr := make(chan error)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
|
Loading…
Reference in New Issue
Block a user