2024-06-18 08:36:36 +03:00
package splithttp
import (
"context"
"crypto/tls"
"io"
gonet "net"
"net/http"
"strconv"
2024-06-21 02:30:51 +03:00
"strings"
2024-06-18 08:36:36 +03:00
"sync"
"time"
2024-07-19 20:53:47 +03:00
"github.com/quic-go/quic-go"
"github.com/quic-go/quic-go/http3"
2024-06-18 08:36:36 +03:00
"github.com/xtls/xray-core/common"
2024-06-29 21:32:57 +03:00
"github.com/xtls/xray-core/common/errors"
2024-06-18 08:36:36 +03:00
"github.com/xtls/xray-core/common/net"
http_proto "github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/stat"
v2tls "github.com/xtls/xray-core/transport/internet/tls"
2024-06-23 20:05:37 +03:00
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
2024-06-18 08:36:36 +03:00
)
type requestHandler struct {
2024-07-29 09:32:04 +03:00
config * Config
2024-06-18 08:36:36 +03:00
host string
path string
ln * Listener
2024-07-17 14:41:17 +03:00
sessionMu * sync . Mutex
2024-06-18 08:36:36 +03:00
sessions sync . Map
localAddr gonet . TCPAddr
}
2024-06-21 02:30:51 +03:00
type httpSession struct {
2024-07-11 10:56:20 +03:00
uploadQueue * uploadQueue
2024-06-21 02:30:51 +03:00
// for as long as the GET request is not opened by the client, this will be
// open ("undone"), and the session may be expired within a certain TTL.
// after the client connects, this becomes "done" and the session lives as
// long as the GET request.
isFullyConnected * done . Instance
}
func ( h * requestHandler ) maybeReapSession ( isFullyConnected * done . Instance , sessionId string ) {
shouldReap := done . New ( )
go func ( ) {
time . Sleep ( 30 * time . Second )
shouldReap . Close ( )
} ( )
select {
case <- isFullyConnected . Wait ( ) :
return
case <- shouldReap . Wait ( ) :
h . sessions . Delete ( sessionId )
}
}
func ( h * requestHandler ) upsertSession ( sessionId string ) * httpSession {
2024-07-17 14:41:17 +03:00
// fast path
2024-06-21 02:30:51 +03:00
currentSessionAny , ok := h . sessions . Load ( sessionId )
if ok {
return currentSessionAny . ( * httpSession )
}
2024-07-17 14:41:17 +03:00
// slow path
h . sessionMu . Lock ( )
defer h . sessionMu . Unlock ( )
currentSessionAny , ok = h . sessions . Load ( sessionId )
if ok {
return currentSessionAny . ( * httpSession )
}
2024-06-21 02:30:51 +03:00
s := & httpSession {
2024-07-29 07:35:17 +03:00
uploadQueue : NewUploadQueue ( int ( h . ln . config . GetNormalizedMaxConcurrentUploads ( true ) . To ) ) ,
2024-06-21 02:30:51 +03:00
isFullyConnected : done . New ( ) ,
}
h . sessions . Store ( sessionId , s )
go h . maybeReapSession ( s . isFullyConnected , sessionId )
return s
}
2024-06-18 08:36:36 +03:00
func ( h * requestHandler ) ServeHTTP ( writer http . ResponseWriter , request * http . Request ) {
2024-07-07 00:12:49 +03:00
if len ( h . host ) > 0 && ! internet . IsValidHTTPHost ( request . Host , h . host ) {
2024-06-29 21:32:57 +03:00
errors . LogInfo ( context . Background ( ) , "failed to validate host, request:" , request . Host , ", config:" , h . host )
2024-06-18 08:36:36 +03:00
writer . WriteHeader ( http . StatusNotFound )
return
}
2024-06-21 02:30:51 +03:00
if ! strings . HasPrefix ( request . URL . Path , h . path ) {
2024-06-29 21:32:57 +03:00
errors . LogInfo ( context . Background ( ) , "failed to validate path, request:" , request . URL . Path , ", config:" , h . path )
2024-06-18 08:36:36 +03:00
writer . WriteHeader ( http . StatusNotFound )
return
}
2024-06-21 02:30:51 +03:00
sessionId := ""
subpath := strings . Split ( request . URL . Path [ len ( h . path ) : ] , "/" )
if len ( subpath ) > 0 {
sessionId = subpath [ 0 ]
}
2024-06-18 08:36:36 +03:00
if sessionId == "" {
2024-06-29 21:32:57 +03:00
errors . LogInfo ( context . Background ( ) , "no sessionid on request:" , request . URL . Path )
2024-06-18 08:36:36 +03:00
writer . WriteHeader ( http . StatusBadRequest )
return
}
forwardedAddrs := http_proto . ParseXForwardedFor ( request . Header )
remoteAddr , err := gonet . ResolveTCPAddr ( "tcp" , request . RemoteAddr )
if err != nil {
remoteAddr = & gonet . TCPAddr { }
}
if len ( forwardedAddrs ) > 0 && forwardedAddrs [ 0 ] . Family ( ) . IsIP ( ) {
remoteAddr = & net . TCPAddr {
IP : forwardedAddrs [ 0 ] . IP ( ) ,
Port : int ( 0 ) ,
}
}
2024-06-21 02:30:51 +03:00
currentSession := h . upsertSession ( sessionId )
2024-07-29 07:35:17 +03:00
maxUploadSize := int ( h . ln . config . GetNormalizedMaxUploadSize ( true ) . To )
2024-06-21 02:30:51 +03:00
2024-06-18 08:36:36 +03:00
if request . Method == "POST" {
2024-06-21 02:30:51 +03:00
seq := ""
if len ( subpath ) > 1 {
seq = subpath [ 1 ]
2024-06-18 08:36:36 +03:00
}
if seq == "" {
2024-06-29 21:32:57 +03:00
errors . LogInfo ( context . Background ( ) , "no seq on request:" , request . URL . Path )
2024-06-18 08:36:36 +03:00
writer . WriteHeader ( http . StatusBadRequest )
return
}
payload , err := io . ReadAll ( request . Body )
2024-07-29 07:35:17 +03:00
if len ( payload ) > maxUploadSize {
errors . LogInfo ( context . Background ( ) , "Too large upload. maxUploadSize is set to" , maxUploadSize , "but request had size" , len ( payload ) , ". Adjust maxUploadSize on the server to be at least as large as client." )
writer . WriteHeader ( http . StatusRequestEntityTooLarge )
return
}
2024-06-18 08:36:36 +03:00
if err != nil {
2024-06-29 21:32:57 +03:00
errors . LogInfoInner ( context . Background ( ) , err , "failed to upload" )
2024-06-18 08:36:36 +03:00
writer . WriteHeader ( http . StatusInternalServerError )
return
}
seqInt , err := strconv . ParseUint ( seq , 10 , 64 )
if err != nil {
2024-06-29 21:32:57 +03:00
errors . LogInfoInner ( context . Background ( ) , err , "failed to upload" )
2024-06-18 08:36:36 +03:00
writer . WriteHeader ( http . StatusInternalServerError )
return
}
2024-06-21 02:30:51 +03:00
err = currentSession . uploadQueue . Push ( Packet {
2024-06-18 08:36:36 +03:00
Payload : payload ,
Seq : seqInt ,
} )
if err != nil {
2024-06-29 21:32:57 +03:00
errors . LogInfoInner ( context . Background ( ) , err , "failed to upload" )
2024-06-18 08:36:36 +03:00
writer . WriteHeader ( http . StatusInternalServerError )
return
}
writer . WriteHeader ( http . StatusOK )
} else if request . Method == "GET" {
responseFlusher , ok := writer . ( http . Flusher )
if ! ok {
panic ( "expected http.ResponseWriter to be an http.Flusher" )
}
2024-06-21 02:30:51 +03:00
// after GET is done, the connection is finished. disable automatic
// session reaping, and handle it in defer
currentSession . isFullyConnected . Close ( )
2024-06-18 08:36:36 +03:00
defer h . sessions . Delete ( sessionId )
// magic header instructs nginx + apache to not buffer response body
writer . Header ( ) . Set ( "X-Accel-Buffering" , "no" )
2024-07-29 09:32:04 +03:00
if ! h . config . NoSSEHeader {
// magic header to make the HTTP middle box consider this as SSE to disable buffer
writer . Header ( ) . Set ( "Content-Type" , "text/event-stream" )
}
2024-07-11 10:56:20 +03:00
2024-06-18 08:36:36 +03:00
writer . WriteHeader ( http . StatusOK )
// send a chunk immediately to enable CDN streaming.
// many CDN buffer the response headers until the origin starts sending
// the body, with no way to turn it off.
writer . Write ( [ ] byte ( "ok" ) )
responseFlusher . Flush ( )
downloadDone := done . New ( )
conn := splitConn {
writer : & httpResponseBodyWriter {
responseWriter : writer ,
downloadDone : downloadDone ,
responseFlusher : responseFlusher ,
} ,
2024-06-21 02:30:51 +03:00
reader : currentSession . uploadQueue ,
2024-06-18 08:36:36 +03:00
remoteAddr : remoteAddr ,
}
h . ln . addConn ( stat . Connection ( & conn ) )
// "A ResponseWriter may not be used after [Handler.ServeHTTP] has returned."
<- downloadDone . Wait ( )
} else {
writer . WriteHeader ( http . StatusMethodNotAllowed )
}
}
type httpResponseBodyWriter struct {
sync . Mutex
responseWriter http . ResponseWriter
responseFlusher http . Flusher
downloadDone * done . Instance
}
func ( c * httpResponseBodyWriter ) Write ( b [ ] byte ) ( int , error ) {
c . Lock ( )
defer c . Unlock ( )
if c . downloadDone . Done ( ) {
return 0 , io . ErrClosedPipe
}
n , err := c . responseWriter . Write ( b )
if err == nil {
c . responseFlusher . Flush ( )
}
return n , err
}
func ( c * httpResponseBodyWriter ) Close ( ) error {
c . Lock ( )
defer c . Unlock ( )
c . downloadDone . Close ( )
return nil
}
type Listener struct {
sync . Mutex
2024-07-19 20:53:47 +03:00
server http . Server
h3server * http3 . Server
listener net . Listener
h3listener * quic . EarlyListener
config * Config
addConn internet . ConnHandler
isH3 bool
2024-06-18 08:36:36 +03:00
}
func ListenSH ( ctx context . Context , address net . Address , port net . Port , streamSettings * internet . MemoryStreamConfig , addConn internet . ConnHandler ) ( internet . Listener , error ) {
l := & Listener {
addConn : addConn ,
}
shSettings := streamSettings . ProtocolSettings . ( * Config )
l . config = shSettings
if l . config != nil {
if streamSettings . SocketSettings == nil {
streamSettings . SocketSettings = & internet . SocketConfig { }
}
}
var listener net . Listener
var err error
var localAddr = gonet . TCPAddr { }
2024-07-19 20:53:47 +03:00
handler := & requestHandler {
2024-07-29 09:32:04 +03:00
config : shSettings ,
2024-07-19 20:53:47 +03:00
host : shSettings . Host ,
2024-07-29 07:35:17 +03:00
path : shSettings . GetNormalizedPath ( "" , false ) ,
2024-07-19 20:53:47 +03:00
ln : l ,
sessionMu : & sync . Mutex { } ,
sessions : sync . Map { } ,
localAddr : localAddr ,
}
tlsConfig := getTLSConfig ( streamSettings )
l . isH3 = len ( tlsConfig . NextProtos ) == 1 && tlsConfig . NextProtos [ 0 ] == "h3"
2024-06-18 08:36:36 +03:00
if port == net . Port ( 0 ) { // unix
listener , err = internet . ListenSystem ( ctx , & net . UnixAddr {
Name : address . Domain ( ) ,
Net : "unix" ,
} , streamSettings . SocketSettings )
if err != nil {
2024-06-29 21:32:57 +03:00
return nil , errors . New ( "failed to listen unix domain socket(for SH) on " , address ) . Base ( err )
2024-06-18 08:36:36 +03:00
}
2024-06-29 21:32:57 +03:00
errors . LogInfo ( ctx , "listening unix domain socket(for SH) on " , address )
2024-07-19 20:53:47 +03:00
} else if l . isH3 { // quic
Conn , err := internet . ListenSystemPacket ( context . Background ( ) , & net . UDPAddr {
IP : address . IP ( ) ,
Port : int ( port ) ,
} , streamSettings . SocketSettings )
if err != nil {
2024-07-21 03:29:50 +03:00
return nil , errors . New ( "failed to listen UDP(for SH3) on " , address , ":" , port ) . Base ( err )
2024-07-19 20:53:47 +03:00
}
2024-07-21 03:29:50 +03:00
h3listener , err := quic . ListenEarly ( Conn , tlsConfig , nil )
2024-07-19 20:53:47 +03:00
if err != nil {
return nil , errors . New ( "failed to listen QUIC(for SH3) on " , address , ":" , port ) . Base ( err )
}
l . h3listener = h3listener
errors . LogInfo ( ctx , "listening QUIC(for SH3) on " , address , ":" , port )
l . h3server = & http3 . Server {
Handler : handler ,
}
go func ( ) {
if err := l . h3server . ServeListener ( l . h3listener ) ; err != nil {
errors . LogWarningInner ( ctx , err , "failed to serve http3 for splithttp" )
}
} ( )
2024-06-18 08:36:36 +03:00
} else { // tcp
localAddr = gonet . TCPAddr {
IP : address . IP ( ) ,
Port : int ( port ) ,
}
listener , err = internet . ListenSystem ( ctx , & net . TCPAddr {
IP : address . IP ( ) ,
Port : int ( port ) ,
} , streamSettings . SocketSettings )
if err != nil {
2024-06-29 21:32:57 +03:00
return nil , errors . New ( "failed to listen TCP(for SH) on " , address , ":" , port ) . Base ( err )
2024-06-18 08:36:36 +03:00
}
2024-06-29 21:32:57 +03:00
errors . LogInfo ( ctx , "listening TCP(for SH) on " , address , ":" , port )
2024-07-21 03:29:50 +03:00
}
// tcp/unix (h1/h2)
if listener != nil {
if config := v2tls . ConfigFromStreamSettings ( streamSettings ) ; config != nil {
if tlsConfig := config . GetTLSConfig ( ) ; tlsConfig != nil {
listener = tls . NewListener ( listener , tlsConfig )
}
}
2024-07-22 23:19:31 +03:00
// h2cHandler can handle both plaintext HTTP/1.1 and h2c
h2cHandler := h2c . NewHandler ( handler , & http2 . Server { } )
2024-07-21 03:29:50 +03:00
l . listener = listener
2024-07-22 23:19:31 +03:00
l . server = http . Server {
Handler : h2cHandler ,
ReadHeaderTimeout : time . Second * 4 ,
MaxHeaderBytes : 8192 ,
}
2024-07-21 03:29:50 +03:00
2024-07-19 20:53:47 +03:00
go func ( ) {
if err := l . server . Serve ( l . listener ) ; err != nil {
errors . LogWarningInner ( ctx , err , "failed to serve http for splithttp" )
}
} ( )
}
2024-06-18 08:36:36 +03:00
return l , err
}
// Addr implements net.Listener.Addr().
func ( ln * Listener ) Addr ( ) net . Addr {
return ln . listener . Addr ( )
}
// Close implements net.Listener.Close().
func ( ln * Listener ) Close ( ) error {
2024-07-19 20:53:47 +03:00
if ln . h3server != nil {
if err := ln . h3server . Close ( ) ; err != nil {
return err
}
} else if ln . listener != nil {
return ln . listener . Close ( )
}
return errors . New ( "listener does not have an HTTP/3 server or a net.listener" )
}
func getTLSConfig ( streamSettings * internet . MemoryStreamConfig ) * tls . Config {
config := v2tls . ConfigFromStreamSettings ( streamSettings )
if config == nil {
return & tls . Config { }
}
return config . GetTLSConfig ( )
2024-06-18 08:36:36 +03:00
}
func init ( ) {
common . Must ( internet . RegisterTransportListener ( protocolName , ListenSH ) )
}