yuhan6665 02cd3b8c74
Fix SplitHTTP race condition when creating new sessions (#3533)
Co-authored-by: nobody <nobody@nowhere.mars>
Co-authored-by: mmmray <142015632+mmmray@users.noreply.github.com>
2024-07-17 11:41:17 +00:00

329 lines
8.4 KiB
Go

package splithttp
import (
"context"
"crypto/tls"
"io"
gonet "net"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
http_proto "github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/stat"
v2tls "github.com/xtls/xray-core/transport/internet/tls"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
type requestHandler struct {
host string
path string
ln *Listener
sessionMu *sync.Mutex
sessions sync.Map
localAddr gonet.TCPAddr
}
type httpSession struct {
uploadQueue *uploadQueue
// for as long as the GET request is not opened by the client, this will be
// open ("undone"), and the session may be expired within a certain TTL.
// after the client connects, this becomes "done" and the session lives as
// long as the GET request.
isFullyConnected *done.Instance
}
func (h *requestHandler) maybeReapSession(isFullyConnected *done.Instance, sessionId string) {
shouldReap := done.New()
go func() {
time.Sleep(30 * time.Second)
shouldReap.Close()
}()
select {
case <-isFullyConnected.Wait():
return
case <-shouldReap.Wait():
h.sessions.Delete(sessionId)
}
}
func (h *requestHandler) upsertSession(sessionId string) *httpSession {
// fast path
currentSessionAny, ok := h.sessions.Load(sessionId)
if ok {
return currentSessionAny.(*httpSession)
}
// slow path
h.sessionMu.Lock()
defer h.sessionMu.Unlock()
currentSessionAny, ok = h.sessions.Load(sessionId)
if ok {
return currentSessionAny.(*httpSession)
}
s := &httpSession{
uploadQueue: NewUploadQueue(int(2 * h.ln.config.GetNormalizedMaxConcurrentUploads())),
isFullyConnected: done.New(),
}
h.sessions.Store(sessionId, s)
go h.maybeReapSession(s.isFullyConnected, sessionId)
return s
}
func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
if len(h.host) > 0 && !internet.IsValidHTTPHost(request.Host, h.host) {
errors.LogInfo(context.Background(), "failed to validate host, request:", request.Host, ", config:", h.host)
writer.WriteHeader(http.StatusNotFound)
return
}
if !strings.HasPrefix(request.URL.Path, h.path) {
errors.LogInfo(context.Background(), "failed to validate path, request:", request.URL.Path, ", config:", h.path)
writer.WriteHeader(http.StatusNotFound)
return
}
sessionId := ""
subpath := strings.Split(request.URL.Path[len(h.path):], "/")
if len(subpath) > 0 {
sessionId = subpath[0]
}
if sessionId == "" {
errors.LogInfo(context.Background(), "no sessionid on request:", request.URL.Path)
writer.WriteHeader(http.StatusBadRequest)
return
}
forwardedAddrs := http_proto.ParseXForwardedFor(request.Header)
remoteAddr, err := gonet.ResolveTCPAddr("tcp", request.RemoteAddr)
if err != nil {
remoteAddr = &gonet.TCPAddr{}
}
if len(forwardedAddrs) > 0 && forwardedAddrs[0].Family().IsIP() {
remoteAddr = &net.TCPAddr{
IP: forwardedAddrs[0].IP(),
Port: int(0),
}
}
currentSession := h.upsertSession(sessionId)
if request.Method == "POST" {
seq := ""
if len(subpath) > 1 {
seq = subpath[1]
}
if seq == "" {
errors.LogInfo(context.Background(), "no seq on request:", request.URL.Path)
writer.WriteHeader(http.StatusBadRequest)
return
}
payload, err := io.ReadAll(request.Body)
if err != nil {
errors.LogInfoInner(context.Background(), err, "failed to upload")
writer.WriteHeader(http.StatusInternalServerError)
return
}
seqInt, err := strconv.ParseUint(seq, 10, 64)
if err != nil {
errors.LogInfoInner(context.Background(), err, "failed to upload")
writer.WriteHeader(http.StatusInternalServerError)
return
}
err = currentSession.uploadQueue.Push(Packet{
Payload: payload,
Seq: seqInt,
})
if err != nil {
errors.LogInfoInner(context.Background(), err, "failed to upload")
writer.WriteHeader(http.StatusInternalServerError)
return
}
writer.WriteHeader(http.StatusOK)
} else if request.Method == "GET" {
responseFlusher, ok := writer.(http.Flusher)
if !ok {
panic("expected http.ResponseWriter to be an http.Flusher")
}
// after GET is done, the connection is finished. disable automatic
// session reaping, and handle it in defer
currentSession.isFullyConnected.Close()
defer h.sessions.Delete(sessionId)
// magic header instructs nginx + apache to not buffer response body
writer.Header().Set("X-Accel-Buffering", "no")
// magic header to make the HTTP middle box consider this as SSE to disable buffer
writer.Header().Set("Content-Type", "text/event-stream")
writer.WriteHeader(http.StatusOK)
// send a chunk immediately to enable CDN streaming.
// many CDN buffer the response headers until the origin starts sending
// the body, with no way to turn it off.
writer.Write([]byte("ok"))
responseFlusher.Flush()
downloadDone := done.New()
conn := splitConn{
writer: &httpResponseBodyWriter{
responseWriter: writer,
downloadDone: downloadDone,
responseFlusher: responseFlusher,
},
reader: currentSession.uploadQueue,
remoteAddr: remoteAddr,
}
h.ln.addConn(stat.Connection(&conn))
// "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
<-downloadDone.Wait()
} else {
writer.WriteHeader(http.StatusMethodNotAllowed)
}
}
type httpResponseBodyWriter struct {
sync.Mutex
responseWriter http.ResponseWriter
responseFlusher http.Flusher
downloadDone *done.Instance
}
func (c *httpResponseBodyWriter) Write(b []byte) (int, error) {
c.Lock()
defer c.Unlock()
if c.downloadDone.Done() {
return 0, io.ErrClosedPipe
}
n, err := c.responseWriter.Write(b)
if err == nil {
c.responseFlusher.Flush()
}
return n, err
}
func (c *httpResponseBodyWriter) Close() error {
c.Lock()
defer c.Unlock()
c.downloadDone.Close()
return nil
}
type Listener struct {
sync.Mutex
server http.Server
listener net.Listener
config *Config
addConn internet.ConnHandler
}
func ListenSH(ctx context.Context, address net.Address, port net.Port, streamSettings *internet.MemoryStreamConfig, addConn internet.ConnHandler) (internet.Listener, error) {
l := &Listener{
addConn: addConn,
}
shSettings := streamSettings.ProtocolSettings.(*Config)
l.config = shSettings
if l.config != nil {
if streamSettings.SocketSettings == nil {
streamSettings.SocketSettings = &internet.SocketConfig{}
}
}
var listener net.Listener
var err error
var localAddr = gonet.TCPAddr{}
if port == net.Port(0) { // unix
listener, err = internet.ListenSystem(ctx, &net.UnixAddr{
Name: address.Domain(),
Net: "unix",
}, streamSettings.SocketSettings)
if err != nil {
return nil, errors.New("failed to listen unix domain socket(for SH) on ", address).Base(err)
}
errors.LogInfo(ctx, "listening unix domain socket(for SH) on ", address)
} else { // tcp
localAddr = gonet.TCPAddr{
IP: address.IP(),
Port: int(port),
}
listener, err = internet.ListenSystem(ctx, &net.TCPAddr{
IP: address.IP(),
Port: int(port),
}, streamSettings.SocketSettings)
if err != nil {
return nil, errors.New("failed to listen TCP(for SH) on ", address, ":", port).Base(err)
}
errors.LogInfo(ctx, "listening TCP(for SH) on ", address, ":", port)
}
if config := v2tls.ConfigFromStreamSettings(streamSettings); config != nil {
if tlsConfig := config.GetTLSConfig(); tlsConfig != nil {
listener = tls.NewListener(listener, tlsConfig)
}
}
handler := &requestHandler{
host: shSettings.Host,
path: shSettings.GetNormalizedPath(),
ln: l,
sessionMu: &sync.Mutex{},
sessions: sync.Map{},
localAddr: localAddr,
}
// h2cHandler can handle both plaintext HTTP/1.1 and h2c
h2cHandler := h2c.NewHandler(handler, &http2.Server{})
l.listener = listener
l.server = http.Server{
Handler: h2cHandler,
ReadHeaderTimeout: time.Second * 4,
MaxHeaderBytes: 8192,
}
go func() {
if err := l.server.Serve(l.listener); err != nil {
errors.LogWarningInner(ctx, err, "failed to serve http for splithttp")
}
}()
return l, err
}
// Addr implements net.Listener.Addr().
func (ln *Listener) Addr() net.Addr {
return ln.listener.Addr()
}
// Close implements net.Listener.Close().
func (ln *Listener) Close() error {
return ln.listener.Close()
}
func init() {
common.Must(internet.RegisterTransportListener(protocolName, ListenSH))
}