diff --git a/client.go b/client.go index 28411ba..c78808b 100644 --- a/client.go +++ b/client.go @@ -65,8 +65,9 @@ func NewClient(addr string, server string, target string, timeout int, key int, } type Client struct { - exit bool - rtt time.Duration + exit bool + rtt time.Duration + interval *time.Ticker id int sequence int @@ -147,7 +148,6 @@ func (p *Client) Run() error { loggo.Error("Error listening for ICMP packets: %s", err.Error()) return err } - defer conn.Close() p.conn = conn if p.tcpmode > 0 { @@ -156,7 +156,6 @@ func (p *Client) Run() error { loggo.Error("Error listening for tcp packets: %s", err.Error()) return err } - defer tcplistenConn.Close() p.tcplistenConn = tcplistenConn } else { listener, err := net.ListenUDP("udp", p.ipaddr) @@ -164,7 +163,6 @@ func (p *Client) Run() error { loggo.Error("Error listening for udp packets: %s", err.Error()) return err } - defer listener.Close() p.listenConn = listener } @@ -180,13 +178,12 @@ func (p *Client) Run() error { recv := make(chan *Packet, 10000) go recvICMP(*p.conn, recv) - interval := time.NewTicker(time.Second) - defer interval.Stop() + p.interval = time.NewTicker(time.Second) go func() { for !p.exit { select { - case <-interval.C: + case <-p.interval.C: p.checkTimeoutConn() p.ping() p.showNet() @@ -201,6 +198,14 @@ func (p *Client) Run() error { func (p *Client) Stop() { p.exit = true + p.conn.Close() + if p.tcplistenConn != nil { + p.tcplistenConn.Close() + } + if p.listenConn != nil { + p.listenConn.Close() + } + p.interval.Stop() } func (p *Client) AcceptTcp() error { @@ -269,7 +274,7 @@ func (p *Client) AcceptTcpConn(conn *net.TCPConn, targetAddr string) { diffclose := now.Sub(startConnectTime) if diffclose > time.Second*(time.Duration(p.timeout)) { loggo.Info("can not connect remote tcp %s %s", uuid, tcpsrcaddr.String()) - p.Close(clientConn) + p.close(clientConn) return } } @@ -416,7 +421,7 @@ func (p *Client) AcceptTcpConn(conn *net.TCPConn, targetAddr string) { loggo.Info("close tcp conn %s %s", clientConn.id, clientConn.tcpaddr.String()) conn.Close() - p.Close(clientConn) + p.close(clientConn) } func (p *Client) Accept() error { @@ -521,7 +526,7 @@ func (p *Client) processPacket(packet *Packet) { p.recvPacketSize += (uint64)(len(packet.my.Data)) } -func (p *Client) Close(clientConn *ClientConn) { +func (p *Client) close(clientConn *ClientConn) { if p.localIdToConnMap[clientConn.id] != nil { delete(p.localIdToConnMap, clientConn.id) delete(p.localAddrToConnMap, clientConn.ipaddr.String()) @@ -546,7 +551,7 @@ func (p *Client) checkTimeoutConn() { for id, conn := range p.localIdToConnMap { if conn.close { loggo.Info("close inactive conn %s %s", id, conn.ipaddr.String()) - p.Close(conn) + p.close(conn) } } } diff --git a/server.go b/server.go index cb9d276..78b228b 100644 --- a/server.go +++ b/server.go @@ -17,8 +17,9 @@ func NewServer(key int) (*Server, error) { } type Server struct { - exit bool - key int + exit bool + key int + interval *time.Ticker conn *icmp.PacketConn @@ -62,13 +63,12 @@ func (p *Server) Run() error { recv := make(chan *Packet, 10000) go recvICMP(*p.conn, recv) - interval := time.NewTicker(time.Second) - defer interval.Stop() + p.interval = time.NewTicker(time.Second) go func() { for !p.exit { select { - case <-interval.C: + case <-p.interval.C: p.checkTimeoutConn() p.showNet() case r := <-recv: @@ -80,6 +80,12 @@ func (p *Server) Run() error { return nil } +func (p *Server) Stop() { + p.exit = true + p.conn.Close() + p.interval.Stop() +} + func (p *Server) processPacket(packet *Packet) { if packet.my.Key != (int32)(p.key) { @@ -212,7 +218,7 @@ func (p *Server) RecvTCP(conn *ServerConn, id string, src *net.IPAddr) { diffclose := now.Sub(startConnectTime) if diffclose > time.Second*(time.Duration(conn.timeout)) { loggo.Info("can not connect remote tcp %s %s", conn.id, conn.tcpaddrTarget.String()) - p.Close(conn) + p.close(conn) return } } @@ -359,7 +365,7 @@ func (p *Server) RecvTCP(conn *ServerConn, id string, src *net.IPAddr) { time.Sleep(time.Second) loggo.Info("close tcp conn %s %s", conn.id, conn.tcpaddrTarget.String()) - p.Close(conn) + p.close(conn) } func (p *Server) Recv(conn *ServerConn, id string, src *net.IPAddr) { @@ -393,7 +399,7 @@ func (p *Server) Recv(conn *ServerConn, id string, src *net.IPAddr) { } } -func (p *Server) Close(conn *ServerConn) { +func (p *Server) close(conn *ServerConn) { if p.localConnMap[conn.id] != nil { if conn.conn != nil { conn.conn.Close() @@ -425,7 +431,7 @@ func (p *Server) checkTimeoutConn() { } if conn.close { loggo.Info("close inactive conn %s %s", id, conn.ipaddrTarget.String()) - p.Close(conn) + p.close(conn) } } }