diff --git a/transport/internet/tls/ech.go b/transport/internet/tls/ech.go index 46518af4..086eb5c3 100644 --- a/transport/internet/tls/ech.go +++ b/transport/internet/tls/ech.go @@ -36,50 +36,79 @@ func ApplyECH(c *Config, config *tls.Config) error { } type record struct { - record []byte - expire time.Time + echConfig []byte + expire time.Time } var ( - dnsCache = make(map[string]record) - mutex sync.RWMutex + dnsCache = make(map[string]record) + // global Lock? I'm not sure if I need finer grained locks. + // If we do this, we will need to nest another layer of struct + dnsCacheLock sync.RWMutex + updating sync.Mutex ) + +// QueryRecord returns the ECH config for given domain. +// If the record is not in cache or expired, it will query the DOH server and update the cache. func QueryRecord(domain string, server string) ([]byte, error) { - mutex.Lock() - rec, found := dnsCache[domain] - if found && rec.expire.After(time.Now()) { - mutex.Unlock() - return rec.record, nil - } - mutex.Unlock() + dnsCacheLock.RLock() + rec, found := dnsCache[domain] + dnsCacheLock.RUnlock() + if found && rec.expire.After(time.Now()) { + errors.LogDebug(context.Background(), "Cache hit for domain: ", domain) + return rec.echConfig, nil + } - errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server) - record, ttl, err := dohQuery(server, domain) - if err != nil { - return []byte{}, err - } + updating.Lock() + defer updating.Unlock() + // Try to get cache again after lock, in case another goroutine has updated it + // This might happen when the core tring is just stared and multiple goroutines are trying to query the same domain + dnsCacheLock.RLock() + rec, found = dnsCache[domain] + dnsCacheLock.RUnlock() + if found && rec.expire.After(time.Now()) { + errors.LogDebug(context.Background(), "ECH Config cache hit for domain: ", domain, " after trying to get update lock") + return rec.echConfig, nil + } - if ttl < 600 { - ttl = 600 - } + // Query ECH config from DOH server + errors.LogDebug(context.Background(), "Trying to query ECH config for domain: ", domain, " with ECH server: ", server) + echConfig, ttl, err := dohQuery(server, domain) + if err != nil { + return []byte{}, err + } - mutex.Lock() - defer mutex.Unlock() - rec.record = record - rec.expire = time.Now().Add(time.Second * time.Duration(ttl)) - dnsCache[domain] = rec - return record, nil + // Set minimum TTL to 600 seconds + if ttl < 600 { + ttl = 600 + } + + // Get write lock and update cache + dnsCacheLock.Lock() + defer dnsCacheLock.Unlock() + newRecored := record{ + echConfig: echConfig, + expire: time.Now().Add(time.Second * time.Duration(ttl)), + } + dnsCache[domain] = newRecored + return echConfig, nil } + +// dohQuery is the real func for sending type65 query for given domain to given DOH server. +// return ECH config, TTL and error func dohQuery(server string, domain string) ([]byte, uint32, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), dns.TypeHTTPS) + // always 0 in DOH m.Id = 0 msg, err := m.Pack() if err != nil { return []byte{}, 0, err } + // All traffic sent by core should via xray's internet.DialSystem + // This involves the behavior of some Android VPN GUI clients tr := &http.Transport{ IdleConnTimeout: 90 * time.Second, ForceAttemptHTTP2: true,