// Copyright 2017, The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE.md file. package sshtun import ( "bytes" "context" "crypto/md5" "crypto/rsa" "encoding/binary" "fmt" "io" "io/ioutil" "math/rand" "net" "reflect" "strconv" "sync" "testing" "time" "golang.org/x/crypto/ssh" ) type testLogger struct { *testing.T // Already has Fatalf method } func (t testLogger) Printf(f string, x ...interface{}) { t.Logf(f, x...) } func TestTunnel(t *testing.T) { rootWG := new(sync.WaitGroup) defer rootWG.Wait() rootCtx, cancelAll := context.WithCancel(context.Background()) defer cancelAll() // Open all of the TCP sockets needed for the test. tcpLn0 := openListener(t) // Start of the chain tcpLn1 := openListener(t) // Mid-point of the chain tcpLn2 := openListener(t) // End of the chain srvLn0 := openListener(t) // Socket for SSH server in reverse Mode srvLn1 := openListener(t) // Socket for SSH server in forward Mode tcpLn0.Close() // To be later binded by the reverse Tunnel tcpLn1.Close() // To be later binded by the forward Tunnel go closeWhenDone(rootCtx, tcpLn2) go closeWhenDone(rootCtx, srvLn0) go closeWhenDone(rootCtx, srvLn1) // Generate keys for both the servers and clients. clientPriv0, clientPub0 := generateKeys(t) clientPriv1, clientPub1 := generateKeys(t) serverPriv0, serverPub0 := generateKeys(t) serverPriv1, serverPub1 := generateKeys(t) // Start the SSH servers. rootWG.Add(2) go func() { defer rootWG.Done() runServer(t, rootCtx, srvLn0, serverPriv0, clientPub0, clientPub1) }() go func() { defer rootWG.Done() runServer(t, rootCtx, srvLn1, serverPriv1, clientPub0, clientPub1) }() wg := new(sync.WaitGroup) defer wg.Wait() ctx, cancel := context.WithCancel(context.Background()) defer cancel() // Create the Tunnel configurations. tn0 := Tunnel{ Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv0)}, HostKeys: ssh.FixedHostKey(serverPub0), Mode: TunnelReverse, // Reverse Tunnel User: "user0", HostAddr: srvLn0.Addr().String(), BindAddr: tcpLn0.Addr().String(), DialAddr: tcpLn1.Addr().String(), Logger: testLogger{t}, } tn1 := Tunnel{ Auth: []ssh.AuthMethod{ssh.PublicKeys(clientPriv1)}, HostKeys: ssh.FixedHostKey(serverPub1), Mode: TunnelForward, // Forward Tunnel User: "user1", HostAddr: srvLn1.Addr().String(), BindAddr: tcpLn1.Addr().String(), DialAddr: tcpLn2.Addr().String(), Logger: testLogger{t}, } // Start the SSH client tunnels. wg.Add(2) go tn0.Bind(ctx, wg) go tn1.Bind(ctx, wg) t.Log("test started") done := make(chan bool, 10) // Start all the transmitters. for i := 0; i < cap(done); i++ { i := i go func() { for { rnd := rand.New(rand.NewSource(int64(i))) hash := md5.New() size := uint32((1 << 10) + rnd.Intn(1<<20)) buf4 := make([]byte, 4) binary.LittleEndian.PutUint32(buf4, size) cnStart, err := net.Dial("tcp", tcpLn0.Addr().String()) if err != nil { time.Sleep(10 * time.Millisecond) continue } defer cnStart.Close() if _, err := cnStart.Write(buf4); err != nil { t.Errorf("write size error: %v", err) break } r := io.LimitReader(rnd, int64(size)) w := io.MultiWriter(cnStart, hash) if _, err := io.Copy(w, r); err != nil { t.Errorf("copy error: %v", err) break } if _, err := cnStart.Write(hash.Sum(nil)); err != nil { t.Errorf("write hash error: %v", err) break } if err := cnStart.Close(); err != nil { t.Errorf("close error: %v", err) break } break } }() } // Start all the receivers. for i := 0; i < cap(done); i++ { go func() { for { hash := md5.New() buf4 := make([]byte, 4) cnEnd, err := tcpLn2.Accept() if err != nil { time.Sleep(10 * time.Millisecond) continue } defer cnEnd.Close() if _, err := io.ReadFull(cnEnd, buf4); err != nil { t.Errorf("read size error: %v", err) break } size := binary.LittleEndian.Uint32(buf4) r := io.LimitReader(cnEnd, int64(size)) if _, err := io.Copy(hash, r); err != nil { t.Errorf("copy error: %v", err) break } wantHash, err := ioutil.ReadAll(cnEnd) if err != nil { t.Errorf("read hash error: %v", err) break } if err := cnEnd.Close(); err != nil { t.Errorf("close error: %v", err) break } if gotHash := hash.Sum(nil); !bytes.Equal(gotHash, wantHash) { t.Errorf("hash mismatch:\ngot %x\nwant %x", gotHash, wantHash) } break } done <- true }() } for i := 0; i < cap(done); i++ { select { case <-done: case <-time.After(10 * time.Second): t.Errorf("timed out: %d remaining", cap(done)-i) return } } t.Log("test complete") } // generateKeys generates a random pair of SSH private and public keys. func generateKeys(t *testing.T) (priv ssh.Signer, pub ssh.PublicKey) { rnd := rand.New(rand.NewSource(time.Now().Unix())) rsaKey, err := rsa.GenerateKey(rnd, 1024) if err != nil { t.Fatalf("unable to generate RSA key pair: %v", err) } priv, err = ssh.NewSignerFromKey(rsaKey) if err != nil { t.Fatalf("unable to generate signer: %v", err) } pub, err = ssh.NewPublicKey(&rsaKey.PublicKey) if err != nil { t.Fatalf("unable to generate public key: %v", err) } return priv, pub } func openListener(t *testing.T) net.Listener { ln, err := net.Listen("tcp", ":0") if err != nil { t.Fatalf("listen error: %v", err) } return ln } // runServer starts an SSH server capable of handling forward and reverse // TCP tunnels. This function blocks for the entire duration that the // server is running and can be stopped by canceling the context. // // The server listens on the provided Listener and will present to clients // a certificate from serverKey and will only accept users that match // the provided clientKeys. Only users of the name "User%d" are allowed where // the ID number is the index for the specified client key provided. func runServer(t *testing.T, ctx context.Context, ln net.Listener, serverKey ssh.Signer, clientKeys ...ssh.PublicKey) { wg := new(sync.WaitGroup) defer wg.Wait() // Generate SSH server configuration. conf := ssh.ServerConfig{ PublicKeyCallback: func(c ssh.ConnMetadata, pubKey ssh.PublicKey) (*ssh.Permissions, error) { var uid int _, err := fmt.Sscanf(c.User(), "User%d", &uid) if err != nil || uid >= len(clientKeys) || !bytes.Equal(clientKeys[uid].Marshal(), pubKey.Marshal()) { return nil, fmt.Errorf("unknown public key for %q", c.User()) } return nil, nil }, } conf.AddHostKey(serverKey) // Handle every SSH client connection. for { tcpCn, err := ln.Accept() if err != nil { if !isDone(ctx) { t.Errorf("accept error: %v", err) } return } wg.Add(1) go handleServerConn(t, ctx, wg, tcpCn, &conf) } } // handleServerConn handles a single SSH connection. func handleServerConn(t *testing.T, ctx context.Context, wg *sync.WaitGroup, tcpCn net.Conn, conf *ssh.ServerConfig) { defer wg.Done() go closeWhenDone(ctx, tcpCn) defer tcpCn.Close() sshCn, chans, reqs, err := ssh.NewServerConn(tcpCn, conf) if err != nil { t.Errorf("new connection error: %v", err) return } go closeWhenDone(ctx, sshCn) defer sshCn.Close() wg.Add(1) go handleServerChannels(t, ctx, wg, sshCn, chans) wg.Add(1) go handleServerRequests(t, ctx, wg, sshCn, reqs) if err := sshCn.Wait(); err != nil && err != io.EOF && !isDone(ctx) { t.Errorf("connection error: %v", err) } } // handleServerChannels handles new channels on a SSH connection. // The client initiates a new channel when forwarding a TCP dial. func handleServerChannels(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, chans <-chan ssh.NewChannel) { defer wg.Done() for nc := range chans { if nc.ChannelType() != "direct-tcpip" { nc.Reject(ssh.UnknownChannelType, "not implemented") continue } var args struct { DstHost string DstPort uint32 SrcHost string SrcPort uint32 } if !unmarshalData(nc.ExtraData(), &args) { nc.Reject(ssh.Prohibited, "invalid request") continue } // Open a connection for both sides. cn, err := net.Dial("tcp", net.JoinHostPort(args.DstHost, strconv.Itoa(int(args.DstPort)))) if err != nil { nc.Reject(ssh.ConnectionFailed, err.Error()) continue } ch, reqs, err := nc.Accept() if err != nil { t.Errorf("accept channel error: %v", err) cn.Close() continue } go ssh.DiscardRequests(reqs) wg.Add(1) go bidirCopyAndClose(t, ctx, wg, cn, ch) } } // handleServerRequests handles new requests on a SSH connection. // The client initiates a new request for binding a local TCP socket. func handleServerRequests(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, reqs <-chan *ssh.Request) { defer wg.Done() for r := range reqs { if !r.WantReply { continue } if r.Type != "tcpip-forward" { r.Reply(false, nil) continue } var args struct { Host string Port uint32 } if !unmarshalData(r.Payload, &args) { r.Reply(false, nil) continue } ln, err := net.Listen("tcp", net.JoinHostPort(args.Host, strconv.Itoa(int(args.Port)))) if err != nil { r.Reply(false, nil) continue } var resp struct{ Port uint32 } _, resp.Port = splitHostPort(ln.Addr().String()) if err := r.Reply(true, marshalData(resp)); err != nil { t.Errorf("request reply error: %v", err) ln.Close() continue } wg.Add(1) go handleLocalListener(t, ctx, wg, sshCn, ln, args.Host) } } // handleLocalListener handles every new connection on the provided socket. // All local connections will be forwarded to the client via a new channel. func handleLocalListener(t *testing.T, ctx context.Context, wg *sync.WaitGroup, sshCn ssh.Conn, ln net.Listener, host string) { defer wg.Done() go closeWhenDone(ctx, ln) defer ln.Close() for { // Open a connection for both sides. cn, err := ln.Accept() if err != nil { if !isDone(ctx) { t.Errorf("accept error: %v", err) } return } var args struct { DstHost string DstPort uint32 SrcHost string SrcPort uint32 } args.DstHost, args.DstPort = splitHostPort(cn.LocalAddr().String()) args.SrcHost, args.SrcPort = splitHostPort(cn.RemoteAddr().String()) args.DstHost = host // This must match on client side! ch, reqs, err := sshCn.OpenChannel("forwarded-tcpip", marshalData(args)) if err != nil { t.Errorf("open channel error: %v", err) cn.Close() continue } go ssh.DiscardRequests(reqs) wg.Add(1) go bidirCopyAndClose(t, ctx, wg, cn, ch) } } // bidirCopyAndClose performs a bi-directional copy on both connections // until either side closes the connection or the context is canceled. // This will close both connections before returning. func bidirCopyAndClose(t *testing.T, ctx context.Context, wg *sync.WaitGroup, c1, c2 io.ReadWriteCloser) { defer wg.Done() go closeWhenDone(ctx, c1) go closeWhenDone(ctx, c2) defer c1.Close() defer c2.Close() errc := make(chan error, 2) go func() { _, err := io.Copy(c1, c2) errc <- err }() go func() { _, err := io.Copy(c2, c1) errc <- err }() if err := <-errc; err != nil && err != io.EOF && !isDone(ctx) { t.Errorf("copy error: %v", err) } } // unmarshalData parses b into s, where s is a pointer to a struct. // Only unexported fields of type uint32 or string are allowed. func unmarshalData(b []byte, s interface{}) bool { v := reflect.ValueOf(s) if !v.IsValid() || v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct { panic("destination must be pointer to struct") } v = v.Elem() for i := 0; i < v.NumField(); i++ { switch v.Type().Field(i).Type.Kind() { case reflect.Uint32: if len(b) < 4 { return false } v.Field(i).Set(reflect.ValueOf(binary.BigEndian.Uint32(b))) b = b[4:] case reflect.String: if len(b) < 4 { return false } n := binary.BigEndian.Uint32(b) b = b[4:] if uint64(len(b)) < uint64(n) { return false } v.Field(i).Set(reflect.ValueOf(string(b[:n]))) b = b[n:] default: panic("invalid field type: " + v.Type().Field(i).Type.String()) } } return len(b) == 0 } // marshalData serializes s into b, where s is a struct (or a pointer to one). // Only unexported fields of type uint32 or string are allowed. func marshalData(s interface{}) (b []byte) { v := reflect.ValueOf(s) if v.IsValid() && v.Kind() == reflect.Ptr { v = v.Elem() } if !v.IsValid() || v.Kind() != reflect.Struct { panic("source must be a struct") } var arr32 [4]byte for i := 0; i < v.NumField(); i++ { switch v.Type().Field(i).Type.Kind() { case reflect.Uint32: binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Uint())) b = append(b, arr32[:]...) case reflect.String: binary.BigEndian.PutUint32(arr32[:], uint32(v.Field(i).Len())) b = append(b, arr32[:]...) b = append(b, v.Field(i).String()...) default: panic("invalid field type: " + v.Type().Field(i).Type.String()) } } return b } func splitHostPort(s string) (string, uint32) { host, port, _ := net.SplitHostPort(s) p, _ := strconv.Atoi(port) return host, uint32(p) } func closeWhenDone(ctx context.Context, c io.Closer) { <-ctx.Done() c.Close() } func isDone(ctx context.Context) bool { select { case <-ctx.Done(): return true default: return false } }