mirror of
https://github.com/XTLS/Xray-core.git
synced 2024-11-21 20:56:06 +03:00
790f908f0b
Signed-off-by: curlwget <curlwget@icloud.com>
394 lines
9.1 KiB
Go
394 lines
9.1 KiB
Go
package mux
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/xtls/xray-core/common"
|
|
"github.com/xtls/xray-core/common/buf"
|
|
"github.com/xtls/xray-core/common/errors"
|
|
"github.com/xtls/xray-core/common/net"
|
|
"github.com/xtls/xray-core/common/protocol"
|
|
"github.com/xtls/xray-core/common/session"
|
|
"github.com/xtls/xray-core/common/signal/done"
|
|
"github.com/xtls/xray-core/common/task"
|
|
"github.com/xtls/xray-core/common/xudp"
|
|
"github.com/xtls/xray-core/proxy"
|
|
"github.com/xtls/xray-core/transport"
|
|
"github.com/xtls/xray-core/transport/internet"
|
|
"github.com/xtls/xray-core/transport/pipe"
|
|
)
|
|
|
|
type ClientManager struct {
|
|
Enabled bool // whether mux is enabled from user config
|
|
Picker WorkerPicker
|
|
}
|
|
|
|
func (m *ClientManager) Dispatch(ctx context.Context, link *transport.Link) error {
|
|
for i := 0; i < 16; i++ {
|
|
worker, err := m.Picker.PickAvailable()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if worker.Dispatch(ctx, link) {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return errors.New("unable to find an available mux client").AtWarning()
|
|
}
|
|
|
|
type WorkerPicker interface {
|
|
PickAvailable() (*ClientWorker, error)
|
|
}
|
|
|
|
type IncrementalWorkerPicker struct {
|
|
Factory ClientWorkerFactory
|
|
|
|
access sync.Mutex
|
|
workers []*ClientWorker
|
|
cleanupTask *task.Periodic
|
|
}
|
|
|
|
func (p *IncrementalWorkerPicker) cleanupFunc() error {
|
|
p.access.Lock()
|
|
defer p.access.Unlock()
|
|
|
|
if len(p.workers) == 0 {
|
|
return errors.New("no worker")
|
|
}
|
|
|
|
p.cleanup()
|
|
return nil
|
|
}
|
|
|
|
func (p *IncrementalWorkerPicker) cleanup() {
|
|
var activeWorkers []*ClientWorker
|
|
for _, w := range p.workers {
|
|
if !w.Closed() {
|
|
activeWorkers = append(activeWorkers, w)
|
|
}
|
|
}
|
|
p.workers = activeWorkers
|
|
}
|
|
|
|
func (p *IncrementalWorkerPicker) findAvailable() int {
|
|
for idx, w := range p.workers {
|
|
if !w.IsFull() {
|
|
return idx
|
|
}
|
|
}
|
|
|
|
return -1
|
|
}
|
|
|
|
func (p *IncrementalWorkerPicker) pickInternal() (*ClientWorker, bool, error) {
|
|
p.access.Lock()
|
|
defer p.access.Unlock()
|
|
|
|
idx := p.findAvailable()
|
|
if idx >= 0 {
|
|
n := len(p.workers)
|
|
if n > 1 && idx != n-1 {
|
|
p.workers[n-1], p.workers[idx] = p.workers[idx], p.workers[n-1]
|
|
}
|
|
return p.workers[idx], false, nil
|
|
}
|
|
|
|
p.cleanup()
|
|
|
|
worker, err := p.Factory.Create()
|
|
if err != nil {
|
|
return nil, false, err
|
|
}
|
|
p.workers = append(p.workers, worker)
|
|
|
|
if p.cleanupTask == nil {
|
|
p.cleanupTask = &task.Periodic{
|
|
Interval: time.Second * 30,
|
|
Execute: p.cleanupFunc,
|
|
}
|
|
}
|
|
|
|
return worker, true, nil
|
|
}
|
|
|
|
func (p *IncrementalWorkerPicker) PickAvailable() (*ClientWorker, error) {
|
|
worker, start, err := p.pickInternal()
|
|
if start {
|
|
common.Must(p.cleanupTask.Start())
|
|
}
|
|
|
|
return worker, err
|
|
}
|
|
|
|
type ClientWorkerFactory interface {
|
|
Create() (*ClientWorker, error)
|
|
}
|
|
|
|
type DialingWorkerFactory struct {
|
|
Proxy proxy.Outbound
|
|
Dialer internet.Dialer
|
|
Strategy ClientStrategy
|
|
}
|
|
|
|
func (f *DialingWorkerFactory) Create() (*ClientWorker, error) {
|
|
opts := []pipe.Option{pipe.WithSizeLimit(64 * 1024)}
|
|
uplinkReader, upLinkWriter := pipe.New(opts...)
|
|
downlinkReader, downlinkWriter := pipe.New(opts...)
|
|
|
|
c, err := NewClientWorker(transport.Link{
|
|
Reader: downlinkReader,
|
|
Writer: upLinkWriter,
|
|
}, f.Strategy)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
go func(p proxy.Outbound, d internet.Dialer, c common.Closable) {
|
|
outbounds := []*session.Outbound{{
|
|
Target: net.TCPDestination(muxCoolAddress, muxCoolPort),
|
|
}}
|
|
ctx := session.ContextWithOutbounds(context.Background(), outbounds)
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
|
|
if err := p.Process(ctx, &transport.Link{Reader: uplinkReader, Writer: downlinkWriter}, d); err != nil {
|
|
errors.LogInfoInner(ctx, err, "failed to handler mux client connection")
|
|
}
|
|
common.Must(c.Close())
|
|
cancel()
|
|
}(f.Proxy, f.Dialer, c.done)
|
|
|
|
return c, nil
|
|
}
|
|
|
|
type ClientStrategy struct {
|
|
MaxConcurrency uint32
|
|
MaxConnection uint32
|
|
}
|
|
|
|
type ClientWorker struct {
|
|
sessionManager *SessionManager
|
|
link transport.Link
|
|
done *done.Instance
|
|
strategy ClientStrategy
|
|
}
|
|
|
|
var (
|
|
muxCoolAddress = net.DomainAddress("v1.mux.cool")
|
|
muxCoolPort = net.Port(9527)
|
|
)
|
|
|
|
// NewClientWorker creates a new mux.Client.
|
|
func NewClientWorker(stream transport.Link, s ClientStrategy) (*ClientWorker, error) {
|
|
c := &ClientWorker{
|
|
sessionManager: NewSessionManager(),
|
|
link: stream,
|
|
done: done.New(),
|
|
strategy: s,
|
|
}
|
|
|
|
go c.fetchOutput()
|
|
go c.monitor()
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func (m *ClientWorker) TotalConnections() uint32 {
|
|
return uint32(m.sessionManager.Count())
|
|
}
|
|
|
|
func (m *ClientWorker) ActiveConnections() uint32 {
|
|
return uint32(m.sessionManager.Size())
|
|
}
|
|
|
|
// Closed returns true if this Client is closed.
|
|
func (m *ClientWorker) Closed() bool {
|
|
return m.done.Done()
|
|
}
|
|
|
|
func (m *ClientWorker) monitor() {
|
|
timer := time.NewTicker(time.Second * 16)
|
|
defer timer.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-m.done.Wait():
|
|
m.sessionManager.Close()
|
|
common.Close(m.link.Writer)
|
|
common.Interrupt(m.link.Reader)
|
|
return
|
|
case <-timer.C:
|
|
size := m.sessionManager.Size()
|
|
if size == 0 && m.sessionManager.CloseIfNoSession() {
|
|
common.Must(m.done.Close())
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func writeFirstPayload(reader buf.Reader, writer *Writer) error {
|
|
err := buf.CopyOnceTimeout(reader, writer, time.Millisecond*100)
|
|
if err == buf.ErrNotTimeoutReader || err == buf.ErrReadTimeout {
|
|
return writer.WriteMultiBuffer(buf.MultiBuffer{})
|
|
}
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func fetchInput(ctx context.Context, s *Session, output buf.Writer) {
|
|
outbounds := session.OutboundsFromContext(ctx)
|
|
ob := outbounds[len(outbounds)-1]
|
|
transferType := protocol.TransferTypeStream
|
|
if ob.Target.Network == net.Network_UDP {
|
|
transferType = protocol.TransferTypePacket
|
|
}
|
|
s.transferType = transferType
|
|
writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx))
|
|
defer s.Close(false)
|
|
defer writer.Close()
|
|
|
|
errors.LogInfo(ctx, "dispatching request to ", ob.Target)
|
|
if err := writeFirstPayload(s.input, writer); err != nil {
|
|
errors.LogInfoInner(ctx, err, "failed to write first payload")
|
|
writer.hasError = true
|
|
return
|
|
}
|
|
|
|
if err := buf.Copy(s.input, writer); err != nil {
|
|
errors.LogInfoInner(ctx, err, "failed to fetch all input")
|
|
writer.hasError = true
|
|
return
|
|
}
|
|
}
|
|
|
|
func (m *ClientWorker) IsClosing() bool {
|
|
sm := m.sessionManager
|
|
if m.strategy.MaxConnection > 0 && sm.Count() >= int(m.strategy.MaxConnection) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *ClientWorker) IsFull() bool {
|
|
if m.IsClosing() || m.Closed() {
|
|
return true
|
|
}
|
|
|
|
sm := m.sessionManager
|
|
if m.strategy.MaxConcurrency > 0 && sm.Size() >= int(m.strategy.MaxConcurrency) {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
func (m *ClientWorker) Dispatch(ctx context.Context, link *transport.Link) bool {
|
|
if m.IsFull() || m.Closed() {
|
|
return false
|
|
}
|
|
|
|
sm := m.sessionManager
|
|
s := sm.Allocate()
|
|
if s == nil {
|
|
return false
|
|
}
|
|
s.input = link.Reader
|
|
s.output = link.Writer
|
|
go fetchInput(ctx, s, m.link.Writer)
|
|
return true
|
|
}
|
|
|
|
func (m *ClientWorker) handleStatueKeepAlive(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
|
if meta.Option.Has(OptionData) {
|
|
return buf.Copy(NewStreamReader(reader), buf.Discard)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *ClientWorker) handleStatusNew(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
|
if meta.Option.Has(OptionData) {
|
|
return buf.Copy(NewStreamReader(reader), buf.Discard)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *ClientWorker) handleStatusKeep(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
|
if !meta.Option.Has(OptionData) {
|
|
return nil
|
|
}
|
|
|
|
s, found := m.sessionManager.Get(meta.SessionID)
|
|
if !found {
|
|
// Notify remote peer to close this session.
|
|
closingWriter := NewResponseWriter(meta.SessionID, m.link.Writer, protocol.TransferTypeStream)
|
|
closingWriter.Close()
|
|
|
|
return buf.Copy(NewStreamReader(reader), buf.Discard)
|
|
}
|
|
|
|
rr := s.NewReader(reader, &meta.Target)
|
|
err := buf.Copy(rr, s.output)
|
|
if err != nil && buf.IsWriteError(err) {
|
|
errors.LogInfoInner(context.Background(), err, "failed to write to downstream. closing session ", s.ID)
|
|
s.Close(false)
|
|
return buf.Copy(rr, buf.Discard)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (m *ClientWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.BufferedReader) error {
|
|
if s, found := m.sessionManager.Get(meta.SessionID); found {
|
|
s.Close(false)
|
|
}
|
|
if meta.Option.Has(OptionData) {
|
|
return buf.Copy(NewStreamReader(reader), buf.Discard)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (m *ClientWorker) fetchOutput() {
|
|
defer func() {
|
|
common.Must(m.done.Close())
|
|
}()
|
|
|
|
reader := &buf.BufferedReader{Reader: m.link.Reader}
|
|
|
|
var meta FrameMetadata
|
|
for {
|
|
err := meta.Unmarshal(reader)
|
|
if err != nil {
|
|
if errors.Cause(err) != io.EOF {
|
|
errors.LogInfoInner(context.Background(), err, "failed to read metadata")
|
|
}
|
|
break
|
|
}
|
|
|
|
switch meta.SessionStatus {
|
|
case SessionStatusKeepAlive:
|
|
err = m.handleStatueKeepAlive(&meta, reader)
|
|
case SessionStatusEnd:
|
|
err = m.handleStatusEnd(&meta, reader)
|
|
case SessionStatusNew:
|
|
err = m.handleStatusNew(&meta, reader)
|
|
case SessionStatusKeep:
|
|
err = m.handleStatusKeep(&meta, reader)
|
|
default:
|
|
status := meta.SessionStatus
|
|
errors.LogError(context.Background(), "unknown status: ", status)
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
errors.LogInfoInner(context.Background(), err, "failed to process data")
|
|
return
|
|
}
|
|
}
|
|
}
|