// Copyright 2017, The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE.md file. package sshtun import ( "context" "fmt" "io" "net" "sync" "sync/atomic" "time" "golang.org/x/crypto/ssh" ) type TunnelMode uint8 func (t TunnelMode) String() string { switch t { case TunnelForward: return "->" case TunnelReverse: return "<-" default: return "" } } const ( TunnelForward TunnelMode = iota TunnelReverse ) type logger interface { Printf(string, ...interface{}) } type Tunnel struct { Auth []ssh.AuthMethod HostKeys ssh.HostKeyCallback Mode TunnelMode User string HostAddr string BindAddr string DialAddr string RetryInterval time.Duration KeepAlive KeepAliveConfig Logger logger } type KeepAliveConfig struct { // Interval is the amount of time in seconds to wait before the // Tunnel client will send a keep-alive message to ensure some minimum // traffic on the SSH connection. Interval uint // CountMax is the maximum number of consecutive failed responses to // keep-alive messages the client is willing to tolerate before considering // the SSH connection as dead. CountMax uint } func (t Tunnel) String() string { var left, right string switch t.Mode { case TunnelForward: left, right = t.BindAddr, t.DialAddr case TunnelReverse: left, right = t.DialAddr, t.BindAddr } return fmt.Sprintf("%s@%s | %s %s %s", t.User, t.HostAddr, left, t.Mode, right) } func (t Tunnel) Bind(ctx context.Context, wg *sync.WaitGroup) { defer wg.Done() for { var once sync.Once // Only print errors once per session func() { // Connect to the server host via SSH. cl, err := ssh.Dial("tcp", t.HostAddr, &ssh.ClientConfig{ User: t.User, Auth: t.Auth, HostKeyCallback: t.HostKeys, Timeout: 5 * time.Second, }) if err != nil { once.Do(func() { t.Logger.Printf("(%v) SSH dial error: %v", t, err) }) return } wg.Add(1) go t.keepAliveMonitor(&once, wg, cl) defer cl.Close() // Attempt to bind to the inbound socket. var ln net.Listener switch t.Mode { case TunnelForward: ln, err = net.Listen("tcp", t.BindAddr) case TunnelReverse: ln, err = cl.Listen("tcp", t.BindAddr) } if err != nil { once.Do(func() { t.Logger.Printf("(%v) bind error: %v", t, err) }) return } // The socket is bound. Make sure we close it eventually. bindCtx, cancel := context.WithCancel(ctx) defer cancel() go func() { cl.Wait() cancel() }() go func() { <-bindCtx.Done() once.Do(func() {}) // Suppress future errors ln.Close() }() t.Logger.Printf("(%v) binded Tunnel", t) defer t.Logger.Printf("(%v) collapsed Tunnel", t) // Accept all incoming connections. for { cn1, err := ln.Accept() if err != nil { once.Do(func() { t.Logger.Printf("(%v) accept error: %v", t, err) }) return } wg.Add(1) go t.dialTunnel(bindCtx, wg, cl, cn1) } }() select { case <-ctx.Done(): return case <-time.After(t.RetryInterval): t.Logger.Printf("(%v) retrying...", t) } } } func (t Tunnel) dialTunnel(ctx context.Context, wg *sync.WaitGroup, client *ssh.Client, cn1 net.Conn) { defer wg.Done() // The inbound connection is established. Make sure we close it eventually. connCtx, cancel := context.WithCancel(ctx) defer cancel() go func() { <-connCtx.Done() cn1.Close() }() // Establish the outbound connection. var cn2 net.Conn var err error switch t.Mode { case TunnelForward: cn2, err = client.Dial("tcp", t.DialAddr) case TunnelReverse: cn2, err = net.Dial("tcp", t.DialAddr) } if err != nil { t.Logger.Printf("(%v) dial error: %v", t, err) return } go func() { <-connCtx.Done() cn2.Close() }() t.Logger.Printf("(%v) connection established", t) defer t.Logger.Printf("(%v) connection closed", t) // Copy bytes from one connection to the other until one side closes. var once sync.Once var wg2 sync.WaitGroup wg2.Add(2) go func() { defer wg2.Done() defer cancel() if _, err := io.Copy(cn1, cn2); err != nil { once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) }) } once.Do(func() {}) // Suppress future errors }() go func() { defer wg2.Done() defer cancel() if _, err := io.Copy(cn2, cn1); err != nil { once.Do(func() { t.Logger.Printf("(%v) connection error: %v", t, err) }) } once.Do(func() {}) // Suppress future errors }() wg2.Wait() } // keepAliveMonitor periodically sends messages to invoke a response. // If the server does not respond after some period of time, // assume that the underlying net.Conn abruptly died. func (t Tunnel) keepAliveMonitor(once *sync.Once, wg *sync.WaitGroup, client *ssh.Client) { defer wg.Done() if t.KeepAlive.Interval == 0 || t.KeepAlive.CountMax == 0 { return } // Detect when the SSH connection is closed. wait := make(chan error, 1) wg.Add(1) go func() { defer wg.Done() wait <- client.Wait() }() // Repeatedly check if the remote server is still alive. var aliveCount int32 ticker := time.NewTicker(time.Duration(t.KeepAlive.Interval) * time.Second) defer ticker.Stop() for { select { case err := <-wait: if err != nil && err != io.EOF { once.Do(func() { t.Logger.Printf("(%v) SSH error: %v", t, err) }) } return case <-ticker.C: if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.KeepAlive.CountMax) { once.Do(func() { t.Logger.Printf("(%v) SSH keep-alive termination", t) }) client.Close() return } } wg.Add(1) go func() { defer wg.Done() _, _, err := client.SendRequest("keepalive@openssh.com", true, nil) if err == nil { atomic.StoreInt32(&aliveCount, 0) } }() } }