sshpoke/internal/server/proto/sshtun/tunnel.go

250 lines
5.7 KiB
Go

// Copyright 2017, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.
package sshtun
import (
"context"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"golang.org/x/crypto/ssh"
)
type TunnelMode uint8
func (t TunnelMode) String() string {
switch t {
case TunnelForward:
return "->"
case TunnelReverse:
return "<-"
default:
return "<?>"
}
}
const (
TunnelForward TunnelMode = iota
TunnelReverse
)
type logger interface {
Printf(string, ...interface{})
}
type Tunnel struct {
Auth []ssh.AuthMethod
HostKeys ssh.HostKeyCallback
Mode TunnelMode
User string
HostAddr string
BindAddr string
DialAddr string
RetryInterval time.Duration
KeepAlive KeepAliveConfig
Logger logger
}
type KeepAliveConfig struct {
// Interval is the amount of time in seconds to wait before the
// Tunnel client will send a keep-alive message to ensure some minimum
// traffic on the SSH connection.
Interval uint
// CountMax is the maximum number of consecutive failed responses to
// keep-alive messages the client is willing to tolerate before considering
// the SSH connection as dead.
CountMax uint
}
func (t Tunnel) String() string {
var left, right string
switch t.Mode {
case TunnelForward:
left, right = t.BindAddr, t.DialAddr
case TunnelReverse:
left, right = t.DialAddr, t.BindAddr
}
return fmt.Sprintf("%s@%s | %s %s %s", t.User, t.HostAddr, left, t.Mode, right)
}
func (t Tunnel) Bind(ctx context.Context, wg *sync.WaitGroup) {
defer wg.Done()
for {
var once sync.Once // Only print errors once per session
func() {
// Connect to the server host via SSH.
cl, err := ssh.Dial("tcp", t.HostAddr, &ssh.ClientConfig{
User: t.User,
Auth: t.Auth,
HostKeyCallback: t.HostKeys,
Timeout: 5 * time.Second,
})
if err != nil {
once.Do(func() { t.Logger.Printf("(%v) SSH dial error: %v", t, err) })
return
}
wg.Add(1)
go t.keepAliveMonitor(&once, wg, cl)
defer cl.Close()
// Attempt to bind to the inbound socket.
var ln net.Listener
switch t.Mode {
case TunnelForward:
ln, err = net.Listen("tcp", t.BindAddr)
case TunnelReverse:
ln, err = cl.Listen("tcp", t.BindAddr)
}
if err != nil {
once.Do(func() { t.Logger.Printf("(%v) bind error: %v", t, err) })
return
}
// The socket is bound. Make sure we close it eventually.
bindCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
cl.Wait()
cancel()
}()
go func() {
<-bindCtx.Done()
once.Do(func() {}) // Suppress future errors
ln.Close()
}()
t.Logger.Printf("(%v) binded Tunnel", t)
defer t.Logger.Printf("(%v) collapsed Tunnel", t)
// Accept all incoming connections.
for {
cn1, err := ln.Accept()
if err != nil {
once.Do(func() { t.Logger.Printf("(%v) accept error: %v", t, err) })
return
}
wg.Add(1)
go t.dialTunnel(bindCtx, wg, cl, cn1)
}
}()
select {
case <-ctx.Done():
return
case <-time.After(t.RetryInterval):
t.Logger.Printf("(%v) retrying...", t)
}
}
}
func (t Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh.Client, cn1 net.Conn) {
defer wg.Done()
// The inbound connection is established. Make sure we close it eventually.
connCtx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-connCtx.Done()
cn1.Close()
}()
// Establish the outbound connection.
var cn2 net.Conn
var err error
switch t.Mode {
case TunnelForward:
cn2, err = client.Dial("tcp", t.DialAddr)
case TunnelReverse:
cn2, err = net.Dial("tcp", t.DialAddr)
}
if err != nil {
t.Logger.Printf("(%v) dial error: %v", t, err)
return
}
go func() {
<-connCtx.Done()
cn2.Close()
}()
t.Logger.Printf("(%v) connection established", t)
defer t.Logger.Printf("(%v) connection closed", t)
// Copy bytes from one connection to the other until one side closes.
var once sync.Once
var wg2 sync.WaitGroup
wg2.Add(2)
go func() {
defer wg2.Done()
defer cancel()
if _, err := io.Copy(cn1, cn2); err != nil {
once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) })
}
once.Do(func() {}) // Suppress future errors
}()
go func() {
defer wg2.Done()
defer cancel()
if _, err := io.Copy(cn2, cn1); err != nil {
once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) })
}
once.Do(func() {}) // Suppress future errors
}()
wg2.Wait()
}
// keepAliveMonitor periodically sends messages to invoke a response.
// If the server does not respond after some period of time,
// assume that the underlying net.Conn abruptly died.
func (t Tunnel) keepAliveMonitor(once *sync.Once, wg *sync.WaitGroup, client *ssh.Client) {
defer wg.Done()
if t.KeepAlive.Interval == 0 || t.KeepAlive.CountMax == 0 {
return
}
// Detect when the SSH connection is closed.
wait := make(chan error, 1)
wg.Add(1)
go func() {
defer wg.Done()
wait <- client.Wait()
}()
// Repeatedly check if the remote server is still alive.
var aliveCount int32
ticker := time.NewTicker(time.Duration(t.KeepAlive.Interval) * time.Second)
defer ticker.Stop()
for {
select {
case err := <-wait:
if err != nil && err != io.EOF {
once.Do(func() { t.Logger.Printf("(%v) SSH error: %v", t, err) })
}
return
case <-ticker.C:
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.KeepAlive.CountMax) {
once.Do(func() { t.Logger.Printf("(%v) SSH keep-alive termination", t) })
client.Close()
return
}
}
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
if err == nil {
atomic.StoreInt32(&aliveCount, 0)
}
}()
}
}