diff --git a/client.go b/client.go index 56622ec..0c1fb43 100644 --- a/client.go +++ b/client.go @@ -19,7 +19,7 @@ const ( func NewClient(addr string, server string, target string, timeout int, key int, tcpmode int, tcpmode_buffersize int, tcpmode_maxwin int, tcpmode_resend_timems int, tcpmode_compress int, - tcpmode_stat int, open_sock5 int) (*Client, error) { + tcpmode_stat int, open_sock5 int, maxconn int) (*Client, error) { var ipaddr *net.UDPAddr var tcpaddr *net.TCPAddr @@ -62,6 +62,7 @@ func NewClient(addr string, server string, target string, timeout int, key int, tcpmode_compress: tcpmode_compress, tcpmode_stat: tcpmode_stat, open_sock5: open_sock5, + maxconn: maxconn, }, nil } @@ -70,6 +71,7 @@ type Client struct { rtt time.Duration interval *time.Ticker workResultLock sync.WaitGroup + maxconn int id int sequence int @@ -272,9 +274,15 @@ func (p *Client) AcceptTcpConn(conn *net.TCPConn, targetAddr string) { p.workResultLock.Add(1) defer p.workResultLock.Done() - uuid := UniqueId() tcpsrcaddr := conn.RemoteAddr().(*net.TCPAddr) + if p.localIdToConnMapSize >= p.maxconn { + loggo.Info("too many connections %d, client accept new local tcp fail %s", p.localIdToConnMapSize, tcpsrcaddr.String()) + return + } + + uuid := UniqueId() + fm := NewFrameMgr(p.tcpmode_buffersize, p.tcpmode_maxwin, p.tcpmode_resend_timems, p.tcpmode_compress, p.tcpmode_stat) now := time.Now() @@ -484,6 +492,10 @@ func (p *Client) Accept() error { now := time.Now() clientConn := p.getClientConnByAddr(srcaddr.String()) if clientConn == nil { + if p.localIdToConnMapSize >= p.maxconn { + loggo.Info("too many connections %d, client accept new local udp fail %s", p.localIdToConnMapSize, srcaddr.String()) + continue + } uuid := UniqueId() clientConn = &ClientConn{ipaddr: srcaddr, id: uuid, activeRecvTime: now, activeSendTime: now, close: false} p.addClientConn(uuid, srcaddr.String(), clientConn) diff --git a/cmd/main.go b/cmd/main.go index a909d56..077b97c 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -71,6 +71,9 @@ Usage: -sock5 开启sock5转发,默认0 Turn on sock5 forwarding, default 0 is off + + -maxconn 最大连接数,默认1000 + the max num of connections, default 1000 ` func main() { @@ -90,6 +93,7 @@ func main() { tcpmode_stat := flag.Int("tcp_stat", 0, "print tcp stat") loglevel := flag.String("loglevel", "info", "log level") open_sock5 := flag.Int("sock5", 0, "sock5 mode") + maxconn := flag.Int("maxconn", 0, "max num of connections") flag.Usage = func() { fmt.Printf(usage) } @@ -132,7 +136,7 @@ func main() { loggo.Info("key %d", *key) if *t == "server" { - s, err := pingtunnel.NewServer(*key) + s, err := pingtunnel.NewServer(*key, *maxconn) if err != nil { loggo.Error("ERROR: %s", err.Error()) return @@ -143,8 +147,7 @@ func main() { loggo.Error("Run ERROR: %s", err.Error()) return } - } - if *t == "client" { + } else if *t == "client" { loggo.Info("type %s", *t) loggo.Info("listen %s", *listen) @@ -161,7 +164,7 @@ func main() { c, err := pingtunnel.NewClient(*listen, *server, *target, *timeout, *key, *tcpmode, *tcpmode_buffersize, *tcpmode_maxwin, *tcpmode_resend_timems, *tcpmode_compress, - *tcpmode_stat, *open_sock5) + *tcpmode_stat, *open_sock5, *maxconn) if err != nil { loggo.Error("ERROR: %s", err.Error()) return @@ -173,6 +176,8 @@ func main() { loggo.Error("Run ERROR: %s", err.Error()) return } + } else { + return } for { time.Sleep(time.Hour) diff --git a/server.go b/server.go index 6cab990..1fcb7f3 100644 --- a/server.go +++ b/server.go @@ -10,10 +10,11 @@ import ( "time" ) -func NewServer(key int) (*Server, error) { +func NewServer(key int, maxconn int) (*Server, error) { return &Server{ - exit: false, - key: key, + exit: false, + key: key, + maxconn: maxconn, }, nil } @@ -22,15 +23,17 @@ type Server struct { key int interval *time.Ticker workResultLock sync.WaitGroup + maxconn int conn *icmp.PacketConn localConnMap sync.Map - sendPacket uint64 - recvPacket uint64 - sendPacketSize uint64 - recvPacketSize uint64 + sendPacket uint64 + recvPacket uint64 + sendPacketSize uint64 + recvPacketSize uint64 + localConnMapSize int echoId int echoSeq int @@ -118,6 +121,11 @@ func (p *Server) processPacket(packet *Packet) { localConn := p.getServerConnById(id) if localConn == nil { + if p.localConnMapSize >= p.maxconn { + loggo.Info("too many connections %d, server connected target fail %s", p.localConnMapSize, packet.my.Target) + return + } + if packet.my.Tcpmode > 0 { addr := packet.my.Target @@ -449,8 +457,13 @@ func (p *Server) checkTimeoutConn() { } func (p *Server) showNet() { - loggo.Info("send %dPacket/s %dKB/s recv %dPacket/s %dKB/s", - p.sendPacket, p.sendPacketSize/1024, p.recvPacket, p.recvPacketSize/1024) + p.localConnMapSize = 0 + p.localConnMap.Range(func(key, value interface{}) bool { + p.localConnMapSize++ + return true + }) + loggo.Info("send %dPacket/s %dKB/s recv %dPacket/s %dKB/s %dConnections", + p.sendPacket, p.sendPacketSize/1024, p.recvPacket, p.recvPacketSize/1024, p.localConnMapSize) p.sendPacket = 0 p.recvPacket = 0 p.sendPacketSize = 0