Fix memory leak with RequireFeatures()

RequireFeatures() must be called at init
This commit is contained in:
yuhan6665 2024-11-29 01:12:24 -05:00
parent 98a72b6fb4
commit 2479c9d7c2
12 changed files with 53 additions and 59 deletions

View File

@ -35,7 +35,7 @@ type Client struct {
var errExpectedIPNonMatch = errors.New("expectIPs not match")
// NewServer creates a name server object according to the network destination url.
func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy) (Server, error) {
func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrategy QueryStrategy, fd dns.FakeDNSEngine) (Server, error) {
if address := dest.Address; address.Family().IsDomain() {
u, err := url.Parse(address.Domain())
if err != nil {
@ -55,7 +55,7 @@ func NewServer(dest net.Destination, dispatcher routing.Dispatcher, queryStrateg
case strings.EqualFold(u.Scheme, "tcp+local"): // DNS-over-TCP Local mode
return NewTCPLocalNameServer(u, queryStrategy)
case strings.EqualFold(u.String(), "fakedns"):
return NewFakeDNSServer(), nil
return NewFakeDNSServer(fd), nil
}
}
if dest.Network == net.Network_Unknown {
@ -78,9 +78,9 @@ func NewClient(
) (*Client, error) {
client := &Client{}
err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher) error {
err := core.RequireFeatures(ctx, func(dispatcher routing.Dispatcher, fd dns.FakeDNSEngine) error {
// Create a new server for each client for now
server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy())
server, err := NewServer(ns.Address.AsDestination(), dispatcher, ns.GetQueryStrategy(), fd)
if err != nil {
return errors.New("failed to create nameserver").Base(err).AtWarning()
}

View File

@ -5,7 +5,6 @@ import (
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/dns"
)
@ -13,8 +12,8 @@ type FakeDNSServer struct {
fakeDNSEngine dns.FakeDNSEngine
}
func NewFakeDNSServer() *FakeDNSServer {
return &FakeDNSServer{}
func NewFakeDNSServer(fd dns.FakeDNSEngine) *FakeDNSServer {
return &FakeDNSServer{fakeDNSEngine: fd}
}
func (FakeDNSServer) Name() string {
@ -22,13 +21,6 @@ func (FakeDNSServer) Name() string {
}
func (f *FakeDNSServer) QueryIP(ctx context.Context, domain string, _ net.IP, opt dns.IPOption, _ bool) ([]net.IP, error) {
if f.fakeDNSEngine == nil {
if err := core.RequireFeatures(ctx, func(fd dns.FakeDNSEngine) {
f.fakeDNSEngine = fd
}); err != nil {
return nil, errors.New("Unable to locate a fake DNS Engine").Base(err).AtError()
}
}
var ips []net.Address
if fkr0, ok := f.fakeDNSEngine.(dns.FakeDNSEngineRev0); ok {
ips = fkr0.GetFakeIPForDomain3(domain, opt.IPv4Enable, opt.IPv6Enable)

View File

@ -12,6 +12,7 @@ import (
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/extension"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/features/routing"
"google.golang.org/protobuf/proto"
)
@ -88,13 +89,15 @@ func (o *Observer) Close() error {
func New(ctx context.Context, config *Config) (*Observer, error) {
var outboundManager outbound.Manager
err := core.RequireFeatures(ctx, func(om outbound.Manager) {
var dispatcher routing.Dispatcher
err := core.RequireFeatures(ctx, func(om outbound.Manager, rd routing.Dispatcher) {
outboundManager = om
dispatcher = rd
})
if err != nil {
return nil, errors.New("Cannot get depended features").Base(err)
}
hp := NewHealthPing(ctx, config.PingConfig)
hp := NewHealthPing(ctx, dispatcher, config.PingConfig)
return &Observer{
config: config,
ctx: ctx,

View File

@ -9,6 +9,7 @@ import (
"github.com/xtls/xray-core/common/dice"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/features/routing"
)
// HealthPingSettings holds settings for health Checker
@ -23,6 +24,7 @@ type HealthPingSettings struct {
// HealthPing is the health checker for balancers
type HealthPing struct {
ctx context.Context
dispatcher routing.Dispatcher
access sync.Mutex
ticker *time.Ticker
tickerClose chan struct{}
@ -32,7 +34,7 @@ type HealthPing struct {
}
// NewHealthPing creates a new HealthPing with settings
func NewHealthPing(ctx context.Context, config *HealthPingConfig) *HealthPing {
func NewHealthPing(ctx context.Context, dispatcher routing.Dispatcher, config *HealthPingConfig) *HealthPing {
settings := &HealthPingSettings{}
if config != nil {
settings = &HealthPingSettings{
@ -65,6 +67,7 @@ func NewHealthPing(ctx context.Context, config *HealthPingConfig) *HealthPing {
}
return &HealthPing{
ctx: ctx,
dispatcher: dispatcher,
Settings: settings,
Results: nil,
}
@ -149,6 +152,7 @@ func (h *HealthPing) doCheck(tags []string, duration time.Duration, rounds int)
handler := tag
client := newPingClient(
h.ctx,
h.dispatcher,
h.Settings.Destination,
h.Settings.Timeout,
handler,

View File

@ -6,6 +6,7 @@ import (
"time"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet/tagged"
)
@ -14,10 +15,10 @@ type pingClient struct {
httpClient *http.Client
}
func newPingClient(ctx context.Context, destination string, timeout time.Duration, handler string) *pingClient {
func newPingClient(ctx context.Context, dispatcher routing.Dispatcher, destination string, timeout time.Duration, handler string) *pingClient {
return &pingClient{
destination: destination,
httpClient: newHTTPClient(ctx, handler, timeout),
httpClient: newHTTPClient(ctx, dispatcher, handler, timeout),
}
}
@ -28,7 +29,7 @@ func newDirectPingClient(destination string, timeout time.Duration) *pingClient
}
}
func newHTTPClient(ctxv context.Context, handler string, timeout time.Duration) *http.Client {
func newHTTPClient(ctxv context.Context, dispatcher routing.Dispatcher, handler string, timeout time.Duration) *http.Client {
tr := &http.Transport{
DisableKeepAlives: true,
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
@ -36,7 +37,7 @@ func newHTTPClient(ctxv context.Context, handler string, timeout time.Duration)
if err != nil {
return nil, err
}
return tagged.Dialer(ctxv, dest, handler)
return tagged.Dialer(ctxv, dispatcher, dest, handler)
},
}
return &http.Client{

View File

@ -18,6 +18,7 @@ import (
"github.com/xtls/xray-core/core"
"github.com/xtls/xray-core/features/extension"
"github.com/xtls/xray-core/features/outbound"
"github.com/xtls/xray-core/features/routing"
"github.com/xtls/xray-core/transport/internet/tagged"
"google.golang.org/protobuf/proto"
)
@ -32,6 +33,7 @@ type Observer struct {
finished *done.Instance
ohm outbound.Manager
dispatcher routing.Dispatcher
}
func (o *Observer) GetObservation(ctx context.Context) (proto.Message, error) {
@ -131,7 +133,7 @@ func (o *Observer) probe(outbound string) ProbeResult {
return errors.New("cannot understand address").Base(err)
}
trackedCtx := session.TrackedConnectionError(o.ctx, errorCollectorForRequest)
conn, err := tagged.Dialer(trackedCtx, dest, outbound)
conn, err := tagged.Dialer(trackedCtx, o.dispatcher, dest, outbound)
if err != nil {
return errors.New("cannot dial remote address ", dest).Base(err)
}
@ -215,8 +217,10 @@ func (o *Observer) findStatusLocationLockHolderOnly(outbound string) int {
func New(ctx context.Context, config *Config) (*Observer, error) {
var outboundManager outbound.Manager
err := core.RequireFeatures(ctx, func(om outbound.Manager) {
var dispatcher routing.Dispatcher
err := core.RequireFeatures(ctx, func(om outbound.Manager, rd routing.Dispatcher) {
outboundManager = om
dispatcher = rd
})
if err != nil {
return nil, errors.New("Cannot get depended features").Base(err)
@ -225,6 +229,7 @@ func New(ctx context.Context, config *Config) (*Observer, error) {
config: config,
ctx: ctx,
ohm: outboundManager,
dispatcher: dispatcher,
}, nil
}

View File

@ -31,6 +31,12 @@ type RoundRobinStrategy struct {
func (s *RoundRobinStrategy) InjectContext(ctx context.Context) {
s.ctx = ctx
if len(s.FallbackTag) > 0 {
common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
s.observatory = observatory
return nil
}))
}
}
func (s *RoundRobinStrategy) GetPrincipleTarget(strings []string) []string {
@ -38,12 +44,6 @@ func (s *RoundRobinStrategy) GetPrincipleTarget(strings []string) []string {
}
func (s *RoundRobinStrategy) PickOutbound(tags []string) string {
if len(s.FallbackTag) > 0 && s.observatory == nil {
common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
s.observatory = observatory
return nil
}))
}
if s.observatory != nil {
observeReport, err := s.observatory.GetObservation(s.ctx)
if err == nil {

View File

@ -58,8 +58,12 @@ type node struct {
RTTDeviationCost time.Duration
}
func (l *LeastLoadStrategy) InjectContext(ctx context.Context) {
l.ctx = ctx
func (s *LeastLoadStrategy) InjectContext(ctx context.Context) {
s.ctx = ctx
common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
s.observer = observatory
return nil
}))
}
func (s *LeastLoadStrategy) PickOutbound(candidates []string) string {
@ -135,12 +139,6 @@ func (s *LeastLoadStrategy) selectLeastLoad(nodes []*node) []*node {
}
func (s *LeastLoadStrategy) getNodes(candidates []string, maxRTT time.Duration) []*node {
if s.observer == nil {
common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
s.observer = observatory
return nil
}))
}
observeResult, err := s.observer.GetObservation(s.ctx)
if err != nil {
errors.LogInfoInner(s.ctx, err, "cannot get observation")

View File

@ -21,16 +21,13 @@ func (l *LeastPingStrategy) GetPrincipleTarget(strings []string) []string {
func (l *LeastPingStrategy) InjectContext(ctx context.Context) {
l.ctx = ctx
common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error {
l.observatory = observatory
return nil
}))
}
func (l *LeastPingStrategy) PickOutbound(strings []string) string {
if l.observatory == nil {
common.Must(core.RequireFeatures(l.ctx, func(observatory extension.Observatory) error {
l.observatory = observatory
return nil
}))
}
observeReport, err := l.observatory.GetObservation(l.ctx)
if err != nil {
errors.LogInfoInner(l.ctx, err, "cannot get observe report")

View File

@ -20,6 +20,12 @@ type RandomStrategy struct {
func (s *RandomStrategy) InjectContext(ctx context.Context) {
s.ctx = ctx
if len(s.FallbackTag) > 0 {
common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
s.observatory = observatory
return nil
}))
}
}
func (s *RandomStrategy) GetPrincipleTarget(strings []string) []string {
@ -27,12 +33,6 @@ func (s *RandomStrategy) GetPrincipleTarget(strings []string) []string {
}
func (s *RandomStrategy) PickOutbound(candidates []string) string {
if len(s.FallbackTag) > 0 && s.observatory == nil {
common.Must(core.RequireFeatures(s.ctx, func(observatory extension.Observatory) error {
s.observatory = observatory
return nil
}))
}
if s.observatory != nil {
observeReport, err := s.observatory.GetObservation(s.ctx)
if err == nil {

View File

@ -4,8 +4,9 @@ import (
"context"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/features/routing"
)
type DialFunc func(ctx context.Context, dest net.Destination, tag string) (net.Conn, error)
type DialFunc func(ctx context.Context, dispatcher routing.Dispatcher, dest net.Destination, tag string) (net.Conn, error)
var Dialer DialFunc

View File

@ -12,17 +12,10 @@ import (
"github.com/xtls/xray-core/transport/internet/tagged"
)
func DialTaggedOutbound(ctx context.Context, dest net.Destination, tag string) (net.Conn, error) {
var dispatcher routing.Dispatcher
func DialTaggedOutbound(ctx context.Context, dispatcher routing.Dispatcher, dest net.Destination, tag string) (net.Conn, error) {
if core.FromContext(ctx) == nil {
return nil, errors.New("Instance context variable is not in context, dial denied. ")
}
if err := core.RequireFeatures(ctx, func(dispatcherInstance routing.Dispatcher) {
dispatcher = dispatcherInstance
}); err != nil {
return nil, errors.New("Required Feature dispatcher not resolved").Base(err)
}
content := new(session.Content)
content.SkipDNSResolve = true