diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index f3ce2778..1cca1453 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -35,7 +35,7 @@ type Client struct { var errExpectedIPNonMatch = errors.New("expectIPs not match") // NewServer creates a name server object according to the network destination url. -func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (Server, error) { +func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy, fd dns.FakeDNSEngine) (Server, error) { if address := dest.Address; address.Family().IsDomain() { u, err := url.Parse(address.Domain()) if err != nil { @@ -55,7 +55,7 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrateg case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode return NewTCPLocalNameServer(u, queryStrategy) case strings.EqualFold(u.String(), "fakedns"): - return NewFakeDNSServer(), nil + return NewFakeDNSServer(fd), nil } } if dest.Network == net.Network_Unknown { @@ -78,9 +78,9 @@ func NewClient( ) (*Client, error) { client := &Client{} - err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error { + err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher, fd dns.FakeDNSEngine) error { // Create a new server for each client for now - server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy()) + server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy(), fd) if err != nil { return errors.New("failed to create nameserver").Base(err).AtWarning() } diff --git a/app/dns/nameserver_fakedns.go b/app/dns/nameserver_fakedns.go index 531417da..4b9d93f9 100644 --- a/app/dns/nameserver_fakedns.go +++ b/app/dns/nameserver_fakedns.go @@ -5,7 +5,6 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/dns" ) @@ -13,8 +12,8 @@ type FakeDNSServer struct { fakeDNSEngine dns.FakeDNSEngine } -func NewFakeDNSServer() *FakeDNSServer { - return &FakeDNSServer{} +func NewFakeDNSServer(fd dns.FakeDNSEngine) *FakeDNSServer { + return &FakeDNSServer{fakeDNSEngine: fd} } func (FakeDNSServer) Name() string { @@ -22,13 +21,6 @@ func (FakeDNSServer) Name() string { } func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, opt dns.IPOption, _ bool) ([]net.IP, error) { - if f.fakeDNSEngine == nil { - if err := core.RequireFeatures(ctx, func(fd dns.FakeDNSEngine) { - f.fakeDNSEngine = fd - }); err != nil { - return nil, errors.New("Unable to locate a fake DNS Engine").Base(err).AtError() - } - } var ips []net.Address if fkr0, ok := f.fakeDNSEngine.(dns.FakeDNSEngineRev0); ok { ips = fkr0.GetFakeIPForDomain3(domain, opt.IPv4Enable, opt.IPv6Enable) diff --git a/app/observatory/burst/burstobserver.go b/app/observatory/burst/burstobserver.go index f2204c00..472351cc 100644 --- a/app/observatory/burst/burstobserver.go +++ b/app/observatory/burst/burstobserver.go @@ -12,6 +12,7 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" "github.com/xtls/xray-core/features/outbound" + "github.com/xtls/xray-core/features/routing" "google.golang.org/protobuf/proto" ) @@ -88,13 +89,15 @@ func (o *Observer) Close() error { func New(ctx context.Context, config *Config) (*Observer, error) { var outboundManager outbound.Manager - err := core.RequireFeatures(ctx, func(om outbound.Manager) { + var dispatcher routing.Dispatcher + err := core.RequireFeatures(ctx, func(om outbound.Manager, rd routing.Dispatcher) { outboundManager = om + dispatcher = rd }) if err != nil { return nil, errors.New("Cannot get depended features").Base(err) } - hp := NewHealthPing(ctx, config.PingConfig) + hp := NewHealthPing(ctx, dispatcher, config.PingConfig) return &Observer{ config: config, ctx: ctx, diff --git a/app/observatory/burst/healthping.go b/app/observatory/burst/healthping.go index cd4d5fc0..f0842602 100644 --- a/app/observatory/burst/healthping.go +++ b/app/observatory/burst/healthping.go @@ -9,6 +9,7 @@ import ( "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/features/routing" ) // HealthPingSettings holds settings for health Checker @@ -23,6 +24,7 @@ type HealthPingSettings struct { // HealthPing is the health checker for balancers type HealthPing struct { ctx context.Context + dispatcher routing.Dispatcher access sync.Mutex ticker *time.Ticker tickerClose chan struct{} @@ -32,7 +34,7 @@ type HealthPing struct { } // NewHealthPing creates a new HealthPing with settings -func NewHealthPing(ctx context.Context, config *HealthPingConfig) *HealthPing { +func NewHealthPing(ctx context.Context, dispatcher routing.Dispatcher, config *HealthPingConfig) *HealthPing { settings := &HealthPingSettings{} if config != nil { settings = &HealthPingSettings{ @@ -65,6 +67,7 @@ func NewHealthPing(ctx context.Context, config *HealthPingConfig) *HealthPing { } return &HealthPing{ ctx: ctx, + dispatcher: dispatcher, Settings: settings, Results: nil, } @@ -149,6 +152,7 @@ func (h *HealthPing) doCheck(tags []string, duration time.Duration, rounds int) handler := tag client := newPingClient( h.ctx, + h.dispatcher, h.Settings.Destination, h.Settings.Timeout, handler, diff --git a/app/observatory/burst/ping.go b/app/observatory/burst/ping.go index de1465b9..5ea1433a 100644 --- a/app/observatory/burst/ping.go +++ b/app/observatory/burst/ping.go @@ -6,6 +6,7 @@ import ( "time" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/tagged" ) @@ -14,10 +15,10 @@ type pingClient struct { httpClient *http.Client } -func newPingClient(ctx context.Context, destination string, timeout time.Duration, handler string) *pingClient { +func newPingClient(ctx context.Context, dispatcher routing.Dispatcher, destination string, timeout time.Duration, handler string) *pingClient { return &pingClient{ destination: destination, - httpClient: newHTTPClient(ctx, handler, timeout), + httpClient: newHTTPClient(ctx, dispatcher, handler, timeout), } } @@ -28,7 +29,7 @@ func newDirectPingClient(destination string, timeout time.Duration) *pingClient } } -func newHTTPClient(ctxv context.Context, handler string, timeout time.Duration) *http.Client { +func newHTTPClient(ctxv context.Context, dispatcher routing.Dispatcher, handler string, timeout time.Duration) *http.Client { tr := &http.Transport{ DisableKeepAlives: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -36,7 +37,7 @@ func newHTTPClient(ctxv context.Context, handler string, timeout time.Duration) if err != nil { return nil, err } - return tagged.Dialer(ctxv, dest, handler) + return tagged.Dialer(ctxv, dispatcher, dest, handler) }, } return &http.Client{ diff --git a/app/observatory/observer.go b/app/observatory/observer.go index f29856db..657396f6 100644 --- a/app/observatory/observer.go +++ b/app/observatory/observer.go @@ -18,6 +18,7 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" "github.com/xtls/xray-core/features/outbound" + "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/tagged" "google.golang.org/protobuf/proto" ) @@ -32,6 +33,7 @@ type Observer struct { finished *done.Instance ohm outbound.Manager + dispatcher routing.Dispatcher } func (o *Observer) GetObservation(ctx context.Context) (proto.Message, error) { @@ -131,7 +133,7 @@ func (o *Observer) probe(outbound string) ProbeResult { return errors.New("cannot understand address").Base(err) } trackedCtx := session.TrackedConnectionError(o.ctx, errorCollectorForRequest) - conn, err := tagged.Dialer(trackedCtx, dest, outbound) + conn, err := tagged.Dialer(trackedCtx, o.dispatcher, dest, outbound) if err != nil { return errors.New("cannot dial remote address ", dest).Base(err) } @@ -215,8 +217,10 @@ func (o *Observer) findStatusLocationLockHolderOnly(outbound string) int { func New(ctx context.Context, config *Config) (*Observer, error) { var outboundManager outbound.Manager - err := core.RequireFeatures(ctx, func(om outbound.Manager) { + var dispatcher routing.Dispatcher + err := core.RequireFeatures(ctx, func(om outbound.Manager, rd routing.Dispatcher) { outboundManager = om + dispatcher = rd }) if err != nil { return nil, errors.New("Cannot get depended features").Base(err) @@ -225,6 +229,7 @@ func New(ctx context.Context, config *Config) (*Observer, error) { config: config, ctx: ctx, ohm: outboundManager, + dispatcher: dispatcher, }, nil } diff --git a/app/router/balancing.go b/app/router/balancing.go index 14f8e21f..5f8cb1c2 100644 --- a/app/router/balancing.go +++ b/app/router/balancing.go @@ -31,6 +31,12 @@ type RoundRobinStrategy struct { func (s *RoundRobinStrategy) InjectContext(ctx context.Context) { s.ctx = ctx + if len(s.FallbackTag) > 0 { + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { + s.observatory = observatory + return nil + })) + } } func (s *RoundRobinStrategy) GetPrincipleTarget(strings []string) []string { @@ -38,12 +44,6 @@ func (s *RoundRobinStrategy) GetPrincipleTarget(strings []string) []string { } func (s *RoundRobinStrategy) PickOutbound(tags []string) string { - if len(s.FallbackTag) > 0 && s.observatory == nil { - common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { - s.observatory = observatory - return nil - })) - } if s.observatory != nil { observeReport, err := s.observatory.GetObservation(s.ctx) if err == nil { diff --git a/app/router/strategy_leastload.go b/app/router/strategy_leastload.go index bfdfd878..e4620725 100644 --- a/app/router/strategy_leastload.go +++ b/app/router/strategy_leastload.go @@ -58,8 +58,12 @@ type node struct { RTTDeviationCost time.Duration } -func (l *LeastLoadStrategy) InjectContext(ctx context.Context) { - l.ctx = ctx +func (s *LeastLoadStrategy) InjectContext(ctx context.Context) { + s.ctx = ctx + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { + s.observer = observatory + return nil + })) } func (s *LeastLoadStrategy) PickOutbound(candidates []string) string { @@ -135,12 +139,6 @@ func (s *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node { } func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node { - if s.observer == nil { - common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { - s.observer = observatory - return nil - })) - } observeResult, err := s.observer.GetObservation(s.ctx) if err != nil { errors.LogInfoInner(s.ctx, err, "cannot get observation") diff --git a/app/router/strategy_leastping.go b/app/router/strategy_leastping.go index 28efe386..f4e53ed7 100644 --- a/app/router/strategy_leastping.go +++ b/app/router/strategy_leastping.go @@ -21,16 +21,13 @@ func (l *LeastPingStrategy) GetPrincipleTarget(strings []string) []string { func (l *LeastPingStrategy) InjectContext(ctx context.Context) { l.ctx = ctx + common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error { + l.observatory = observatory + return nil + })) } func (l *LeastPingStrategy) PickOutbound(strings []string) string { - if l.observatory == nil { - common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error { - l.observatory = observatory - return nil - })) - } - observeReport, err := l.observatory.GetObservation(l.ctx) if err != nil { errors.LogInfoInner(l.ctx, err, "cannot get observe report") diff --git a/app/router/strategy_random.go b/app/router/strategy_random.go index ed82ff9d..ea9b7add 100644 --- a/app/router/strategy_random.go +++ b/app/router/strategy_random.go @@ -20,6 +20,12 @@ type RandomStrategy struct { func (s *RandomStrategy) InjectContext(ctx context.Context) { s.ctx = ctx + if len(s.FallbackTag) > 0 { + common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { + s.observatory = observatory + return nil + })) + } } func (s *RandomStrategy) GetPrincipleTarget(strings []string) []string { @@ -27,12 +33,6 @@ func (s *RandomStrategy) GetPrincipleTarget(strings []string) []string { } func (s *RandomStrategy) PickOutbound(candidates []string) string { - if len(s.FallbackTag) > 0 && s.observatory == nil { - common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { - s.observatory = observatory - return nil - })) - } if s.observatory != nil { observeReport, err := s.observatory.GetObservation(s.ctx) if err == nil { diff --git a/transport/internet/tagged/tagged.go b/transport/internet/tagged/tagged.go index 430f9640..2cd9dcd2 100644 --- a/transport/internet/tagged/tagged.go +++ b/transport/internet/tagged/tagged.go @@ -4,8 +4,9 @@ import ( "context" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/features/routing" ) -type DialFunc func(ctx context.Context, dest net.Destination, tag string) (net.Conn, error) +type DialFunc func(ctx context.Context, dispatcher routing.Dispatcher, dest net.Destination, tag string) (net.Conn, error) var Dialer DialFunc diff --git a/transport/internet/tagged/taggedimpl/impl.go b/transport/internet/tagged/taggedimpl/impl.go index 29caec7c..2a773401 100644 --- a/transport/internet/tagged/taggedimpl/impl.go +++ b/transport/internet/tagged/taggedimpl/impl.go @@ -12,17 +12,10 @@ import ( "github.com/xtls/xray-core/transport/internet/tagged" ) -func DialTaggedOutbound(ctx context.Context, dest net.Destination, tag string) (net.Conn, error) { - var dispatcher routing.Dispatcher +func DialTaggedOutbound(ctx context.Context, dispatcher routing.Dispatcher, dest net.Destination, tag string) (net.Conn, error) { if core.FromContext(ctx) == nil { return nil, errors.New("Instance context variable is not in context, dial denied. ") } - if err := core.RequireFeatures(ctx, func(dispatcherInstance routing.Dispatcher) { - dispatcher = dispatcherInstance - }); err != nil { - return nil, errors.New("Required Feature dispatcher not resolved").Base(err) - } - content := new(session.Content) content.SkipDNSResolve = true