Fix: refactor tun implement

This commit is contained in:
Kr328 2021-05-26 15:33:59 +08:00
parent ac35f2a5f4
commit 5e34221a09
11 changed files with 253 additions and 235 deletions

View File

@ -96,15 +96,17 @@ Java_com_github_kr328_clash_core_bridge_Bridge_nativeNotifyInstalledAppChanged(J
} }
JNIEXPORT void JNICALL JNIEXPORT void JNICALL
Java_com_github_kr328_clash_core_bridge_Bridge_nativeStartTun(JNIEnv *env, jobject thiz, jint fd, Java_com_github_kr328_clash_core_bridge_Bridge_nativeStartTun(JNIEnv *env, jobject thiz,
jint mtu, jstring dns, jint fd, jint mtu,
jstring gateway, jstring dns,
jobject cb) { jobject cb) {
TRACE_METHOD(); TRACE_METHOD();
scoped_string _gateway = get_string(gateway);
scoped_string _dns = get_string(dns); scoped_string _dns = get_string(dns);
jobject _interface = new_global(cb); jobject _interface = new_global(cb);
startTun(fd, mtu, _dns, _interface); startTun(fd, mtu, _gateway, _dns, _interface);
} }
JNIEXPORT void JNICALL JNIEXPORT void JNICALL

View File

@ -11,8 +11,6 @@ import (
"cfa/tun" "cfa/tun"
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/Dreamacro/clash/log"
) )
type remoteTun struct { type remoteTun struct {
@ -23,7 +21,7 @@ type remoteTun struct {
} }
func (t *remoteTun) markSocket(fd int) { func (t *remoteTun) markSocket(fd int) {
_ = t.limit.Acquire(context.Background(), 1) _ = t.limit.Acquire(context.TODO(), 1)
defer t.limit.Release(1) defer t.limit.Release(1)
if t.closed { if t.closed {
@ -34,7 +32,7 @@ func (t *remoteTun) markSocket(fd int) {
} }
func (t *remoteTun) querySocketUid(protocol int, source, target string) int { func (t *remoteTun) querySocketUid(protocol int, source, target string) int {
_ = t.limit.Acquire(context.Background(), 1) _ = t.limit.Acquire(context.TODO(), 1)
defer t.limit.Release(1) defer t.limit.Release(1)
if t.closed { if t.closed {
@ -50,26 +48,27 @@ func (t *remoteTun) stop() {
t.closed = true t.closed = true
C.release_object(t.callback) app.ApplyTunContext(nil, nil)
log.Infoln("Android tun device destroyed") C.release_object(t.callback)
} }
//export startTun //export startTun
func startTun(fd, mtu C.int, dns C.c_string, callback unsafe.Pointer) C.int { func startTun(fd, mtu C.int, gateway, dns C.c_string, callback unsafe.Pointer) C.int {
f := int(fd) f := int(fd)
m := int(mtu) m := int(mtu)
g := C.GoString(gateway)
d := C.GoString(dns) d := C.GoString(dns)
remote := &remoteTun{callback: callback, closed: false, limit: semaphore.NewWeighted(4)} remote := &remoteTun{callback: callback, closed: false, limit: semaphore.NewWeighted(4)}
if tun.Start(f, m, d) != nil {
return 1
}
app.ApplyTunContext(remote.markSocket, remote.querySocketUid) app.ApplyTunContext(remote.markSocket, remote.querySocketUid)
log.Infoln("Android tun device created") if tun.Start(f, m, g, d, remote.stop) != nil {
app.ApplyTunContext(nil, nil)
return 1
}
return 0 return 0
} }

View File

@ -1,19 +1,12 @@
package tun package tun
import ( import (
"encoding/binary"
"io"
"net" "net"
"time"
"github.com/Dreamacro/clash/component/resolver" "github.com/Dreamacro/clash/component/resolver"
"github.com/kr328/tun2socket"
D "github.com/miekg/dns" D "github.com/miekg/dns"
) )
const defaultDnsReadTimeout = time.Second * 30
func shouldHijackDns(dns net.IP, target net.IP, targetPort int) bool { func shouldHijackDns(dns net.IP, target net.IP, targetPort int) bool {
if targetPort != 53 { if targetPort != 53 {
return false return false
@ -22,58 +15,7 @@ func shouldHijackDns(dns net.IP, target net.IP, targetPort int) bool {
return net.IPv4zero.Equal(dns) || target.Equal(dns) return net.IPv4zero.Equal(dns) || target.Equal(dns)
} }
func hijackUDPDns(pkt []byte, lAddr, rAddr net.Addr, udp tun2socket.UDP) { func relayDns(payload []byte) ([]byte, error) {
go func() {
answer, err := relayDnsPacket(pkt)
if err != nil {
return
}
_, _ = udp.WriteTo(answer, lAddr, rAddr)
recycleUDP(pkt)
}()
}
func hijackTCPDns(conn net.Conn) {
go func() {
defer conn.Close()
for {
if err := conn.SetReadDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil {
return
}
var length uint16
if binary.Read(conn, binary.BigEndian, &length) != nil {
return
}
data := make([]byte, length)
_, err := io.ReadFull(conn, data)
if err != nil {
return
}
rb, err := relayDnsPacket(data)
if err != nil {
continue
}
if binary.Write(conn, binary.BigEndian, uint16(len(rb))) != nil {
return
}
if _, err := conn.Write(rb); err != nil {
return
}
}
}()
}
func relayDnsPacket(payload []byte) ([]byte, error) {
msg := &D.Msg{} msg := &D.Msg{}
if err := msg.Unpack(payload); err != nil { if err := msg.Unpack(payload); err != nil {
return nil, err return nil, err
@ -84,14 +26,6 @@ func relayDnsPacket(payload []byte) ([]byte, error) {
return nil, err return nil, err
} }
for _, ans := range r.Answer {
header := ans.Header()
if header.Class == D.ClassINET && (header.Rrtype == D.TypeA || header.Rrtype == D.TypeAAAA) {
header.Ttl = 1
}
}
r.SetRcode(msg, r.Rcode) r.SetRcode(msg, r.Rcode)
return r.Pack() return r.Pack()

View File

@ -0,0 +1,39 @@
package tun
import "github.com/Dreamacro/clash/log"
func (a *adapter) rx() {
log.Infoln("[ATUN] Device rx started")
defer log.Infoln("[ATUN] Device rx exited")
defer a.once.Do(a.stop)
defer a.close()
buf := make([]byte, a.mtu)
for {
n, err := a.device.Read(buf)
if err != nil {
return
}
_, _ = a.stack.Link().Write(buf[:n])
}
}
func (a *adapter) tx() {
log.Infoln("[ATUN] Device tx started")
defer log.Infoln("[ATUN] Device tx exited")
defer a.once.Do(a.stop)
defer a.close()
buf := make([]byte, a.mtu)
for {
n, err := a.stack.Link().Read(buf)
if err != nil {
return
}
_, _ = a.device.Write(buf[:n])
}
}

View File

@ -1,27 +1,100 @@
package tun package tun
import ( import (
"encoding/binary"
"io"
"net" "net"
"strconv" "strconv"
"time"
C "github.com/Dreamacro/clash/constant" C "github.com/Dreamacro/clash/constant"
CTX "github.com/Dreamacro/clash/context" "github.com/Dreamacro/clash/context"
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
) )
func handleTCP(conn net.Conn, source *net.TCPAddr, target *net.TCPAddr) { const defaultDnsReadTimeout = time.Second * 30
func (a *adapter) tcp() {
log.Infoln("[ATUN] TCP listener started")
defer log.Infoln("[ATUN] TCP listener exited")
defer a.stack.Close()
for {
conn, err := a.stack.TCP().Accept()
if err != nil {
return
}
sAddr := conn.LocalAddr().(*net.TCPAddr)
tAddr := conn.RemoteAddr().(*net.TCPAddr)
// handle dns messages
if a.hijackTCPDNS(conn, tAddr) {
continue
}
// drop all connections connect to gateway
if a.gateway.Contains(tAddr.IP) {
continue
}
metadata := &C.Metadata{ metadata := &C.Metadata{
NetWork: C.TCP, NetWork: C.TCP,
Type: C.SOCKS, Type: C.SOCKS,
SrcIP: source.IP, SrcIP: sAddr.IP,
DstIP: target.IP, DstIP: tAddr.IP,
SrcPort: strconv.Itoa(source.Port), SrcPort: strconv.Itoa(sAddr.Port),
DstPort: strconv.Itoa(target.Port), DstPort: strconv.Itoa(tAddr.Port),
AddrType: C.AtypIPv4, AddrType: C.AtypIPv4,
Host: "", Host: "",
RawSrcAddr: source, RawSrcAddr: sAddr,
RawDstAddr: target, RawDstAddr: tAddr,
} }
tunnel.Add(CTX.NewConnContext(conn, metadata)) tunnel.Add(context.NewConnContext(conn, metadata))
}
}
func (a *adapter) hijackTCPDNS(conn net.Conn, tAddr *net.TCPAddr) bool {
if !shouldHijackDns(a.dns, tAddr.IP, tAddr.Port) {
return false
}
go func() {
defer conn.Close()
for {
if err := conn.SetReadDeadline(time.Now().Add(defaultDnsReadTimeout)); err != nil {
return
}
var length uint16
if binary.Read(conn, binary.BigEndian, &length) != nil {
return
}
data := make([]byte, length)
_, err := io.ReadFull(conn, data)
if err != nil {
return
}
rb, err := relayDns(data)
if err != nil {
continue
}
if binary.Write(conn, binary.BigEndian, uint16(len(rb))) != nil {
return
}
if _, err := conn.Write(rb); err != nil {
return
}
}
}()
return true
} }

View File

@ -4,33 +4,40 @@ import (
"net" "net"
"os" "os"
"sync" "sync"
"syscall"
"github.com/kr328/tun2socket" "github.com/kr328/tun2socket"
) )
type context struct { type adapter struct {
device *os.File device *os.File
stack tun2socket.Stack stack tun2socket.Stack
gateway *net.IPNet
dns net.IP
mtu int
once sync.Once
stop func()
} }
var lock sync.Mutex var lock sync.Mutex
var tun *context var instance *adapter
func (ctx *context) close() { func (a *adapter) close() {
_ = ctx.stack.Close() _ = a.stack.Close()
_ = ctx.device.Close() _ = a.device.Close()
} }
func Start(fd, mtu int, dns string) error { func Start(fd, mtu int, gateway, dns string, stop func()) error {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
stopLocked() if instance != nil {
instance.close()
}
dnsIP := net.ParseIP(dns) _ = syscall.SetNonblock(fd, true)
device := os.NewFile(uintptr(fd), "/dev/tun") device := os.NewFile(uintptr(fd), "/dev/tun")
stack, err := tun2socket.NewStack(mtu) stack, err := tun2socket.NewStack(mtu)
if err != nil { if err != nil {
_ = device.Close() _ = device.Close()
@ -38,100 +45,23 @@ func Start(fd, mtu int, dns string) error {
return err return err
} }
ctx := &context{ dn := net.ParseIP(dns)
_, gw, _ := net.ParseCIDR(gateway)
instance = &adapter{
device: device, device: device,
stack: stack, stack: stack,
gateway: gw,
dns: dn,
mtu: mtu,
once: sync.Once{},
stop: stop,
} }
go func() { go instance.rx()
// device -> lwip go instance.tx()
go instance.tcp()
defer ctx.close() go instance.udp()
buf := make([]byte, mtu)
for {
n, err := device.Read(buf)
if err != nil {
return
}
_, _ = stack.Link().Write(buf[:n])
}
}()
go func() {
// lwip -> device
defer ctx.close()
buf := make([]byte, mtu)
for {
n, err := stack.Link().Read(buf)
if err != nil {
return
}
_, _ = device.Write(buf[:n])
}
}()
go func() {
// lwip tcp
defer ctx.close()
for {
conn, err := stack.TCP().Accept()
if err != nil {
return
}
source := conn.LocalAddr().(*net.TCPAddr)
target := conn.RemoteAddr().(*net.TCPAddr)
if shouldHijackDns(dnsIP, target.IP, target.Port) {
hijackTCPDns(conn)
continue
}
handleTCP(conn, source, target)
}
}()
go func() {
// lwip udp
defer ctx.close()
for {
buf := allocUDP(mtu)
n, lAddr, rAddr, err := stack.UDP().ReadFrom(buf)
if err != nil {
return
}
source := lAddr.(*net.UDPAddr)
target := rAddr.(*net.UDPAddr)
if n == 0 {
continue
}
if shouldHijackDns(dnsIP, target.IP, target.Port) {
hijackUDPDns(buf[:n], source, target, stack.UDP())
continue
}
handleUDP(buf[:n], source, target, stack.UDP())
}
}()
tun = ctx
return nil return nil
} }
@ -140,13 +70,9 @@ func Stop() {
lock.Lock() lock.Lock()
defer lock.Unlock() defer lock.Unlock()
stopLocked() if instance != nil {
} instance.close()
func stopLocked() {
if tun != nil {
tun.close()
} }
tun = nil instance = nil
} }

View File

@ -3,6 +3,7 @@ package tun
import ( import (
"net" "net"
"github.com/Dreamacro/clash/log"
"github.com/Dreamacro/clash/transport/socks5" "github.com/Dreamacro/clash/transport/socks5"
"github.com/kr328/tun2socket" "github.com/kr328/tun2socket"
@ -12,48 +13,88 @@ import (
"github.com/Dreamacro/clash/tunnel" "github.com/Dreamacro/clash/tunnel"
) )
type udpPacket struct { type packet struct {
source *net.UDPAddr stack tun2socket.Stack
local *net.UDPAddr
data []byte data []byte
udp tun2socket.UDP
} }
func (u *udpPacket) Data() []byte { func (pkt *packet) Data() []byte {
return u.data return pkt.data
} }
func (u *udpPacket) WriteBack(b []byte, addr net.Addr) (n int, err error) { func (pkt *packet) WriteBack(b []byte, addr net.Addr) (n int, err error) {
return u.udp.WriteTo(b, u.source, addr) return pkt.stack.UDP().WriteTo(b, pkt.local, addr)
} }
func (u *udpPacket) Drop() { func (pkt *packet) Drop() {
recycleUDP(u.data) pool.Put(pkt.data)
} }
func (u *udpPacket) LocalAddr() net.Addr { func (pkt *packet) LocalAddr() net.Addr {
return &net.UDPAddr{ return &net.UDPAddr{
IP: u.source.IP, IP: pkt.local.IP,
Port: u.source.Port, Port: pkt.local.Port,
Zone: "", Zone: "",
} }
} }
func handleUDP(payload []byte, source *net.UDPAddr, target *net.UDPAddr, udp tun2socket.UDP) { func (a *adapter) udp() {
pkt := &udpPacket{ log.Infoln("[ATUN] UDP receiver started")
source: source, defer log.Infoln("[ATUN] UDP receiver exited")
data: payload, defer a.stack.Close()
udp: udp,
for {
buf := pool.Get(a.mtu)
n, lAddr, rAddr, err := a.stack.UDP().ReadFrom(buf)
if err != nil {
return
} }
adapter := adapters.NewPacket(socks5.ParseAddrToSocksAddr(target), pkt, C.SOCKS) sAddr := lAddr.(*net.UDPAddr)
tAddr := rAddr.(*net.UDPAddr)
// handle dns messages
if a.hijackUDPDNS(buf[:n], sAddr, tAddr) {
continue
}
// drop all packets send to gateway
if a.gateway.Contains(tAddr.IP) {
pool.Put(buf)
continue
}
pkt := &packet{
stack: a.stack,
local: sAddr,
data: buf[:n],
}
adapter := adapters.NewPacket(socks5.ParseAddrToSocksAddr(tAddr), pkt, C.SOCKS)
tunnel.AddPacket(adapter) tunnel.AddPacket(adapter)
}
} }
func allocUDP(size int) []byte { func (a *adapter) hijackUDPDNS(pkt []byte, sAddr, tAddr *net.UDPAddr) bool {
return pool.Get(size) if !shouldHijackDns(a.dns, tAddr.IP, tAddr.Port) {
} return false
}
func recycleUDP(payload []byte) { go func() {
_ = pool.Put(payload) answer, err := relayDns(pkt)
if err != nil {
return
}
_, _ = a.stack.UDP().WriteTo(answer, sAddr, tAddr)
pool.Put(pkt)
}()
return true
} }

View File

@ -61,11 +61,12 @@ object Clash {
fun startTun( fun startTun(
fd: Int, fd: Int,
mtu: Int, mtu: Int,
gateway: String,
dns: String, dns: String,
markSocket: (Int) -> Boolean, markSocket: (Int) -> Boolean,
querySocketUid: (protocol: Int, source: InetSocketAddress, target: InetSocketAddress) -> Int querySocketUid: (protocol: Int, source: InetSocketAddress, target: InetSocketAddress) -> Int
) { ) {
Bridge.nativeStartTun(fd, mtu, dns, object : TunInterface { Bridge.nativeStartTun(fd, mtu, gateway, dns, object : TunInterface {
override fun markSocket(fd: Int) { override fun markSocket(fd: Int) {
markSocket(fd) markSocket(fd)
} }

View File

@ -17,7 +17,7 @@ object Bridge {
external fun nativeQueryTrafficTotal(): Long external fun nativeQueryTrafficTotal(): Long
external fun nativeNotifyDnsChanged(dnsList: String) external fun nativeNotifyDnsChanged(dnsList: String)
external fun nativeNotifyInstalledAppChanged(uidList: String) external fun nativeNotifyInstalledAppChanged(uidList: String)
external fun nativeStartTun(fd: Int, mtu: Int, dns: String, cb: TunInterface) external fun nativeStartTun(fd: Int, mtu: Int, gateway: String, dns: String, cb: TunInterface)
external fun nativeStopTun() external fun nativeStopTun()
external fun nativeStartHttp(listenAt: String): String? external fun nativeStartHttp(listenAt: String): String?
external fun nativeStopHttp() external fun nativeStopHttp()

View File

@ -217,6 +217,7 @@ class TunService : VpnService(), CoroutineScope by CoroutineScope(Dispatchers.De
fd = establish()?.detachFd() fd = establish()?.detachFd()
?: throw NullPointerException("Establish VPN rejected by system"), ?: throw NullPointerException("Establish VPN rejected by system"),
mtu = TUN_MTU, mtu = TUN_MTU,
gateway = "$TUN_GATEWAY/$TUN_SUBNET_PREFIX",
dns = if (store.dnsHijacking) NET_ANY else TUN_DNS, dns = if (store.dnsHijacking) NET_ANY else TUN_DNS,
) )
} }

View File

@ -16,7 +16,8 @@ class TunModule(private val vpn: VpnService) : Module<Unit>(vpn) {
data class TunDevice( data class TunDevice(
val fd: Int, val fd: Int,
val mtu: Int, val mtu: Int,
val dns: String val gateway: String,
val dns: String,
) )
private val connectivity = service.getSystemService<ConnectivityManager>()!! private val connectivity = service.getSystemService<ConnectivityManager>()!!
@ -56,6 +57,7 @@ class TunModule(private val vpn: VpnService) : Module<Unit>(vpn) {
Clash.startTun( Clash.startTun(
fd = device.fd, fd = device.fd,
mtu = device.mtu, mtu = device.mtu,
gateway = device.gateway,
dns = device.dns, dns = device.dns,
markSocket = vpn::protect, markSocket = vpn::protect,
querySocketUid = this::queryUid querySocketUid = this::queryUid