510 lines
13 KiB
Go
510 lines
13 KiB
Go
// Copyright 2017, The Go Authors. All rights reserved.
|
|
// Use of this source code is governed by a BSD-style
|
|
// license that can be found in the LICENSE.md file.
|
|
|
|
package sshtun
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/md5"
|
|
"crypto/rsa"
|
|
"encoding/binary"
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
"math/rand"
|
|
"net"
|
|
"reflect"
|
|
"strconv"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type testLogger struct {
|
|
*testing.T // Already has Fatalf method
|
|
}
|
|
|
|
func (t testLogger) Printf(f string, x ...interface{}) { t.Logf(f, x...) }
|
|
|
|
func TestTunnel(t *testing.T) {
|
|
rootWG := new(sync.WaitGroup)
|
|
defer rootWG.Wait()
|
|
rootCtx, cancelAll := context.WithCancel(context.Background())
|
|
defer cancelAll()
|
|
|
|
// Open all of the TCP sockets needed for the test.
|
|
tcpLn0 := openListener(t) // Start of the chain
|
|
tcpLn1 := openListener(t) // Mid-point of the chain
|
|
tcpLn2 := openListener(t) // End of the chain
|
|
srvLn0 := openListener(t) // Socket for SSH server in reverse Mode
|
|
srvLn1 := openListener(t) // Socket for SSH server in forward Mode
|
|
|
|
tcpLn0.Close() // To be later binded by the reverse Tunnel
|
|
tcpLn1.Close() // To be later binded by the forward Tunnel
|
|
go closeWhenDone(rootCtx, tcpLn2)
|
|
go closeWhenDone(rootCtx, srvLn0)
|
|
go closeWhenDone(rootCtx, srvLn1)
|
|
|
|
// Generate keys for both the servers and clients.
|
|
clientPriv0, clientPub0 := generateKeys(t)
|
|
clientPriv1, clientPub1 := generateKeys(t)
|
|
serverPriv0, serverPub0 := generateKeys(t)
|
|
serverPriv1, serverPub1 := generateKeys(t)
|
|
|
|
// Start the SSH servers.
|
|
rootWG.Add(2)
|
|
go func() {
|
|
defer rootWG.Done()
|
|
runServer(t, rootCtx, srvLn0, serverPriv0, clientPub0, clientPub1)
|
|
}()
|
|
go func() {
|
|
defer rootWG.Done()
|
|
runServer(t, rootCtx, srvLn1, serverPriv1, clientPub0, clientPub1)
|
|
}()
|
|
|
|
wg := new(sync.WaitGroup)
|
|
defer wg.Wait()
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
// Create the Tunnel configurations.
|
|
tn0 := Tunnel{
|
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv0)},
|
|
HostKeys: ssh.FixedHostKey(serverPub0),
|
|
Mode: TunnelReverse, // Reverse Tunnel
|
|
User: "user0",
|
|
HostAddr: srvLn0.Addr().String(),
|
|
BindAddr: tcpLn0.Addr().String(),
|
|
DialAddr: tcpLn1.Addr().String(),
|
|
Logger: testLogger{t},
|
|
}
|
|
tn1 := Tunnel{
|
|
Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv1)},
|
|
HostKeys: ssh.FixedHostKey(serverPub1),
|
|
Mode: TunnelForward, // Forward Tunnel
|
|
User: "user1",
|
|
HostAddr: srvLn1.Addr().String(),
|
|
BindAddr: tcpLn1.Addr().String(),
|
|
DialAddr: tcpLn2.Addr().String(),
|
|
Logger: testLogger{t},
|
|
}
|
|
|
|
// Start the SSH client tunnels.
|
|
wg.Add(2)
|
|
go tn0.Bind(ctx, wg)
|
|
go tn1.Bind(ctx, wg)
|
|
|
|
t.Log("test started")
|
|
done := make(chan bool, 10)
|
|
|
|
// Start all the transmitters.
|
|
for i := 0; i < cap(done); i++ {
|
|
i := i
|
|
go func() {
|
|
for {
|
|
rnd := rand.New(rand.NewSource(int64(i)))
|
|
hash := md5.New()
|
|
size := uint32((1 << 10) + rnd.Intn(1<<20))
|
|
buf4 := make([]byte, 4)
|
|
binary.LittleEndian.PutUint32(buf4, size)
|
|
|
|
cnStart, err := net.Dial("tcp", tcpLn0.Addr().String())
|
|
if err != nil {
|
|
time.Sleep(10 * time.Millisecond)
|
|
continue
|
|
}
|
|
defer cnStart.Close()
|
|
if _, err := cnStart.Write(buf4); err != nil {
|
|
t.Errorf("write size error: %v", err)
|
|
break
|
|
}
|
|
r := io.LimitReader(rnd, int64(size))
|
|
w := io.MultiWriter(cnStart, hash)
|
|
if _, err := io.Copy(w, r); err != nil {
|
|
t.Errorf("copy error: %v", err)
|
|
break
|
|
}
|
|
if _, err := cnStart.Write(hash.Sum(nil)); err != nil {
|
|
t.Errorf("write hash error: %v", err)
|
|
break
|
|
}
|
|
if err := cnStart.Close(); err != nil {
|
|
t.Errorf("close error: %v", err)
|
|
break
|
|
}
|
|
break
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Start all the receivers.
|
|
for i := 0; i < cap(done); i++ {
|
|
go func() {
|
|
for {
|
|
hash := md5.New()
|
|
buf4 := make([]byte, 4)
|
|
|
|
cnEnd, err := tcpLn2.Accept()
|
|
if err != nil {
|
|
time.Sleep(10 * time.Millisecond)
|
|
continue
|
|
}
|
|
defer cnEnd.Close()
|
|
|
|
if _, err := io.ReadFull(cnEnd, buf4); err != nil {
|
|
t.Errorf("read size error: %v", err)
|
|
break
|
|
}
|
|
size := binary.LittleEndian.Uint32(buf4)
|
|
r := io.LimitReader(cnEnd, int64(size))
|
|
if _, err := io.Copy(hash, r); err != nil {
|
|
t.Errorf("copy error: %v", err)
|
|
break
|
|
}
|
|
wantHash, err := ioutil.ReadAll(cnEnd)
|
|
if err != nil {
|
|
t.Errorf("read hash error: %v", err)
|
|
break
|
|
}
|
|
if err := cnEnd.Close(); err != nil {
|
|
t.Errorf("close error: %v", err)
|
|
break
|
|
}
|
|
|
|
if gotHash := hash.Sum(nil); !bytes.Equal(gotHash, wantHash) {
|
|
t.Errorf("hash mismatch:\ngot %x\nwant %x", gotHash, wantHash)
|
|
}
|
|
break
|
|
}
|
|
done <- true
|
|
}()
|
|
}
|
|
|
|
for i := 0; i < cap(done); i++ {
|
|
select {
|
|
case <-done:
|
|
case <-time.After(10 * time.Second):
|
|
t.Errorf("timed out: %d remaining", cap(done)-i)
|
|
return
|
|
}
|
|
}
|
|
t.Log("test complete")
|
|
}
|
|
|
|
// generateKeys generates a random pair of SSH private and public keys.
|
|
func generateKeys(t *testing.T) (priv ssh.Signer, pub ssh.PublicKey) {
|
|
rnd := rand.New(rand.NewSource(time.Now().Unix()))
|
|
rsaKey, err := rsa.GenerateKey(rnd, 1024)
|
|
if err != nil {
|
|
t.Fatalf("unable to generate RSA key pair: %v", err)
|
|
}
|
|
priv, err = ssh.NewSignerFromKey(rsaKey)
|
|
if err != nil {
|
|
t.Fatalf("unable to generate signer: %v", err)
|
|
}
|
|
pub, err = ssh.NewPublicKey(&rsaKey.PublicKey)
|
|
if err != nil {
|
|
t.Fatalf("unable to generate public key: %v", err)
|
|
}
|
|
return priv, pub
|
|
}
|
|
|
|
func openListener(t *testing.T) net.Listener {
|
|
ln, err := net.Listen("tcp", ":0")
|
|
if err != nil {
|
|
t.Fatalf("listen error: %v", err)
|
|
}
|
|
return ln
|
|
}
|
|
|
|
// runServer starts an SSH server capable of handling forward and reverse
|
|
// TCP tunnels. This function blocks for the entire duration that the
|
|
// server is running and can be stopped by canceling the context.
|
|
//
|
|
// The server listens on the provided Listener and will present to clients
|
|
// a certificate from serverKey and will only accept users that match
|
|
// the provided clientKeys. Only users of the name "User%d" are allowed where
|
|
// the ID number is the index for the specified client key provided.
|
|
func runServer(t *testing.T, ctx context.Context, ln net.Listener, serverKey ssh.Signer, clientKeys ...ssh.PublicKey) {
|
|
wg := new(sync.WaitGroup)
|
|
defer wg.Wait()
|
|
|
|
// Generate SSH server configuration.
|
|
conf := ssh.ServerConfig{
|
|
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
|
|
var uid int
|
|
_, err := fmt.Sscanf(c.User(), "User%d", &uid)
|
|
if err != nil || uid >= len(clientKeys) || !bytes.Equal(clientKeys[uid].Marshal(), pubKey.Marshal()) {
|
|
return nil, fmt.Errorf("unknown public key for %q", c.User())
|
|
}
|
|
return nil, nil
|
|
},
|
|
}
|
|
conf.AddHostKey(serverKey)
|
|
|
|
// Handle every SSH client connection.
|
|
for {
|
|
tcpCn, err := ln.Accept()
|
|
if err != nil {
|
|
if !isDone(ctx) {
|
|
t.Errorf("accept error: %v", err)
|
|
}
|
|
return
|
|
}
|
|
wg.Add(1)
|
|
go handleServerConn(t, ctx, wg, tcpCn, &conf)
|
|
}
|
|
}
|
|
|
|
// handleServerConn handles a single SSH connection.
|
|
func handleServerConn(t *testing.T, ctx context.Context, wg *sync.WaitGroup, tcpCn net.Conn, conf *ssh.ServerConfig) {
|
|
defer wg.Done()
|
|
go closeWhenDone(ctx, tcpCn)
|
|
defer tcpCn.Close()
|
|
|
|
sshCn, chans, reqs, err := ssh.NewServerConn(tcpCn, conf)
|
|
if err != nil {
|
|
t.Errorf("new connection error: %v", err)
|
|
return
|
|
}
|
|
go closeWhenDone(ctx, sshCn)
|
|
defer sshCn.Close()
|
|
|
|
wg.Add(1)
|
|
go handleServerChannels(t, ctx, wg, sshCn, chans)
|
|
|
|
wg.Add(1)
|
|
go handleServerRequests(t, ctx, wg, sshCn, reqs)
|
|
|
|
if err := sshCn.Wait(); err != nil && err != io.EOF && !isDone(ctx) {
|
|
t.Errorf("connection error: %v", err)
|
|
}
|
|
}
|
|
|
|
// handleServerChannels handles new channels on a SSH connection.
|
|
// The client initiates a new channel when forwarding a TCP dial.
|
|
func handleServerChannels(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, chans <-chan ssh.NewChannel) {
|
|
defer wg.Done()
|
|
for nc := range chans {
|
|
if nc.ChannelType() != "direct-tcpip" {
|
|
nc.Reject(ssh.UnknownChannelType, "not implemented")
|
|
continue
|
|
}
|
|
var args struct {
|
|
DstHost string
|
|
DstPort uint32
|
|
SrcHost string
|
|
SrcPort uint32
|
|
}
|
|
if !unmarshalData(nc.ExtraData(), &args) {
|
|
nc.Reject(ssh.Prohibited, "invalid request")
|
|
continue
|
|
}
|
|
|
|
// Open a connection for both sides.
|
|
cn, err := net.Dial("tcp", net.JoinHostPort(args.DstHost, strconv.Itoa(int(args.DstPort))))
|
|
if err != nil {
|
|
nc.Reject(ssh.ConnectionFailed, err.Error())
|
|
continue
|
|
}
|
|
ch, reqs, err := nc.Accept()
|
|
if err != nil {
|
|
t.Errorf("accept channel error: %v", err)
|
|
cn.Close()
|
|
continue
|
|
}
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
wg.Add(1)
|
|
go bidirCopyAndClose(t, ctx, wg, cn, ch)
|
|
}
|
|
}
|
|
|
|
// handleServerRequests handles new requests on a SSH connection.
|
|
// The client initiates a new request for binding a local TCP socket.
|
|
func handleServerRequests(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, reqs <-chan *ssh.Request) {
|
|
defer wg.Done()
|
|
for r := range reqs {
|
|
if !r.WantReply {
|
|
continue
|
|
}
|
|
if r.Type != "tcpip-forward" {
|
|
r.Reply(false, nil)
|
|
continue
|
|
}
|
|
var args struct {
|
|
Host string
|
|
Port uint32
|
|
}
|
|
if !unmarshalData(r.Payload, &args) {
|
|
r.Reply(false, nil)
|
|
continue
|
|
}
|
|
ln, err := net.Listen("tcp", net.JoinHostPort(args.Host, strconv.Itoa(int(args.Port))))
|
|
if err != nil {
|
|
r.Reply(false, nil)
|
|
continue
|
|
}
|
|
|
|
var resp struct{ Port uint32 }
|
|
_, resp.Port = splitHostPort(ln.Addr().String())
|
|
if err := r.Reply(true, marshalData(resp)); err != nil {
|
|
t.Errorf("request reply error: %v", err)
|
|
ln.Close()
|
|
continue
|
|
}
|
|
|
|
wg.Add(1)
|
|
go handleLocalListener(t, ctx, wg, sshCn, ln, args.Host)
|
|
|
|
}
|
|
}
|
|
|
|
// handleLocalListener handles every new connection on the provided socket.
|
|
// All local connections will be forwarded to the client via a new channel.
|
|
func handleLocalListener(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, ln net.Listener, host string) {
|
|
defer wg.Done()
|
|
go closeWhenDone(ctx, ln)
|
|
defer ln.Close()
|
|
|
|
for {
|
|
// Open a connection for both sides.
|
|
cn, err := ln.Accept()
|
|
if err != nil {
|
|
if !isDone(ctx) {
|
|
t.Errorf("accept error: %v", err)
|
|
}
|
|
return
|
|
}
|
|
var args struct {
|
|
DstHost string
|
|
DstPort uint32
|
|
SrcHost string
|
|
SrcPort uint32
|
|
}
|
|
args.DstHost, args.DstPort = splitHostPort(cn.LocalAddr().String())
|
|
args.SrcHost, args.SrcPort = splitHostPort(cn.RemoteAddr().String())
|
|
args.DstHost = host // This must match on client side!
|
|
ch, reqs, err := sshCn.OpenChannel("forwarded-tcpip", marshalData(args))
|
|
if err != nil {
|
|
t.Errorf("open channel error: %v", err)
|
|
cn.Close()
|
|
continue
|
|
}
|
|
go ssh.DiscardRequests(reqs)
|
|
|
|
wg.Add(1)
|
|
go bidirCopyAndClose(t, ctx, wg, cn, ch)
|
|
}
|
|
}
|
|
|
|
// bidirCopyAndClose performs a bi-directional copy on both connections
|
|
// until either side closes the connection or the context is canceled.
|
|
// This will close both connections before returning.
|
|
func bidirCopyAndClose(t *testing.T, ctx context.Context, wg *sync.WaitGroup, c1, c2 io.ReadWriteCloser) {
|
|
defer wg.Done()
|
|
go closeWhenDone(ctx, c1)
|
|
go closeWhenDone(ctx, c2)
|
|
defer c1.Close()
|
|
defer c2.Close()
|
|
|
|
errc := make(chan error, 2)
|
|
go func() {
|
|
_, err := io.Copy(c1, c2)
|
|
errc <- err
|
|
}()
|
|
go func() {
|
|
_, err := io.Copy(c2, c1)
|
|
errc <- err
|
|
}()
|
|
if err := <-errc; err != nil && err != io.EOF && !isDone(ctx) {
|
|
t.Errorf("copy error: %v", err)
|
|
}
|
|
}
|
|
|
|
// unmarshalData parses b into s, where s is a pointer to a struct.
|
|
// Only unexported fields of type uint32 or string are allowed.
|
|
func unmarshalData(b []byte, s interface{}) bool {
|
|
v := reflect.ValueOf(s)
|
|
if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
|
|
panic("destination must be pointer to struct")
|
|
}
|
|
v = v.Elem()
|
|
for i := 0; i < v.NumField(); i++ {
|
|
switch v.Type().Field(i).Type.Kind() {
|
|
case reflect.Uint32:
|
|
if len(b) < 4 {
|
|
return false
|
|
}
|
|
v.Field(i).Set(reflect.ValueOf(binary.BigEndian.Uint32(b)))
|
|
b = b[4:]
|
|
case reflect.String:
|
|
if len(b) < 4 {
|
|
return false
|
|
}
|
|
n := binary.BigEndian.Uint32(b)
|
|
b = b[4:]
|
|
if uint64(len(b)) < uint64(n) {
|
|
return false
|
|
}
|
|
v.Field(i).Set(reflect.ValueOf(string(b[:n])))
|
|
b = b[n:]
|
|
default:
|
|
panic("invalid field type: " + v.Type().Field(i).Type.String())
|
|
}
|
|
}
|
|
return len(b) == 0
|
|
}
|
|
|
|
// marshalData serializes s into b, where s is a struct (or a pointer to one).
|
|
// Only unexported fields of type uint32 or string are allowed.
|
|
func marshalData(s interface{}) (b []byte) {
|
|
v := reflect.ValueOf(s)
|
|
if v.IsValid() && v.Kind() == reflect.Ptr {
|
|
v = v.Elem()
|
|
}
|
|
if !v.IsValid() || v.Kind() != reflect.Struct {
|
|
panic("source must be a struct")
|
|
}
|
|
var arr32 [4]byte
|
|
for i := 0; i < v.NumField(); i++ {
|
|
switch v.Type().Field(i).Type.Kind() {
|
|
case reflect.Uint32:
|
|
binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Uint()))
|
|
b = append(b, arr32[:]...)
|
|
case reflect.String:
|
|
binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Len()))
|
|
b = append(b, arr32[:]...)
|
|
b = append(b, v.Field(i).String()...)
|
|
default:
|
|
panic("invalid field type: " + v.Type().Field(i).Type.String())
|
|
}
|
|
}
|
|
return b
|
|
|
|
}
|
|
|
|
func splitHostPort(s string) (string, uint32) {
|
|
host, port, _ := net.SplitHostPort(s)
|
|
p, _ := strconv.Atoi(port)
|
|
return host, uint32(p)
|
|
}
|
|
|
|
func closeWhenDone(ctx context.Context, c io.Closer) {
|
|
<-ctx.Done()
|
|
c.Close()
|
|
}
|
|
|
|
func isDone(ctx context.Context) bool {
|
|
select {
|
|
case <-ctx.Done():
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|