diff --git a/client.go b/client.go index b1e43b5..9658a64 100644 --- a/client.go +++ b/client.go @@ -51,6 +51,7 @@ type ClientConn struct { ipaddr *net.UDPAddr id string activeTime time.Time + close bool } func (p *Client) Addr() string { @@ -137,7 +138,7 @@ func (p *Client) Accept() error { clientConn := p.localAddrToConnMap[srcaddr.String()] if clientConn == nil { uuid := UniqueId() - clientConn = &ClientConn{ipaddr: srcaddr, id: uuid, activeTime: now} + clientConn = &ClientConn{ipaddr: srcaddr, id: uuid, activeTime: now, close: false} p.localAddrToConnMap[srcaddr.String()] = clientConn p.localIdToConnMap[uuid] = clientConn fmt.Printf("client accept new local %s %s\n", uuid, srcaddr.String()) @@ -166,7 +167,7 @@ func (p *Client) processPacket(packet *Packet) { _, err := p.listenConn.WriteToUDP(packet.data, addr) if err != nil { fmt.Printf("WriteToUDP Error read udp %s\n", err) - p.Close(clientConn) + clientConn.close = true return } } @@ -180,9 +181,15 @@ func (p *Client) Close(clientConn *ClientConn) { func (p *Client) checkTimeoutConn() { now := time.Now() - for id, conn := range p.localIdToConnMap { + for _, conn := range p.localIdToConnMap { diff := now.Sub(conn.activeTime) if diff > time.Second*(time.Duration(p.timeout)) { + conn.close = true + } + } + + for id, conn := range p.localIdToConnMap { + if conn.close { fmt.Printf("close inactive conn %s %s\n", id, conn.ipaddr.String()) p.Close(conn) } diff --git a/server.go b/server.go index 31ceb8e..ee2ab94 100644 --- a/server.go +++ b/server.go @@ -26,6 +26,7 @@ type ServerConn struct { conn *net.UDPConn id string activeTime time.Time + close bool } func (p *Server) Run() { @@ -77,7 +78,7 @@ func (p *Server) processPacket(packet *Packet) { fmt.Printf("Error listening for udp packets: %s\n", err.Error()) return } - udpConn = &ServerConn{conn: targetConn, ipaddrTarget: ipaddrTarget, id: id, activeTime: now} + udpConn = &ServerConn{conn: targetConn, ipaddrTarget: ipaddrTarget, id: id, activeTime: now, close: false} p.localConnMap[id] = udpConn go p.Recv(udpConn, id, packet.src) } @@ -87,7 +88,7 @@ func (p *Server) processPacket(packet *Packet) { _, err := udpConn.conn.Write(packet.data) if err != nil { fmt.Printf("WriteToUDP Error %s\n", err) - p.Close(udpConn) + udpConn.close = true return } } @@ -108,7 +109,7 @@ func (p *Server) Recv(conn *ServerConn, id string, src *net.IPAddr) { continue } else { fmt.Printf("ReadFromUDP Error read udp %s\n", err) - p.Close(conn) + conn.close = true return } } @@ -131,12 +132,17 @@ func (p *Server) Close(conn *ServerConn) { func (p *Server) checkTimeoutConn() { now := time.Now() - for id, conn := range p.localConnMap { + for _, conn := range p.localConnMap { diff := now.Sub(conn.activeTime) if diff > time.Second*(time.Duration(p.timeout)) { + conn.close = true + } + } + + for id, conn := range p.localConnMap { + if conn.close { fmt.Printf("close inactive conn %s %s\n", id, conn.ipaddrTarget.String()) p.Close(conn) } } - }