diff --git a/src/pingtunnel/client.go b/src/pingtunnel/client.go index 11bdf56..1c2db53 100644 --- a/src/pingtunnel/client.go +++ b/src/pingtunnel/client.go @@ -1,17 +1,15 @@ package pingtunnel import ( - "encoding/json" "fmt" "golang.org/x/net/icmp" - "golang.org/x/net/ipv4" "net" - "syscall" + "time" ) func NewClient(addr string, target string) (*Client, error) { - ipaddr, err := net.ResolveTCPAddr("tcp", addr) + ipaddr, err := net.ResolveUDPAddr("udp", addr) if err != nil { return nil, err } @@ -30,21 +28,24 @@ func NewClient(addr string, target string) (*Client, error) { } type Client struct { - ipaddr *net.TCPAddr + ipaddr *net.UDPAddr addr string ipaddrTarget *net.IPAddr addrTarget string conn *icmp.PacketConn - listenConn *net.TCPListener + listenConn *net.UDPConn + + localConnToIdMap map[string]uint32 + localIdToConnMap map[uint32]*net.UDPAddr } func (p *Client) Addr() string { return p.addr } -func (p *Client) IPAddr() *net.TCPAddr { +func (p *Client) IPAddr() *net.UDPAddr { return p.ipaddr } @@ -66,81 +67,78 @@ func (p *Client) Run() { defer conn.Close() p.conn = conn - listener, err := net.ListenTCP("tcp", p.ipaddr) + listener, err := net.ListenUDP("udp", p.ipaddr) if err != nil { - fmt.Printf("Error listening for tcp packets: %s\n", err.Error()) + fmt.Printf("Error listening for udp packets: %s\n", err.Error()) return } - + defer listener.Close() p.listenConn = listener - p.Accept() + p.localConnToIdMap = make(map[string]uint32) + p.localIdToConnMap = make(map[uint32]*net.UDPAddr) + + go p.Accept() + + recv := make(chan *Packet, 1000) + go recvICMP(*p.conn, recv) + + for { + select { + case r := <-recv: + p.processPacket(r) + } + } } func (p *Client) Accept() error { fmt.Println("client waiting local accept") + bytes := make([]byte, 10240) + for { - localConn, err := p.listenConn.AcceptTCP() + p.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) + n, srcaddr, err := p.listenConn.ReadFromUDP(bytes) if err != nil { - fmt.Println(err) - continue - } - - localConn.SetLinger(0) - go p.handleConn(*localConn) - } -} - -func (p *Client) handleConn(conn net.TCPConn) { - defer conn.Close() - - uuid := UniqueId() - - fmt.Printf("client new conn %s %s", conn.RemoteAddr().String(), uuid) - - data, err := json.Marshal(RegisterData{localaddr: conn.RemoteAddr().String()}) - if err != nil { - fmt.Printf("Unable to marshal data %s\n", err) - return - } - - for { - p.sendICMP(uuid, REGISTER, data) - } -} - -func (p *Client) sendICMP(connId string, msgType MSGID, data []byte) error { - - body := &Msg{ - ID: connId, - TYPE: (int)(msgType), - Data: data, - } - - msg := &icmp.Message{ - Type: ipv4.ICMPTypeExtendedEchoRequest, - Code: 0, - Body: body, - } - - bytes, err := msg.Marshal(nil) - if err != nil { - return err - } - - for { - if _, err := (*p.conn).WriteTo(bytes, p.ipaddrTarget); err != nil { if neterr, ok := err.(*net.OpError); ok { - if neterr.Err == syscall.ENOBUFS { + if neterr.Timeout() { + // Read timeout + continue + } else { + fmt.Printf("Error read udp %s\n", err) continue } } - fmt.Printf("sendICMP error %s %s\n", p.ipaddrTarget.String(), err) } - break + + uuid := p.localConnToIdMap[srcaddr.String()] + if uuid == 0 { + uuid = UniqueId() + p.localConnToIdMap[srcaddr.String()] = uuid + p.localIdToConnMap[uuid] = srcaddr + fmt.Printf("client accept new local %d %s\n", uuid, srcaddr.String()) + } + + sendICMP(*p.conn, p.ipaddrTarget, uuid, (uint32)(DATA), bytes[:n]) + } +} + +func (p *Client) processPacket(packet *Packet) { + + fmt.Printf("processPacket %d %s %d\n", packet.id, packet.src.String(), len(packet.data)) + + addr := p.localIdToConnMap[packet.id] + if addr == nil { + fmt.Printf("processPacket no conn %d \n", packet.id) + return } - return nil + _, err := p.listenConn.WriteToUDP(packet.data, addr) + if err != nil { + fmt.Printf("WriteToUDP Error read udp %s\n", err) + p.localConnToIdMap[addr.String()] = 0 + p.localIdToConnMap[packet.id] = nil + return + } } diff --git a/src/pingtunnel/pingtunnel.go b/src/pingtunnel/pingtunnel.go index e4d144b..56d5a94 100644 --- a/src/pingtunnel/pingtunnel.go +++ b/src/pingtunnel/pingtunnel.go @@ -1,60 +1,135 @@ package pingtunnel import ( - "crypto/md5" - "crypto/rand" - "encoding/base64" "encoding/binary" - "encoding/hex" - "io" + "fmt" + "golang.org/x/net/icmp" + "golang.org/x/net/ipv4" + "net" + "sync/atomic" + "syscall" + "time" ) type MSGID int const ( - REGISTER MSGID = 1 + DATA MSGID = 0xDEADBEEF ) const ( - protocolICMP = 1 + protocolICMP = 1 ) -type Msg struct { - TYPE int - ID string // identifier - Data []byte // data +// An Echo represents an ICMP echo request or reply message body. +type MyMsg struct { + ID uint32 + TYPE uint32 + Data []byte } -func (p *Msg) Len(proto int) int { +// Len implements the Len method of MessageBody interface. +func (p *MyMsg) Len(proto int) int { if p == nil { return 0 } - return 4 + 32 + len(p.Data) + return 8 + len(p.Data) } -func (p *Msg) Marshal(proto int) ([]byte, error) { +// Marshal implements the Marshal method of MessageBody interface. +func (p *MyMsg) Marshal(proto int) ([]byte, error) { b := make([]byte, p.Len(proto)) - binary.BigEndian.PutUint32(b, uint32(p.TYPE)) - copy(b[4:], p.ID) - copy(b[4+32:], p.Data) + binary.BigEndian.PutUint32(b[:4], uint32(p.ID)) + binary.BigEndian.PutUint32(b[4:8], uint32(p.TYPE)) + copy(b[8:], p.Data) return b, nil } -func UniqueId() string { - b := make([]byte, 48) +// Marshal implements the Marshal method of MessageBody interface. +func (p *MyMsg) Unmarshal(b []byte) error { + p.ID = binary.BigEndian.Uint32(b[:4]) + p.TYPE = binary.BigEndian.Uint32(b[4:8]) + p.Data = make([]byte, len(b[8:])) + copy(p.Data, b[8:]) + return nil +} - if _, err := io.ReadFull(rand.Reader, b); err != nil { - return "" +var uuid uint32 + +func UniqueId() uint32 { + newValue := atomic.AddUint32(&uuid, 1) + return (uint32)(newValue) +} + +func sendICMP(conn icmp.PacketConn, target *net.IPAddr, connId uint32, msgType uint32, data []byte) { + + m := &MyMsg{ + ID: connId, + TYPE: msgType, + Data: data, } - return GetMd5String(base64.URLEncoding.EncodeToString(b)) + + msg := &icmp.Message{ + Type: ipv4.ICMPTypeExtendedEchoRequest, + Code: 0, + Body: m, + } + + bytes, err := msg.Marshal(nil) + if err != nil { + fmt.Printf("sendICMP Marshal error %s %s\n", target.String(), err) + return + } + + for { + if _, err := conn.WriteTo(bytes, target); err != nil { + if neterr, ok := err.(*net.OpError); ok { + if neterr.Err == syscall.ENOBUFS { + continue + } + } + fmt.Printf("sendICMP WriteTo error %s %s\n", target.String(), err) + } + break + } + + return } -func GetMd5String(s string) string { - h := md5.New() - h.Write([]byte(s)) - return hex.EncodeToString(h.Sum(nil)) +func recvICMP(conn icmp.PacketConn, recv chan<- *Packet) { + + bytes := make([]byte, 10240) + for { + conn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) + n, srcaddr, err := conn.ReadFrom(bytes) + + if err != nil { + if neterr, ok := err.(*net.OpError); ok { + if neterr.Timeout() { + // Read timeout + continue + } else { + fmt.Printf("Error read icmp message %s\n", err) + continue + } + } + } + + my := &MyMsg{ + } + my.Unmarshal(bytes[4:n]) + + if my.TYPE != (uint32)(DATA) { + fmt.Printf("processPacket diff type %d \n", my.TYPE) + continue + } + + recv <- &Packet{data: my.Data, id: my.ID, src: srcaddr.(*net.IPAddr)} + } } -type RegisterData struct { - localaddr string +type Packet struct { + data []byte + id uint32 + src *net.IPAddr } diff --git a/src/pingtunnel/server.go b/src/pingtunnel/server.go index f149dd2..7afc986 100644 --- a/src/pingtunnel/server.go +++ b/src/pingtunnel/server.go @@ -9,7 +9,7 @@ import ( func NewServer(target string) (*Server, error) { - ipaddrTarget, err := net.ResolveTCPAddr("tcp", target) + ipaddrTarget, err := net.ResolveUDPAddr("udp", target) if err != nil { return nil, err } @@ -21,17 +21,19 @@ func NewServer(target string) (*Server, error) { } type Server struct { - ipaddrTarget *net.TCPAddr + ipaddrTarget *net.UDPAddr addrTarget string - conn net.PacketConn + conn *icmp.PacketConn + + localConnMap map[uint32]*net.UDPConn } func (p *Server) TargetAddr() string { return p.addrTarget } -func (p *Server) TargetIPAddr() *net.TCPAddr { +func (p *Server) TargetIPAddr() *net.UDPAddr { return p.ipaddrTarget } @@ -44,43 +46,66 @@ func (p *Server) Run() { } p.conn = conn - p.Recv() -} + p.localConnMap = make(map[uint32]*net.UDPConn) -func (p *Server) Recv() error { + recv := make(chan *Packet, 1000) + go recvICMP(*p.conn, recv) for { - bytes := make([]byte, 512) - p.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) - n, srcaddr, err := p.conn.ReadFrom(bytes) + select { + case r := <-recv: + p.processPacket(r) + } + } +} +func (p *Server) processPacket(packet *Packet) { + + fmt.Printf("processPacket %d %s %d\n", packet.id, packet.src.String(), len(packet.data)) + + id := packet.id + udpConn := p.localConnMap[id] + if udpConn == nil { + targetConn, err := net.ListenUDP("udp", p.ipaddrTarget) + if err != nil { + fmt.Printf("Error listening for udp packets: %s\n", err.Error()) + return + } + udpConn = targetConn + p.localConnMap[id] = udpConn + go p.Recv(udpConn, id, packet.src) + } + + _, err := udpConn.WriteToUDP(packet.data, p.ipaddrTarget) + if err != nil { + fmt.Printf("WriteToUDP Error read udp %s\n", err) + p.localConnMap[id] = nil + return + } +} + +func (p *Server) Recv(conn *net.UDPConn, id uint32, src *net.IPAddr) { + + fmt.Println("server waiting target response") + + bytes := make([]byte, 10240) + + for { + p.conn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) + n, _, err := conn.ReadFromUDP(bytes) if err != nil { if neterr, ok := err.(*net.OpError); ok { if neterr.Timeout() { // Read timeout continue } else { - return err + fmt.Printf("ReadFromUDP Error read udp %s\n", err) + p.localConnMap[id] = nil + return } } } - var m *icmp.Message - if m, err = icmp.ParseMessage(protocolICMP, bytes[:n]); err != nil { - fmt.Println("Error parsing icmp message") - return err - } - - fmt.Printf("%d %d %d %s \n", m.Type, m.Code, n, srcaddr) + sendICMP(*p.conn, src, id, (uint32)(DATA), bytes[:n]) } } - -func (p *Server) listen(netProto string, source string) *icmp.PacketConn { - - conn, err := icmp.ListenPacket(netProto, source) - if err != nil { - fmt.Printf("Error listening for ICMP packets: %s\n", err.Error()) - return nil - } - return conn -}