sshpoke/internal/server/driver/ssh/sshtun/keepalive.go

63 lines
1.4 KiB
Go
Raw Normal View History

package sshtun
import (
"context"
"io"
"sync"
"sync/atomic"
"time"
"github.com/Neur0toxine/sshpoke/pkg/proto/ssh"
)
// keepAlive periodically sends messages to invoke a response.
// If the server does not respond after some period of time,
// assume that the underlying net.Conn abruptly died.
func (t *Tunnel) keepAlive(ctx context.Context, client *ssh.Client, wg *sync.WaitGroup) {
defer wg.Done()
if t.tunConfig.KeepAliveInterval <= 0 || t.tunConfig.KeepAliveMax <= 0 {
return
}
// Detect when the SSH connection is closed.
wait := make(chan error, 1)
wg.Add(1)
go func() {
defer wg.Done()
wait <- client.Wait()
}()
// Repeatedly check if the remote server is still alive.
var aliveCount int32
ticker := time.NewTicker(time.Duration(t.tunConfig.KeepAliveInterval) * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
t.log.Debug("stopping keep-alive...")
_ = client.Close()
return
case err := <-wait:
if err != nil && err != io.EOF {
t.log.Error("ssh error:", err)
}
return
case <-ticker.C:
if n := atomic.AddInt32(&aliveCount, 1); n > int32(t.tunConfig.KeepAliveMax) {
t.log.Error("keep-alive failed, closing connection...")
_ = client.Close()
return
}
}
wg.Add(1)
go func() {
defer wg.Done()
_, _, err := client.SendRequest("keepalive@openssh.com", true, nil)
if err == nil {
atomic.StoreInt32(&aliveCount, 0)
}
}()
}
}