63 lines
1.4 KiB
Go
63 lines
1.4 KiB
Go
|
package sshtun
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"io"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"time"
|
||
|
|
||
|
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
|
||
|
)
|
||
|
|
||
|
// keepAlive periodically sends messages to invoke a response.
|
||
|
// If the server does not respond after some period of time,
|
||
|
// assume that the underlying net.Conn abruptly died.
|
||
|
func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) {
|
||
|
defer wg.Done()
|
||
|
if t.sessConfig.KeepAliveInterval == 0 || t.sessConfig.KeepAliveMax == 0 {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// Detect when the SSH connection is closed.
|
||
|
wait := make(chan error, 1)
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
wait <- client.Wait()
|
||
|
}()
|
||
|
|
||
|
// Repeatedly check if the remote server is still alive.
|
||
|
var aliveCount int32
|
||
|
ticker := time.NewTicker(time.Duration(t.sessConfig.KeepAliveInterval) * time.Second)
|
||
|
defer ticker.Stop()
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
t.log.Debug("stopping keep-alive...")
|
||
|
_ = client.Close()
|
||
|
return
|
||
|
case err := <-wait:
|
||
|
if err != nil && err != io.EOF {
|
||
|
t.log.Error("ssh error:", err)
|
||
|
}
|
||
|
return
|
||
|
case <-ticker.C:
|
||
|
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.sessConfig.KeepAliveMax) {
|
||
|
t.log.Error("keep-alive failed, closing connection...")
|
||
|
_ = client.Close()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
wg.Add(1)
|
||
|
go func() {
|
||
|
defer wg.Done()
|
||
|
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
|
||
|
if err == nil {
|
||
|
atomic.StoreInt32(&aliveCount, 0)
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
}
|