diff --git a/client.go b/client.go index f57e8bb..df94a3a 100644 --- a/client.go +++ b/client.go @@ -186,14 +186,10 @@ func (p *Client) AcceptTcp() error { conn, err := p.tcplistenConn.AcceptTCP() if err != nil { - if neterr, ok := err.(*net.OpError); ok { - if neterr.Timeout() { - // Read timeout - continue - } else { - loggo.Error("Error accept tcp %s", err) - continue - } + nerr, ok := err.(net.Error) + if !ok || !nerr.Timeout() { + loggo.Error("Error accept tcp %s", err) + continue } } @@ -209,6 +205,7 @@ func (p *Client) AcceptTcpConn(conn *net.TCPConn) { fm := NewFrameMgr(p.tcpmode_buffersize, p.tcpmode_maxwin, p.tcpmode_resend_timems) + now := time.Now() clientConn := &ClientConn{tcpaddr: tcpsrcaddr, id: uuid, activeTime: now, close: false, fm: fm} p.localAddrToConnMap[tcpsrcaddr.String()] = clientConn @@ -223,14 +220,10 @@ func (p *Client) AcceptTcpConn(conn *net.TCPConn) { conn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) n, err := conn.Read(bytes) if err != nil { - if neterr, ok := err.(*net.OpError); ok { - if neterr.Timeout() { - // Read timeout - n = 0 - } else { - loggo.Error("Error read tcp %s %s %s", uuid, tcpsrcaddr.String(), err) - break - } + nerr, ok := err.(net.Error) + if !ok || !nerr.Timeout() { + loggo.Error("Error read tcp %s %s %s", uuid, tcpsrcaddr.String(), err) + break } } if n > 0 { @@ -280,16 +273,15 @@ func (p *Client) Accept() error { p.listenConn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) n, srcaddr, err := p.listenConn.ReadFromUDP(bytes) if err != nil { - if neterr, ok := err.(*net.OpError); ok { - if neterr.Timeout() { - // Read timeout - continue - } else { - loggo.Error("Error read udp %s", err) - continue - } + nerr, ok := err.(net.Error) + if !ok || !nerr.Timeout() { + loggo.Error("Error read udp %s", err) + continue } } + if n <= 0 { + continue + } now := time.Now() clientConn := p.localAddrToConnMap[srcaddr.String()] diff --git a/framemgr.go b/framemgr.go index 773e3dd..d101b7b 100644 --- a/framemgr.go +++ b/framemgr.go @@ -55,6 +55,8 @@ func (fm *FrameMgr) Update() { fm.processRecvList(tmpreq, tmpack, tmpackto) fm.calSendList() + + fm.combineWindowToRecvBuffer() } func (fm *FrameMgr) cutSendBufferToWindow() { @@ -72,7 +74,7 @@ func (fm *FrameMgr) cutSendBufferToWindow() { fm.sendb.Read(f.Data) fm.sendid++ - if fm.sendid > FRAME_MAX_ID { + if fm.sendid >= FRAME_MAX_ID { fm.sendid = 0 } @@ -86,7 +88,7 @@ func (fm *FrameMgr) cutSendBufferToWindow() { fm.sendb.Read(f.Data) fm.sendid++ - if fm.sendid > FRAME_MAX_ID { + if fm.sendid >= FRAME_MAX_ID { fm.sendid = 0 } @@ -180,6 +182,16 @@ func (fm *FrameMgr) processRecvList(tmpreq map[int32]int, tmpack map[int32]int, func (fm *FrameMgr) addToRecvWin(rf *Frame) { + begin := fm.recvid + end := fm.recvid + fm.windowsize + id := (int)(rf.Id) + if id < begin { + id += FRAME_MAX_ID + } + if id > end || id < begin { + return + } + for e := fm.recvwin.Front(); e != nil; e = e.Next() { f := e.Value.(*Frame) if f.Id == rf.Id { @@ -189,15 +201,66 @@ func (fm *FrameMgr) addToRecvWin(rf *Frame) { for e := fm.recvwin.Front(); e != nil; e = e.Next() { f := e.Value.(*Frame) - if rf.Id > (int32)(fm.recvid) && rf.Id < f.Id { + if fm.compareId(rf, f) < 0 { fm.recvwin.InsertBefore(rf, e) return } } - if fm.recvwin.Len() > 0 { - fm.recvwin.PushBack(rf) - } else { - fm.recvwin.PushBack(rf) + fm.recvwin.PushBack(rf) +} + +func (fm *FrameMgr) compareId(lf *Frame, rf *Frame) int { + + l := (int)(lf.Id) + r := (int)(rf.Id) + if l < fm.recvid { + l += FRAME_MAX_ID + } + if r < fm.recvid { + r += FRAME_MAX_ID + } + + return l - r +} + +func (fm *FrameMgr) combineWindowToRecvBuffer() { + + id := fm.recvid + + for { + done := false + for e := fm.recvwin.Front(); e != nil; e = e.Next() { + f := e.Value.(*Frame) + if f.Id == (int32)(id) { + left := fm.recvb.Capacity() - fm.recvb.Size() + if left >= len(f.Data) { + fm.recvb.Write(f.Data) + fm.recvwin.Remove(e) + done = true + break + } + } + } + if !done { + break + } else { + fm.recvid++ + if fm.recvid >= FRAME_MAX_ID { + fm.recvid = 0 + } + } } } + +func (fm *FrameMgr) GetRecvBufferSize() int { + return fm.recvb.Size() +} + +func (fm *FrameMgr) GetRecvReadLineBuffer() []byte { + return fm.recvb.GetReadLineBuffer() +} + +func (fm *FrameMgr) SkipRecvBuffer(size int) { + fm.recvb.SkipRead(size) +} diff --git a/pingtunnel.go b/pingtunnel.go index 30a4513..8cae0cb 100644 --- a/pingtunnel.go +++ b/pingtunnel.go @@ -136,5 +136,5 @@ func GetMd5String(s string) string { const ( FRAME_MAX_SIZE int = 888 - FRAME_MAX_ID int = 999 + FRAME_MAX_ID int = 10000 ) diff --git a/pingtunnel_test.go b/pingtunnel_test.go index 23891a3..20d5a73 100644 --- a/pingtunnel_test.go +++ b/pingtunnel_test.go @@ -6,7 +6,7 @@ import ( "testing" ) -func Test0001(test *testing.T) { +func Test0001(t *testing.T) { my := &MyMsg{} my.Id = "12345" @@ -22,4 +22,32 @@ func Test0001(test *testing.T) { proto.Unmarshal(dst[0:4], my1) fmt.Println("my1 = ", my1) + + fm := FrameMgr{} + fm.recvid = 0 + fm.windowsize = 100 + lr := &Frame{} + rr := &Frame{} + lr.Id = 1 + rr.Id = 2 + fmt.Println("fm.compareId(lr, rr) = ", fm.compareId(lr, rr)) + + lr.Id = 99 + rr.Id = 8 + fmt.Println("fm.compareId(lr, rr) = ", fm.compareId(lr, rr)) + + fm.recvid = 9000 + lr.Id = 9998 + rr.Id = 9999 + fmt.Println("fm.compareId(lr, rr) = ", fm.compareId(lr, rr)) + + fm.recvid = 9000 + lr.Id = 9998 + rr.Id = 8 + fmt.Println("fm.compareId(lr, rr) = ", fm.compareId(lr, rr)) + + fm.recvid = 0 + lr.Id = 9998 + rr.Id = 8 + fmt.Println("fm.compareId(lr, rr) = ", fm.compareId(lr, rr)) } diff --git a/server.go b/server.go index ab172b4..d4b333c 100644 --- a/server.go +++ b/server.go @@ -188,14 +188,10 @@ func (p *Server) RecvTCP(conn *ServerConn, id string, src *net.IPAddr) { conn.tcpconn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) n, err := conn.tcpconn.Read(bytes) if err != nil { - if neterr, ok := err.(*net.OpError); ok { - if neterr.Timeout() { - // Read timeout - n = 0 - } else { - loggo.Error("Error read tcp %s %s %s", conn.id, conn.tcpaddrTarget.String(), err) - break - } + nerr, ok := err.(net.Error) + if !ok || !nerr.Timeout() { + loggo.Error("Error read tcp %s %s %s", conn.id, conn.tcpaddrTarget.String(), err) + break } } if n > 0 { @@ -226,7 +222,27 @@ func (p *Server) RecvTCP(conn *ServerConn, id string, src *net.IPAddr) { p.sendPacket++ p.sendPacketSize += (uint64)(len(mb)) } + + if conn.fm.GetRecvBufferSize() > 0 { + rr := conn.fm.GetRecvReadLineBuffer() + + conn.tcpconn.SetWriteDeadline(time.Now().Add(time.Millisecond * 100)) + n, err := conn.tcpconn.Write(rr) + if err != nil { + nerr, ok := err.(net.Error) + if !ok || !nerr.Timeout() { + loggo.Error("Error write tcp %s %s %s", conn.id, conn.tcpaddrTarget.String(), err) + break + } + } + if n > 0 { + conn.fm.SkipRecvBuffer(n) + } + } } + + loggo.Info("close tcp conn %s %s", conn.id, conn.tcpaddrTarget.String()) + p.Close(conn) } func (p *Server) Recv(conn *ServerConn, id string, src *net.IPAddr) { @@ -239,15 +255,11 @@ func (p *Server) Recv(conn *ServerConn, id string, src *net.IPAddr) { conn.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) n, _, err := conn.conn.ReadFromUDP(bytes) if err != nil { - if neterr, ok := err.(*net.OpError); ok { - if neterr.Timeout() { - // Read timeout - continue - } else { - loggo.Error("ReadFromUDP Error read udp %s", err) - conn.close = true - return - } + nerr, ok := err.(net.Error) + if !ok || !nerr.Timeout() { + loggo.Error("ReadFromUDP Error read udp %s", err) + conn.close = true + return } }