package pingtunnel import ( "encoding/binary" "errors" "io" "net" "strconv" "time" ) var ( errAddrType = errors.New("socks addr type not supported") errVer = errors.New("socks version not supported") errMethod = errors.New("socks only support 1 method now") errAuthExtraData = errors.New("socks authentication get extra data") errReqExtraData = errors.New("socks request get extra data") errCmd = errors.New("socks command not supported") ) const ( socksVer5 = 5 socksCmdConnect = 1 ) func sock5Handshake(conn net.Conn) (err error) { const ( idVer = 0 idNmethod = 1 ) // version identification and method selection message in theory can have // at most 256 methods, plus version and nmethod field in total 258 bytes // the current rfc defines only 3 authentication methods (plus 2 reserved), // so it won't be such long in practice buf := make([]byte, 258) var n int conn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) // make sure we get the nmethod field if n, err = io.ReadAtLeast(conn, buf, idNmethod+1); err != nil { return } if buf[idVer] != socksVer5 { return errVer } nmethod := int(buf[idNmethod]) msgLen := nmethod + 2 if n == msgLen { // handshake done, common case // do nothing, jump directly to send confirmation } else if n < msgLen { // has more methods to read, rare case if _, err = io.ReadFull(conn, buf[n:msgLen]); err != nil { return } } else { // error, should not get extra data return errAuthExtraData } // send confirmation: version 5, no authentication required _, err = conn.Write([]byte{socksVer5, 0}) return } func sock5GetRequest(conn net.Conn) (rawaddr []byte, host string, err error) { const ( idVer = 0 idCmd = 1 idType = 3 // address type index idIP0 = 4 // ip address start index idDmLen = 4 // domain address length index idDm0 = 5 // domain address start index typeIPv4 = 1 // type is ipv4 address typeDm = 3 // type is domain address typeIPv6 = 4 // type is ipv6 address lenIPv4 = 3 + 1 + net.IPv4len + 2 // 3(ver+cmd+rsv) + 1addrType + ipv4 + 2port lenIPv6 = 3 + 1 + net.IPv6len + 2 // 3(ver+cmd+rsv) + 1addrType + ipv6 + 2port lenDmBase = 3 + 1 + 1 + 2 // 3 + 1addrType + 1addrLen + 2port, plus addrLen ) // refer to getRequest in server.go for why set buffer size to 263 buf := make([]byte, 263) var n int conn.SetReadDeadline(time.Now().Add(time.Millisecond * 100)) // read till we get possible domain length field if n, err = io.ReadAtLeast(conn, buf, idDmLen+1); err != nil { return } // check version and cmd if buf[idVer] != socksVer5 { err = errVer return } if buf[idCmd] != socksCmdConnect { err = errCmd return } reqLen := -1 switch buf[idType] { case typeIPv4: reqLen = lenIPv4 case typeIPv6: reqLen = lenIPv6 case typeDm: reqLen = int(buf[idDmLen]) + lenDmBase default: err = errAddrType return } if n == reqLen { // common case, do nothing } else if n < reqLen { // rare case if _, err = io.ReadFull(conn, buf[n:reqLen]); err != nil { return } } else { err = errReqExtraData return } rawaddr = buf[idType:reqLen] switch buf[idType] { case typeIPv4: host = net.IP(buf[idIP0 : idIP0+net.IPv4len]).String() case typeIPv6: host = net.IP(buf[idIP0 : idIP0+net.IPv6len]).String() case typeDm: host = string(buf[idDm0 : idDm0+buf[idDmLen]]) } port := binary.BigEndian.Uint16(buf[reqLen-2 : reqLen]) host = net.JoinHostPort(host, strconv.Itoa(int(port))) return }