sshpoke/internal/server/proto/sshtun/tunnel_test.go

510 lines
13 KiB
Go

// Copyright 2017, The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.
package sshtun
import (
"bytes"
"context"
"crypto/md5"
"crypto/rsa"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"math/rand"
"net"
"reflect"
"strconv"
"sync"
"testing"
"time"
"golang.org/x/crypto/ssh"
)
type testLogger struct {
*testing.T // Already has Fatalf method
}
func (t testLogger) Printf(f string, x ...interface{}) { t.Logf(f, x...) }
func TestTunnel(t *testing.T) {
rootWG := new(sync.WaitGroup)
defer rootWG.Wait()
rootCtx, cancelAll := context.WithCancel(context.Background())
defer cancelAll()
// Open all of the TCP sockets needed for the test.
tcpLn0 := openListener(t) // Start of the chain
tcpLn1 := openListener(t) // Mid-point of the chain
tcpLn2 := openListener(t) // End of the chain
srvLn0 := openListener(t) // Socket for SSH server in reverse Mode
srvLn1 := openListener(t) // Socket for SSH server in forward Mode
tcpLn0.Close() // To be later binded by the reverse Tunnel
tcpLn1.Close() // To be later binded by the forward Tunnel
go closeWhenDone(rootCtx, tcpLn2)
go closeWhenDone(rootCtx, srvLn0)
go closeWhenDone(rootCtx, srvLn1)
// Generate keys for both the servers and clients.
clientPriv0, clientPub0 := generateKeys(t)
clientPriv1, clientPub1 := generateKeys(t)
serverPriv0, serverPub0 := generateKeys(t)
serverPriv1, serverPub1 := generateKeys(t)
// Start the SSH servers.
rootWG.Add(2)
go func() {
defer rootWG.Done()
runServer(t, rootCtx, srvLn0, serverPriv0, clientPub0, clientPub1)
}()
go func() {
defer rootWG.Done()
runServer(t, rootCtx, srvLn1, serverPriv1, clientPub0, clientPub1)
}()
wg := new(sync.WaitGroup)
defer wg.Wait()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Create the Tunnel configurations.
tn0 := Tunnel{
Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv0)},
HostKeys: ssh.FixedHostKey(serverPub0),
Mode: TunnelReverse, // Reverse Tunnel
User: "user0",
HostAddr: srvLn0.Addr().String(),
BindAddr: tcpLn0.Addr().String(),
DialAddr: tcpLn1.Addr().String(),
Logger: testLogger{t},
}
tn1 := Tunnel{
Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv1)},
HostKeys: ssh.FixedHostKey(serverPub1),
Mode: TunnelForward, // Forward Tunnel
User: "user1",
HostAddr: srvLn1.Addr().String(),
BindAddr: tcpLn1.Addr().String(),
DialAddr: tcpLn2.Addr().String(),
Logger: testLogger{t},
}
// Start the SSH client tunnels.
wg.Add(2)
go tn0.Bind(ctx, wg)
go tn1.Bind(ctx, wg)
t.Log("test started")
done := make(chan bool, 10)
// Start all the transmitters.
for i := 0; i < cap(done); i++ {
i := i
go func() {
for {
rnd := rand.New(rand.NewSource(int64(i)))
hash := md5.New()
size := uint32((1 << 10) + rnd.Intn(1<<20))
buf4 := make([]byte, 4)
binary.LittleEndian.PutUint32(buf4, size)
cnStart, err := net.Dial("tcp", tcpLn0.Addr().String())
if err != nil {
time.Sleep(10 * time.Millisecond)
continue
}
defer cnStart.Close()
if _, err := cnStart.Write(buf4); err != nil {
t.Errorf("write size error: %v", err)
break
}
r := io.LimitReader(rnd, int64(size))
w := io.MultiWriter(cnStart, hash)
if _, err := io.Copy(w, r); err != nil {
t.Errorf("copy error: %v", err)
break
}
if _, err := cnStart.Write(hash.Sum(nil)); err != nil {
t.Errorf("write hash error: %v", err)
break
}
if err := cnStart.Close(); err != nil {
t.Errorf("close error: %v", err)
break
}
break
}
}()
}
// Start all the receivers.
for i := 0; i < cap(done); i++ {
go func() {
for {
hash := md5.New()
buf4 := make([]byte, 4)
cnEnd, err := tcpLn2.Accept()
if err != nil {
time.Sleep(10 * time.Millisecond)
continue
}
defer cnEnd.Close()
if _, err := io.ReadFull(cnEnd, buf4); err != nil {
t.Errorf("read size error: %v", err)
break
}
size := binary.LittleEndian.Uint32(buf4)
r := io.LimitReader(cnEnd, int64(size))
if _, err := io.Copy(hash, r); err != nil {
t.Errorf("copy error: %v", err)
break
}
wantHash, err := ioutil.ReadAll(cnEnd)
if err != nil {
t.Errorf("read hash error: %v", err)
break
}
if err := cnEnd.Close(); err != nil {
t.Errorf("close error: %v", err)
break
}
if gotHash := hash.Sum(nil); !bytes.Equal(gotHash, wantHash) {
t.Errorf("hash mismatch:\ngot %x\nwant %x", gotHash, wantHash)
}
break
}
done <- true
}()
}
for i := 0; i < cap(done); i++ {
select {
case <-done:
case <-time.After(10 * time.Second):
t.Errorf("timed out: %d remaining", cap(done)-i)
return
}
}
t.Log("test complete")
}
// generateKeys generates a random pair of SSH private and public keys.
func generateKeys(t *testing.T) (priv ssh.Signer, pub ssh.PublicKey) {
rnd := rand.New(rand.NewSource(time.Now().Unix()))
rsaKey, err := rsa.GenerateKey(rnd, 1024)
if err != nil {
t.Fatalf("unable to generate RSA key pair: %v", err)
}
priv, err = ssh.NewSignerFromKey(rsaKey)
if err != nil {
t.Fatalf("unable to generate signer: %v", err)
}
pub, err = ssh.NewPublicKey(&rsaKey.PublicKey)
if err != nil {
t.Fatalf("unable to generate public key: %v", err)
}
return priv, pub
}
func openListener(t *testing.T) net.Listener {
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("listen error: %v", err)
}
return ln
}
// runServer starts an SSH server capable of handling forward and reverse
// TCP tunnels. This function blocks for the entire duration that the
// server is running and can be stopped by canceling the context.
//
// The server listens on the provided Listener and will present to clients
// a certificate from serverKey and will only accept users that match
// the provided clientKeys. Only users of the name "User%d" are allowed where
// the ID number is the index for the specified client key provided.
func runServer(t *testing.T, ctx context.Context, ln net.Listener, serverKey ssh.Signer, clientKeys ...ssh.PublicKey) {
wg := new(sync.WaitGroup)
defer wg.Wait()
// Generate SSH server configuration.
conf := ssh.ServerConfig{
PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) {
var uid int
_, err := fmt.Sscanf(c.User(), "User%d", &uid)
if err != nil || uid >= len(clientKeys) || !bytes.Equal(clientKeys[uid].Marshal(), pubKey.Marshal()) {
return nil, fmt.Errorf("unknown public key for %q", c.User())
}
return nil, nil
},
}
conf.AddHostKey(serverKey)
// Handle every SSH client connection.
for {
tcpCn, err := ln.Accept()
if err != nil {
if !isDone(ctx) {
t.Errorf("accept error: %v", err)
}
return
}
wg.Add(1)
go handleServerConn(t, ctx, wg, tcpCn, &conf)
}
}
// handleServerConn handles a single SSH connection.
func handleServerConn(t *testing.T, ctx context.Context, wg *sync.WaitGroup, tcpCn net.Conn, conf *ssh.ServerConfig) {
defer wg.Done()
go closeWhenDone(ctx, tcpCn)
defer tcpCn.Close()
sshCn, chans, reqs, err := ssh.NewServerConn(tcpCn, conf)
if err != nil {
t.Errorf("new connection error: %v", err)
return
}
go closeWhenDone(ctx, sshCn)
defer sshCn.Close()
wg.Add(1)
go handleServerChannels(t, ctx, wg, sshCn, chans)
wg.Add(1)
go handleServerRequests(t, ctx, wg, sshCn, reqs)
if err := sshCn.Wait(); err != nil && err != io.EOF && !isDone(ctx) {
t.Errorf("connection error: %v", err)
}
}
// handleServerChannels handles new channels on a SSH connection.
// The client initiates a new channel when forwarding a TCP dial.
func handleServerChannels(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, chans <-chan ssh.NewChannel) {
defer wg.Done()
for nc := range chans {
if nc.ChannelType() != "direct-tcpip" {
nc.Reject(ssh.UnknownChannelType, "not implemented")
continue
}
var args struct {
DstHost string
DstPort uint32
SrcHost string
SrcPort uint32
}
if !unmarshalData(nc.ExtraData(), &args) {
nc.Reject(ssh.Prohibited, "invalid request")
continue
}
// Open a connection for both sides.
cn, err := net.Dial("tcp", net.JoinHostPort(args.DstHost, strconv.Itoa(int(args.DstPort))))
if err != nil {
nc.Reject(ssh.ConnectionFailed, err.Error())
continue
}
ch, reqs, err := nc.Accept()
if err != nil {
t.Errorf("accept channel error: %v", err)
cn.Close()
continue
}
go ssh.DiscardRequests(reqs)
wg.Add(1)
go bidirCopyAndClose(t, ctx, wg, cn, ch)
}
}
// handleServerRequests handles new requests on a SSH connection.
// The client initiates a new request for binding a local TCP socket.
func handleServerRequests(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, reqs <-chan *ssh.Request) {
defer wg.Done()
for r := range reqs {
if !r.WantReply {
continue
}
if r.Type != "tcpip-forward" {
r.Reply(false, nil)
continue
}
var args struct {
Host string
Port uint32
}
if !unmarshalData(r.Payload, &args) {
r.Reply(false, nil)
continue
}
ln, err := net.Listen("tcp", net.JoinHostPort(args.Host, strconv.Itoa(int(args.Port))))
if err != nil {
r.Reply(false, nil)
continue
}
var resp struct{ Port uint32 }
_, resp.Port = splitHostPort(ln.Addr().String())
if err := r.Reply(true, marshalData(resp)); err != nil {
t.Errorf("request reply error: %v", err)
ln.Close()
continue
}
wg.Add(1)
go handleLocalListener(t, ctx, wg, sshCn, ln, args.Host)
}
}
// handleLocalListener handles every new connection on the provided socket.
// All local connections will be forwarded to the client via a new channel.
func handleLocalListener(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, ln net.Listener, host string) {
defer wg.Done()
go closeWhenDone(ctx, ln)
defer ln.Close()
for {
// Open a connection for both sides.
cn, err := ln.Accept()
if err != nil {
if !isDone(ctx) {
t.Errorf("accept error: %v", err)
}
return
}
var args struct {
DstHost string
DstPort uint32
SrcHost string
SrcPort uint32
}
args.DstHost, args.DstPort = splitHostPort(cn.LocalAddr().String())
args.SrcHost, args.SrcPort = splitHostPort(cn.RemoteAddr().String())
args.DstHost = host // This must match on client side!
ch, reqs, err := sshCn.OpenChannel("forwarded-tcpip", marshalData(args))
if err != nil {
t.Errorf("open channel error: %v", err)
cn.Close()
continue
}
go ssh.DiscardRequests(reqs)
wg.Add(1)
go bidirCopyAndClose(t, ctx, wg, cn, ch)
}
}
// bidirCopyAndClose performs a bi-directional copy on both connections
// until either side closes the connection or the context is canceled.
// This will close both connections before returning.
func bidirCopyAndClose(t *testing.T, ctx context.Context, wg *sync.WaitGroup, c1, c2 io.ReadWriteCloser) {
defer wg.Done()
go closeWhenDone(ctx, c1)
go closeWhenDone(ctx, c2)
defer c1.Close()
defer c2.Close()
errc := make(chan error, 2)
go func() {
_, err := io.Copy(c1, c2)
errc <- err
}()
go func() {
_, err := io.Copy(c2, c1)
errc <- err
}()
if err := <-errc; err != nil && err != io.EOF && !isDone(ctx) {
t.Errorf("copy error: %v", err)
}
}
// unmarshalData parses b into s, where s is a pointer to a struct.
// Only unexported fields of type uint32 or string are allowed.
func unmarshalData(b []byte, s interface{}) bool {
v := reflect.ValueOf(s)
if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
panic("destination must be pointer to struct")
}
v = v.Elem()
for i := 0; i < v.NumField(); i++ {
switch v.Type().Field(i).Type.Kind() {
case reflect.Uint32:
if len(b) < 4 {
return false
}
v.Field(i).Set(reflect.ValueOf(binary.BigEndian.Uint32(b)))
b = b[4:]
case reflect.String:
if len(b) < 4 {
return false
}
n := binary.BigEndian.Uint32(b)
b = b[4:]
if uint64(len(b)) < uint64(n) {
return false
}
v.Field(i).Set(reflect.ValueOf(string(b[:n])))
b = b[n:]
default:
panic("invalid field type: " + v.Type().Field(i).Type.String())
}
}
return len(b) == 0
}
// marshalData serializes s into b, where s is a struct (or a pointer to one).
// Only unexported fields of type uint32 or string are allowed.
func marshalData(s interface{}) (b []byte) {
v := reflect.ValueOf(s)
if v.IsValid() && v.Kind() == reflect.Ptr {
v = v.Elem()
}
if !v.IsValid() || v.Kind() != reflect.Struct {
panic("source must be a struct")
}
var arr32 [4]byte
for i := 0; i < v.NumField(); i++ {
switch v.Type().Field(i).Type.Kind() {
case reflect.Uint32:
binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Uint()))
b = append(b, arr32[:]...)
case reflect.String:
binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Len()))
b = append(b, arr32[:]...)
b = append(b, v.Field(i).String()...)
default:
panic("invalid field type: " + v.Type().Field(i).Type.String())
}
}
return b
}
func splitHostPort(s string) (string, uint32) {
host, port, _ := net.SplitHostPort(s)
p, _ := strconv.Atoi(port)
return host, uint32(p)
}
func closeWhenDone(ctx context.Context, c io.Closer) {
<-ctx.Done()
c.Close()
}
func isDone(ctx context.Context) bool {
select {
case <-ctx.Done():
return true
default:
return false
}
}