diff --git a/core/src/main/cpp/main.c b/core/src/main/cpp/main.c index 63a429bb..b4b814d2 100644 --- a/core/src/main/cpp/main.c +++ b/core/src/main/cpp/main.c @@ -96,15 +96,17 @@ Java_com_github_kr328_clash_core_bridge_Bridge_nativeNotifyInstalledAppChanged(J } JNIEXPORT void JNICALL -Java_com_github_kr328_clash_core_bridge_Bridge_nativeStartTun(JNIEnv *env, jobject thiz, jint fd, - jint mtu, jstring dns, +Java_com_github_kr328_clash_core_bridge_Bridge_nativeStartTun(JNIEnv *env, jobject thiz, + jint fd, jint mtu, + jstring gateway, jstring dns, jobject cb) { TRACE_METHOD(); + scoped_string _gateway = get_string(gateway); scoped_string _dns = get_string(dns); jobject _interface = new_global(cb); - startTun(fd, mtu, _dns, _interface); + startTun(fd, mtu, _gateway, _dns, _interface); } JNIEXPORT void JNICALL diff --git a/core/src/main/golang/tun.go b/core/src/main/golang/tun.go index 888724c2..e0a5859e 100644 --- a/core/src/main/golang/tun.go +++ b/core/src/main/golang/tun.go @@ -11,8 +11,6 @@ import ( "cfa/tun" "golang.org/x/sync/semaphore" - - "github.com/Dreamacro/clash/log" ) type remoteTun struct { @@ -23,7 +21,7 @@ type remoteTun struct { } func (t *remoteTun) markSocket(fd int) { - _ = t.limit.Acquire(context.Background(), 1) + _ = t.limit.Acquire(context.TODO(), 1) defer t.limit.Release(1) if t.closed { @@ -34,7 +32,7 @@ func (t *remoteTun) markSocket(fd int) { } func (t *remoteTun) querySocketUid(protocol int, source, target string) int { - _ = t.limit.Acquire(context.Background(), 1) + _ = t.limit.Acquire(context.TODO(), 1) defer t.limit.Release(1) if t.closed { @@ -50,26 +48,27 @@ func (t *remoteTun) stop() { t.closed = true - C.release_object(t.callback) + app.ApplyTunContext(nil, nil) - log.Infoln("Android tun device destroyed") + C.release_object(t.callback) } //export startTun -func startTun(fd, mtu C.int, dns C.c_string, callback unsafe.Pointer) C.int { +func startTun(fd, mtu C.int, gateway, dns C.c_string, callback unsafe.Pointer) C.int { f := int(fd) m := int(mtu) + g := C.GoString(gateway) d := C.GoString(dns) remote := &remoteTun{callback: callback, closed: false, limit: semaphore.NewWeighted(4)} - if tun.Start(f, m, d) != nil { - return 1 - } - app.ApplyTunContext(remote.markSocket, remote.querySocketUid) - log.Infoln("Android tun device created") + if tun.Start(f, m, g, d, remote.stop) != nil { + app.ApplyTunContext(nil, nil) + + return 1 + } return 0 } diff --git a/core/src/main/golang/tun/dns.go b/core/src/main/golang/tun/dns.go index 15330346..a5180f42 100644 --- a/core/src/main/golang/tun/dns.go +++ b/core/src/main/golang/tun/dns.go @@ -1,19 +1,12 @@ package tun import ( - "encoding/binary" - "io" "net" - "time" "github.com/Dreamacro/clash/component/resolver" - "github.com/kr328/tun2socket" - D "github.com/miekg/dns" ) -const defaultDnsReadTimeout = time.Second * 30 - func shouldHijackDns(dns net.IP, target net.IP, targetPort int) bool { if targetPort != 53 { return false @@ -22,58 +15,7 @@ func shouldHijackDns(dns net.IP, target net.IP, targetPort int) bool { return net.IPv4zero.Equal(dns) || target.Equal(dns) } -func hijackUDPDns(pkt []byte, lAddr, rAddr net.Addr, udp tun2socket.UDP) { - go func() { - answer, err := relayDnsPacket(pkt) - - if err != nil { - return - } - - _, _ = udp.WriteTo(answer, lAddr, rAddr) - - recycleUDP(pkt) - }() -} - -func hijackTCPDns(conn net.Conn) { - go func() { - defer conn.Close() - - for { - if err := conn.SetReadDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil { - return - } - - var length uint16 - if binary.Read(conn, binary.BigEndian, &length) != nil { - return - } - - data := make([]byte, length) - - _, err := io.ReadFull(conn, data) - if err != nil { - return - } - - rb, err := relayDnsPacket(data) - if err != nil { - continue - } - - if binary.Write(conn, binary.BigEndian, uint16(len(rb))) != nil { - return - } - - if _, err := conn.Write(rb); err != nil { - return - } - } - }() -} - -func relayDnsPacket(payload []byte) ([]byte, error) { +func relayDns(payload []byte) ([]byte, error) { msg := &D.Msg{} if err := msg.Unpack(payload); err != nil { return nil, err @@ -84,14 +26,6 @@ func relayDnsPacket(payload []byte) ([]byte, error) { return nil, err } - for _, ans := range r.Answer { - header := ans.Header() - - if header.Class == D.ClassINET && (header.Rrtype == D.TypeA || header.Rrtype == D.TypeAAAA) { - header.Ttl = 1 - } - } - r.SetRcode(msg, r.Rcode) return r.Pack() diff --git a/core/src/main/golang/tun/link.go b/core/src/main/golang/tun/link.go new file mode 100644 index 00000000..5ca750e7 --- /dev/null +++ b/core/src/main/golang/tun/link.go @@ -0,0 +1,39 @@ +package tun + +import "github.com/Dreamacro/clash/log" + +func (a *adapter) rx() { + log.Infoln("[ATUN] Device rx started") + defer log.Infoln("[ATUN] Device rx exited") + defer a.once.Do(a.stop) + defer a.close() + + buf := make([]byte, a.mtu) + + for { + n, err := a.device.Read(buf) + if err != nil { + return + } + + _, _ = a.stack.Link().Write(buf[:n]) + } +} + +func (a *adapter) tx() { + log.Infoln("[ATUN] Device tx started") + defer log.Infoln("[ATUN] Device tx exited") + defer a.once.Do(a.stop) + defer a.close() + + buf := make([]byte, a.mtu) + + for { + n, err := a.stack.Link().Read(buf) + if err != nil { + return + } + + _, _ = a.device.Write(buf[:n]) + } +} diff --git a/core/src/main/golang/tun/tcp.go b/core/src/main/golang/tun/tcp.go index c74fe5cd..a072c8fb 100644 --- a/core/src/main/golang/tun/tcp.go +++ b/core/src/main/golang/tun/tcp.go @@ -1,27 +1,100 @@ package tun import ( + "encoding/binary" + "io" "net" "strconv" + "time" C "github.com/Dreamacro/clash/constant" - CTX "github.com/Dreamacro/clash/context" + "github.com/Dreamacro/clash/context" + "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/tunnel" ) -func handleTCP(conn net.Conn, source *net.TCPAddr, target *net.TCPAddr) { - metadata := &C.Metadata{ - NetWork: C.TCP, - Type: C.SOCKS, - SrcIP: source.IP, - DstIP: target.IP, - SrcPort: strconv.Itoa(source.Port), - DstPort: strconv.Itoa(target.Port), - AddrType: C.AtypIPv4, - Host: "", - RawSrcAddr: source, - RawDstAddr: target, +const defaultDnsReadTimeout = time.Second * 30 + +func (a *adapter) tcp() { + log.Infoln("[ATUN] TCP listener started") + defer log.Infoln("[ATUN] TCP listener exited") + defer a.stack.Close() + + for { + conn, err := a.stack.TCP().Accept() + if err != nil { + return + } + + sAddr := conn.LocalAddr().(*net.TCPAddr) + tAddr := conn.RemoteAddr().(*net.TCPAddr) + + // handle dns messages + if a.hijackTCPDNS(conn, tAddr) { + continue + } + + // drop all connections connect to gateway + if a.gateway.Contains(tAddr.IP) { + continue + } + + metadata := &C.Metadata{ + NetWork: C.TCP, + Type: C.SOCKS, + SrcIP: sAddr.IP, + DstIP: tAddr.IP, + SrcPort: strconv.Itoa(sAddr.Port), + DstPort: strconv.Itoa(tAddr.Port), + AddrType: C.AtypIPv4, + Host: "", + RawSrcAddr: sAddr, + RawDstAddr: tAddr, + } + + tunnel.Add(context.NewConnContext(conn, metadata)) + } +} + +func (a *adapter) hijackTCPDNS(conn net.Conn, tAddr *net.TCPAddr) bool { + if !shouldHijackDns(a.dns, tAddr.IP, tAddr.Port) { + return false } - tunnel.Add(CTX.NewConnContext(conn, metadata)) + go func() { + defer conn.Close() + + for { + if err := conn.SetReadDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil { + return + } + + var length uint16 + if binary.Read(conn, binary.BigEndian, &length) != nil { + return + } + + data := make([]byte, length) + + _, err := io.ReadFull(conn, data) + if err != nil { + return + } + + rb, err := relayDns(data) + if err != nil { + continue + } + + if binary.Write(conn, binary.BigEndian, uint16(len(rb))) != nil { + return + } + + if _, err := conn.Write(rb); err != nil { + return + } + } + }() + + return true } diff --git a/core/src/main/golang/tun/tun.go b/core/src/main/golang/tun/tun.go index b5a2d7cc..6a2c097a 100644 --- a/core/src/main/golang/tun/tun.go +++ b/core/src/main/golang/tun/tun.go @@ -4,33 +4,40 @@ import ( "net" "os" "sync" + "syscall" "github.com/kr328/tun2socket" ) -type context struct { - device *os.File - stack tun2socket.Stack +type adapter struct { + device *os.File + stack tun2socket.Stack + gateway *net.IPNet + dns net.IP + mtu int + once sync.Once + stop func() } var lock sync.Mutex -var tun *context +var instance *adapter -func (ctx *context) close() { - _ = ctx.stack.Close() - _ = ctx.device.Close() +func (a *adapter) close() { + _ = a.stack.Close() + _ = a.device.Close() } -func Start(fd, mtu int, dns string) error { +func Start(fd, mtu int, gateway, dns string, stop func()) error { lock.Lock() defer lock.Unlock() - stopLocked() + if instance != nil { + instance.close() + } - dnsIP := net.ParseIP(dns) + _ = syscall.SetNonblock(fd, true) device := os.NewFile(uintptr(fd), "/dev/tun") - stack, err := tun2socket.NewStack(mtu) if err != nil { _ = device.Close() @@ -38,100 +45,23 @@ func Start(fd, mtu int, dns string) error { return err } - ctx := &context{ - device: device, - stack: stack, + dn := net.ParseIP(dns) + _, gw, _ := net.ParseCIDR(gateway) + + instance = &adapter{ + device: device, + stack: stack, + gateway: gw, + dns: dn, + mtu: mtu, + once: sync.Once{}, + stop: stop, } - go func() { - // device -> lwip - - defer ctx.close() - - buf := make([]byte, mtu) - - for { - n, err := device.Read(buf) - if err != nil { - return - } - - _, _ = stack.Link().Write(buf[:n]) - } - }() - - go func() { - // lwip -> device - - defer ctx.close() - - buf := make([]byte, mtu) - - for { - n, err := stack.Link().Read(buf) - if err != nil { - return - } - - _, _ = device.Write(buf[:n]) - } - }() - - go func() { - // lwip tcp - - defer ctx.close() - - for { - conn, err := stack.TCP().Accept() - if err != nil { - return - } - - source := conn.LocalAddr().(*net.TCPAddr) - target := conn.RemoteAddr().(*net.TCPAddr) - - if shouldHijackDns(dnsIP, target.IP, target.Port) { - hijackTCPDns(conn) - - continue - } - - handleTCP(conn, source, target) - } - }() - - go func() { - // lwip udp - - defer ctx.close() - - for { - buf := allocUDP(mtu) - - n, lAddr, rAddr, err := stack.UDP().ReadFrom(buf) - if err != nil { - return - } - - source := lAddr.(*net.UDPAddr) - target := rAddr.(*net.UDPAddr) - - if n == 0 { - continue - } - - if shouldHijackDns(dnsIP, target.IP, target.Port) { - hijackUDPDns(buf[:n], source, target, stack.UDP()) - - continue - } - - handleUDP(buf[:n], source, target, stack.UDP()) - } - }() - - tun = ctx + go instance.rx() + go instance.tx() + go instance.tcp() + go instance.udp() return nil } @@ -140,13 +70,9 @@ func Stop() { lock.Lock() defer lock.Unlock() - stopLocked() -} - -func stopLocked() { - if tun != nil { - tun.close() + if instance != nil { + instance.close() } - tun = nil + instance = nil } diff --git a/core/src/main/golang/tun/udp.go b/core/src/main/golang/tun/udp.go index 98d02972..a5e26f13 100644 --- a/core/src/main/golang/tun/udp.go +++ b/core/src/main/golang/tun/udp.go @@ -3,6 +3,7 @@ package tun import ( "net" + "github.com/Dreamacro/clash/log" "github.com/Dreamacro/clash/transport/socks5" "github.com/kr328/tun2socket" @@ -12,48 +13,88 @@ import ( "github.com/Dreamacro/clash/tunnel" ) -type udpPacket struct { - source *net.UDPAddr - data []byte - udp tun2socket.UDP +type packet struct { + stack tun2socket.Stack + local *net.UDPAddr + data []byte } -func (u *udpPacket) Data() []byte { - return u.data +func (pkt *packet) Data() []byte { + return pkt.data } -func (u *udpPacket) WriteBack(b []byte, addr net.Addr) (n int, err error) { - return u.udp.WriteTo(b, u.source, addr) +func (pkt *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) { + return pkt.stack.UDP().WriteTo(b, pkt.local, addr) } -func (u *udpPacket) Drop() { - recycleUDP(u.data) +func (pkt *packet) Drop() { + pool.Put(pkt.data) } -func (u *udpPacket) LocalAddr() net.Addr { +func (pkt *packet) LocalAddr() net.Addr { return &net.UDPAddr{ - IP: u.source.IP, - Port: u.source.Port, + IP: pkt.local.IP, + Port: pkt.local.Port, Zone: "", } } -func handleUDP(payload []byte, source *net.UDPAddr, target *net.UDPAddr, udp tun2socket.UDP) { - pkt := &udpPacket{ - source: source, - data: payload, - udp: udp, +func (a *adapter) udp() { + log.Infoln("[ATUN] UDP receiver started") + defer log.Infoln("[ATUN] UDP receiver exited") + defer a.stack.Close() + + for { + buf := pool.Get(a.mtu) + + n, lAddr, rAddr, err := a.stack.UDP().ReadFrom(buf) + if err != nil { + return + } + + sAddr := lAddr.(*net.UDPAddr) + tAddr := rAddr.(*net.UDPAddr) + + // handle dns messages + if a.hijackUDPDNS(buf[:n], sAddr, tAddr) { + continue + } + + // drop all packets send to gateway + if a.gateway.Contains(tAddr.IP) { + pool.Put(buf) + + continue + } + + pkt := &packet{ + stack: a.stack, + local: sAddr, + data: buf[:n], + } + + adapter := adapters.NewPacket(socks5.ParseAddrToSocksAddr(tAddr), pkt, C.SOCKS) + + tunnel.AddPacket(adapter) + } +} + +func (a *adapter) hijackUDPDNS(pkt []byte, sAddr, tAddr *net.UDPAddr) bool { + if !shouldHijackDns(a.dns, tAddr.IP, tAddr.Port) { + return false } - adapter := adapters.NewPacket(socks5.ParseAddrToSocksAddr(target), pkt, C.SOCKS) + go func() { + answer, err := relayDns(pkt) - tunnel.AddPacket(adapter) -} + if err != nil { + return + } -func allocUDP(size int) []byte { - return pool.Get(size) -} + _, _ = a.stack.UDP().WriteTo(answer, sAddr, tAddr) -func recycleUDP(payload []byte) { - _ = pool.Put(payload) + pool.Put(pkt) + }() + + return true } diff --git a/core/src/main/java/com/github/kr328/clash/core/Clash.kt b/core/src/main/java/com/github/kr328/clash/core/Clash.kt index bf58ae52..cd65a2de 100644 --- a/core/src/main/java/com/github/kr328/clash/core/Clash.kt +++ b/core/src/main/java/com/github/kr328/clash/core/Clash.kt @@ -61,11 +61,12 @@ object Clash { fun startTun( fd: Int, mtu: Int, + gateway: String, dns: String, markSocket: (Int) -> Boolean, querySocketUid: (protocol: Int, source: InetSocketAddress, target: InetSocketAddress) -> Int ) { - Bridge.nativeStartTun(fd, mtu, dns, object : TunInterface { + Bridge.nativeStartTun(fd, mtu, gateway, dns, object : TunInterface { override fun markSocket(fd: Int) { markSocket(fd) } diff --git a/core/src/main/java/com/github/kr328/clash/core/bridge/Bridge.kt b/core/src/main/java/com/github/kr328/clash/core/bridge/Bridge.kt index 29b67642..7794cc36 100644 --- a/core/src/main/java/com/github/kr328/clash/core/bridge/Bridge.kt +++ b/core/src/main/java/com/github/kr328/clash/core/bridge/Bridge.kt @@ -17,7 +17,7 @@ object Bridge { external fun nativeQueryTrafficTotal(): Long external fun nativeNotifyDnsChanged(dnsList: String) external fun nativeNotifyInstalledAppChanged(uidList: String) - external fun nativeStartTun(fd: Int, mtu: Int, dns: String, cb: TunInterface) + external fun nativeStartTun(fd: Int, mtu: Int, gateway: String, dns: String, cb: TunInterface) external fun nativeStopTun() external fun nativeStartHttp(listenAt: String): String? external fun nativeStopHttp() diff --git a/service/src/main/java/com/github/kr328/clash/service/TunService.kt b/service/src/main/java/com/github/kr328/clash/service/TunService.kt index 3bf09a8d..d8e3a4b6 100644 --- a/service/src/main/java/com/github/kr328/clash/service/TunService.kt +++ b/service/src/main/java/com/github/kr328/clash/service/TunService.kt @@ -217,6 +217,7 @@ class TunService : VpnService(), CoroutineScope by CoroutineScope(Dispatchers.De fd = establish()?.detachFd() ?: throw NullPointerException("Establish VPN rejected by system"), mtu = TUN_MTU, + gateway = "$TUN_GATEWAY/$TUN_SUBNET_PREFIX", dns = if (store.dnsHijacking) NET_ANY else TUN_DNS, ) } diff --git a/service/src/main/java/com/github/kr328/clash/service/clash/module/TunModule.kt b/service/src/main/java/com/github/kr328/clash/service/clash/module/TunModule.kt index bc4c8ebc..2e91316c 100644 --- a/service/src/main/java/com/github/kr328/clash/service/clash/module/TunModule.kt +++ b/service/src/main/java/com/github/kr328/clash/service/clash/module/TunModule.kt @@ -16,7 +16,8 @@ class TunModule(private val vpn: VpnService) : Module(vpn) { data class TunDevice( val fd: Int, val mtu: Int, - val dns: String + val gateway: String, + val dns: String, ) private val connectivity = service.getSystemService()!! @@ -56,6 +57,7 @@ class TunModule(private val vpn: VpnService) : Module(vpn) { Clash.startTun( fd = device.fd, mtu = device.mtu, + gateway = device.gateway, dns = device.dns, markSocket = vpn::protect, querySocketUid = this::queryUid