diff --git a/app/dispatcher/default.go b/app/dispatcher/default.go index 42fca674..dbf58dad 100644 --- a/app/dispatcher/default.go +++ b/app/dispatcher/default.go @@ -106,7 +106,7 @@ func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { d := new(DefaultDispatcher) if err := core.RequireFeatures(ctx, func(om outbound.Manager, router routing.Router, pm policy.Manager, sm stats.Manager, dc dns.Client) error { - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { + core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional d.fdns = fdns }) return d.Init(config.(*Config), om, router, pm, sm, dc) diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index f3ce2778..ecba9aff 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -35,7 +35,7 @@ type Client struct { var errExpectedIPNonMatch = errors.New("expectIPs not match") // NewServer creates a name server object according to the network destination url. -func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (Server, error) { +func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy, fd dns.FakeDNSEngine) (Server, error) { if address := dest.Address; address.Family().IsDomain() { u, err := url.Parse(address.Domain()) if err != nil { @@ -55,7 +55,7 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrateg case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode return NewTCPLocalNameServer(u, queryStrategy) case strings.EqualFold(u.String(), "fakedns"): - return NewFakeDNSServer(), nil + return NewFakeDNSServer(fd), nil } } if dest.Network == net.Network_Unknown { @@ -78,9 +78,13 @@ func NewClient( ) (*Client, error) { client := &Client{} + var fd dns.FakeDNSEngine err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error { + core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional + fd = fdns + }) // Create a new server for each client for now - server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy()) + server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy(), fd) if err != nil { return errors.New("failed to create nameserver").Base(err).AtWarning() } diff --git a/app/dns/nameserver_fakedns.go b/app/dns/nameserver_fakedns.go index 531417da..ae7a1a7d 100644 --- a/app/dns/nameserver_fakedns.go +++ b/app/dns/nameserver_fakedns.go @@ -5,7 +5,6 @@ import ( "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" - "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/dns" ) @@ -13,8 +12,8 @@ type FakeDNSServer struct { fakeDNSEngine dns.FakeDNSEngine } -func NewFakeDNSServer() *FakeDNSServer { - return &FakeDNSServer{} +func NewFakeDNSServer(fd dns.FakeDNSEngine) *FakeDNSServer { + return &FakeDNSServer{fakeDNSEngine: fd} } func (FakeDNSServer) Name() string { @@ -23,12 +22,9 @@ func (FakeDNSServer) Name() string { func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, opt dns.IPOption, _ bool) ([]net.IP, error) { if f.fakeDNSEngine == nil { - if err := core.RequireFeatures(ctx, func(fd dns.FakeDNSEngine) { - f.fakeDNSEngine = fd - }); err != nil { - return nil, errors.New("Unable to locate a fake DNS Engine").Base(err).AtError() - } + return nil, errors.New("Unable to locate a fake DNS Engine").AtError() } + var ips []net.Address if fkr0, ok := f.fakeDNSEngine.(dns.FakeDNSEngineRev0); ok { ips = fkr0.GetFakeIPForDomain3(domain, opt.IPv4Enable, opt.IPv6Enable) diff --git a/app/observatory/burst/burstobserver.go b/app/observatory/burst/burstobserver.go index f2204c00..472351cc 100644 --- a/app/observatory/burst/burstobserver.go +++ b/app/observatory/burst/burstobserver.go @@ -12,6 +12,7 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" "github.com/xtls/xray-core/features/outbound" + "github.com/xtls/xray-core/features/routing" "google.golang.org/protobuf/proto" ) @@ -88,13 +89,15 @@ func (o *Observer) Close() error { func New(ctx context.Context, config *Config) (*Observer, error) { var outboundManager outbound.Manager - err := core.RequireFeatures(ctx, func(om outbound.Manager) { + var dispatcher routing.Dispatcher + err := core.RequireFeatures(ctx, func(om outbound.Manager, rd routing.Dispatcher) { outboundManager = om + dispatcher = rd }) if err != nil { return nil, errors.New("Cannot get depended features").Base(err) } - hp := NewHealthPing(ctx, config.PingConfig) + hp := NewHealthPing(ctx, dispatcher, config.PingConfig) return &Observer{ config: config, ctx: ctx, diff --git a/app/observatory/burst/healthping.go b/app/observatory/burst/healthping.go index cd4d5fc0..f0842602 100644 --- a/app/observatory/burst/healthping.go +++ b/app/observatory/burst/healthping.go @@ -9,6 +9,7 @@ import ( "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" + "github.com/xtls/xray-core/features/routing" ) // HealthPingSettings holds settings for health Checker @@ -23,6 +24,7 @@ type HealthPingSettings struct { // HealthPing is the health checker for balancers type HealthPing struct { ctx context.Context + dispatcher routing.Dispatcher access sync.Mutex ticker *time.Ticker tickerClose chan struct{} @@ -32,7 +34,7 @@ type HealthPing struct { } // NewHealthPing creates a new HealthPing with settings -func NewHealthPing(ctx context.Context, config *HealthPingConfig) *HealthPing { +func NewHealthPing(ctx context.Context, dispatcher routing.Dispatcher, config *HealthPingConfig) *HealthPing { settings := &HealthPingSettings{} if config != nil { settings = &HealthPingSettings{ @@ -65,6 +67,7 @@ func NewHealthPing(ctx context.Context, config *HealthPingConfig) *HealthPing { } return &HealthPing{ ctx: ctx, + dispatcher: dispatcher, Settings: settings, Results: nil, } @@ -149,6 +152,7 @@ func (h *HealthPing) doCheck(tags []string, duration time.Duration, rounds int) handler := tag client := newPingClient( h.ctx, + h.dispatcher, h.Settings.Destination, h.Settings.Timeout, handler, diff --git a/app/observatory/burst/ping.go b/app/observatory/burst/ping.go index de1465b9..5ea1433a 100644 --- a/app/observatory/burst/ping.go +++ b/app/observatory/burst/ping.go @@ -6,6 +6,7 @@ import ( "time" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/tagged" ) @@ -14,10 +15,10 @@ type pingClient struct { httpClient *http.Client } -func newPingClient(ctx context.Context, destination string, timeout time.Duration, handler string) *pingClient { +func newPingClient(ctx context.Context, dispatcher routing.Dispatcher, destination string, timeout time.Duration, handler string) *pingClient { return &pingClient{ destination: destination, - httpClient: newHTTPClient(ctx, handler, timeout), + httpClient: newHTTPClient(ctx, dispatcher, handler, timeout), } } @@ -28,7 +29,7 @@ func newDirectPingClient(destination string, timeout time.Duration) *pingClient } } -func newHTTPClient(ctxv context.Context, handler string, timeout time.Duration) *http.Client { +func newHTTPClient(ctxv context.Context, dispatcher routing.Dispatcher, handler string, timeout time.Duration) *http.Client { tr := &http.Transport{ DisableKeepAlives: true, DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { @@ -36,7 +37,7 @@ func newHTTPClient(ctxv context.Context, handler string, timeout time.Duration) if err != nil { return nil, err } - return tagged.Dialer(ctxv, dest, handler) + return tagged.Dialer(ctxv, dispatcher, dest, handler) }, } return &http.Client{ diff --git a/app/observatory/observer.go b/app/observatory/observer.go index f29856db..657396f6 100644 --- a/app/observatory/observer.go +++ b/app/observatory/observer.go @@ -18,6 +18,7 @@ import ( "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" "github.com/xtls/xray-core/features/outbound" + "github.com/xtls/xray-core/features/routing" "github.com/xtls/xray-core/transport/internet/tagged" "google.golang.org/protobuf/proto" ) @@ -32,6 +33,7 @@ type Observer struct { finished *done.Instance ohm outbound.Manager + dispatcher routing.Dispatcher } func (o *Observer) GetObservation(ctx context.Context) (proto.Message, error) { @@ -131,7 +133,7 @@ func (o *Observer) probe(outbound string) ProbeResult { return errors.New("cannot understand address").Base(err) } trackedCtx := session.TrackedConnectionError(o.ctx, errorCollectorForRequest) - conn, err := tagged.Dialer(trackedCtx, dest, outbound) + conn, err := tagged.Dialer(trackedCtx, o.dispatcher, dest, outbound) if err != nil { return errors.New("cannot dial remote address ", dest).Base(err) } @@ -215,8 +217,10 @@ func (o *Observer) findStatusLocationLockHolderOnly(outbound string) int { func New(ctx context.Context, config *Config) (*Observer, error) { var outboundManager outbound.Manager - err := core.RequireFeatures(ctx, func(om outbound.Manager) { + var dispatcher routing.Dispatcher + err := core.RequireFeatures(ctx, func(om outbound.Manager, rd routing.Dispatcher) { outboundManager = om + dispatcher = rd }) if err != nil { return nil, errors.New("Cannot get depended features").Base(err) @@ -225,6 +229,7 @@ func New(ctx context.Context, config *Config) (*Observer, error) { config: config, ctx: ctx, ohm: outboundManager, + dispatcher: dispatcher, }, nil } diff --git a/app/router/balancing.go b/app/router/balancing.go index 14f8e21f..7d8bb022 100644 --- a/app/router/balancing.go +++ b/app/router/balancing.go @@ -5,7 +5,6 @@ import ( sync "sync" "github.com/xtls/xray-core/app/observatory" - "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -31,6 +30,11 @@ type RoundRobinStrategy struct { func (s *RoundRobinStrategy) InjectContext(ctx context.Context) { s.ctx = ctx + if len(s.FallbackTag) > 0 { + core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + s.observatory = observatory + }) + } } func (s *RoundRobinStrategy) GetPrincipleTarget(strings []string) []string { @@ -38,12 +42,6 @@ func (s *RoundRobinStrategy) GetPrincipleTarget(strings []string) []string { } func (s *RoundRobinStrategy) PickOutbound(tags []string) string { - if len(s.FallbackTag) > 0 && s.observatory == nil { - common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { - s.observatory = observatory - return nil - })) - } if s.observatory != nil { observeReport, err := s.observatory.GetObservation(s.ctx) if err == nil { diff --git a/app/router/strategy_leastload.go b/app/router/strategy_leastload.go index bfdfd878..a4ef1c12 100644 --- a/app/router/strategy_leastload.go +++ b/app/router/strategy_leastload.go @@ -7,7 +7,6 @@ import ( "time" "github.com/xtls/xray-core/app/observatory" - "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" @@ -58,8 +57,11 @@ type node struct { RTTDeviationCost time.Duration } -func (l *LeastLoadStrategy) InjectContext(ctx context.Context) { - l.ctx = ctx +func (s *LeastLoadStrategy) InjectContext(ctx context.Context) { + s.ctx = ctx + core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + s.observer = observatory + }) } func (s *LeastLoadStrategy) PickOutbound(candidates []string) string { @@ -136,10 +138,8 @@ func (s *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node { func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node { if s.observer == nil { - common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { - s.observer = observatory - return nil - })) + errors.LogError(s.ctx, "observer is nil") + return make([]*node, 0) } observeResult, err := s.observer.GetObservation(s.ctx) if err != nil { diff --git a/app/router/strategy_leastping.go b/app/router/strategy_leastping.go index 28efe386..b13d1a7d 100644 --- a/app/router/strategy_leastping.go +++ b/app/router/strategy_leastping.go @@ -4,7 +4,6 @@ import ( "context" "github.com/xtls/xray-core/app/observatory" - "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -21,19 +20,19 @@ func (l *LeastPingStrategy) GetPrincipleTarget(strings []string) []string { func (l *LeastPingStrategy) InjectContext(ctx context.Context) { l.ctx = ctx + core.RequireFeaturesAsync(l.ctx, func(observatory extension.Observatory) { + l.observatory = observatory + }) } func (l *LeastPingStrategy) PickOutbound(strings []string) string { if l.observatory == nil { - common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error { - l.observatory = observatory - return nil - })) + errors.LogError(l.ctx, "observer is nil") + return "" } - observeReport, err := l.observatory.GetObservation(l.ctx) if err != nil { - errors.LogInfoInner(l.ctx, err, "cannot get observe report") + errors.LogInfoInner(l.ctx, err, "cannot get observer report") return "" } outboundsList := outboundList(strings) diff --git a/app/router/strategy_random.go b/app/router/strategy_random.go index ed82ff9d..9f4cdd77 100644 --- a/app/router/strategy_random.go +++ b/app/router/strategy_random.go @@ -4,7 +4,6 @@ import ( "context" "github.com/xtls/xray-core/app/observatory" - "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/dice" "github.com/xtls/xray-core/core" "github.com/xtls/xray-core/features/extension" @@ -20,6 +19,11 @@ type RandomStrategy struct { func (s *RandomStrategy) InjectContext(ctx context.Context) { s.ctx = ctx + if len(s.FallbackTag) > 0 { + core.RequireFeaturesAsync(s.ctx, func(observatory extension.Observatory) { + s.observatory = observatory + }) + } } func (s *RandomStrategy) GetPrincipleTarget(strings []string) []string { @@ -27,12 +31,6 @@ func (s *RandomStrategy) GetPrincipleTarget(strings []string) []string { } func (s *RandomStrategy) PickOutbound(candidates []string) string { - if len(s.FallbackTag) > 0 && s.observatory == nil { - common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error { - s.observatory = observatory - return nil - })) - } if s.observatory != nil { observeReport, err := s.observatory.GetObservation(s.ctx) if err == nil { diff --git a/core/xray.go b/core/xray.go index 0e1f0830..5ab10603 100644 --- a/core/xray.go +++ b/core/xray.go @@ -4,6 +4,7 @@ import ( "context" "reflect" "sync" + "time" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" @@ -156,6 +157,12 @@ func RequireFeatures(ctx context.Context, callback interface{}) error { return v.RequireFeatures(callback) } +// RequireFeaturesAsync registers a callback, which will be called when all dependent features are registered. The order of app init doesn't matter +func RequireFeaturesAsync(ctx context.Context, callback interface{}) { + v := MustFromContext(ctx) + v.RequireFeaturesAsync(callback) +} + // New returns a new Xray instance based on given configuration. // The instance is not started at this point. // To ensure Xray instance works properly, the config must contain one Dispatcher, one InboundHandlerManager and one OutboundHandlerManager. Other features are optional. @@ -290,6 +297,36 @@ func (s *Instance) RequireFeatures(callback interface{}) error { return nil } +// RequireFeaturesAsync registers a callback, which will be called when all dependent features are registered. The order of app init doesn't matter +func (s *Instance) RequireFeaturesAsync(callback interface{}) { + callbackType := reflect.TypeOf(callback) + if callbackType.Kind() != reflect.Func { + panic("not a function") + } + + var featureTypes []reflect.Type + for i := 0; i < callbackType.NumIn(); i++ { + featureTypes = append(featureTypes, reflect.PtrTo(callbackType.In(i))) + } + + r := resolution{ + deps: featureTypes, + callback: callback, + } + go func() { + var finished = false + for i := 0; !finished; i++ { + if i > 100000 { + errors.LogError(s.ctx, "RequireFeaturesAsync failed after count ", i) + break; + } + finished, _ = r.resolve(s.features) + time.Sleep(time.Millisecond) + } + s.featureResolutions = append(s.featureResolutions, r) + }() +} + // AddFeature registers a feature into current Instance. func (s *Instance) AddFeature(feature features.Feature) error { s.features = append(s.features, feature) diff --git a/proxy/dns/dns.go b/proxy/dns/dns.go index ed063197..790c80c1 100644 --- a/proxy/dns/dns.go +++ b/proxy/dns/dns.go @@ -27,7 +27,7 @@ func init() { common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { h := new(Handler) if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { - core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { + core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { // FakeDNSEngine is optional h.fdns = fdns }) return h.Init(config.(*Config), dnsClient, policyManager) diff --git a/transport/internet/tagged/tagged.go b/transport/internet/tagged/tagged.go index 430f9640..2cd9dcd2 100644 --- a/transport/internet/tagged/tagged.go +++ b/transport/internet/tagged/tagged.go @@ -4,8 +4,9 @@ import ( "context" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/features/routing" ) -type DialFunc func(ctx context.Context, dest net.Destination, tag string) (net.Conn, error) +type DialFunc func(ctx context.Context, dispatcher routing.Dispatcher, dest net.Destination, tag string) (net.Conn, error) var Dialer DialFunc diff --git a/transport/internet/tagged/taggedimpl/impl.go b/transport/internet/tagged/taggedimpl/impl.go index 29caec7c..2a773401 100644 --- a/transport/internet/tagged/taggedimpl/impl.go +++ b/transport/internet/tagged/taggedimpl/impl.go @@ -12,17 +12,10 @@ import ( "github.com/xtls/xray-core/transport/internet/tagged" ) -func DialTaggedOutbound(ctx context.Context, dest net.Destination, tag string) (net.Conn, error) { - var dispatcher routing.Dispatcher +func DialTaggedOutbound(ctx context.Context, dispatcher routing.Dispatcher, dest net.Destination, tag string) (net.Conn, error) { if core.FromContext(ctx) == nil { return nil, errors.New("Instance context variable is not in context, dial denied. ") } - if err := core.RequireFeatures(ctx, func(dispatcherInstance routing.Dispatcher) { - dispatcher = dispatcherInstance - }); err != nil { - return nil, errors.New("Required Feature dispatcher not resolved").Base(err) - } - content := new(session.Content) content.SkipDNSResolve = true