Add SplitHTTP Browser Dialer support (#3484)

This commit is contained in:
mmmray 2024-07-11 09:56:20 +02:00 committed by GitHub
parent 308f0c64c3
commit c8f6ba9ff0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 536 additions and 294 deletions

View File

@ -0,0 +1,121 @@
package browser_dialer
import (
"bytes"
"context"
_ "embed"
"encoding/base64"
"net/http"
"time"
"github.com/gorilla/websocket"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/uuid"
)
//go:embed dialer.html
var webpage []byte
var conns chan *websocket.Conn
var upgrader = &websocket.Upgrader{
ReadBufferSize: 0,
WriteBufferSize: 0,
HandshakeTimeout: time.Second * 4,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
func init() {
addr := platform.NewEnvFlag(platform.BrowserDialerAddress).GetValue(func() string { return "" })
if addr != "" {
token := uuid.New()
csrfToken := token.String()
webpage = bytes.ReplaceAll(webpage, []byte("csrfToken"), []byte(csrfToken))
conns = make(chan *websocket.Conn, 256)
go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/websocket" {
if r.URL.Query().Get("token") == csrfToken {
if conn, err := upgrader.Upgrade(w, r, nil); err == nil {
conns <- conn
} else {
errors.LogError(context.Background(), "Browser dialer http upgrade unexpected error")
}
}
} else {
w.Write(webpage)
}
}))
}
}
func HasBrowserDialer() bool {
return conns != nil
}
func DialWS(uri string, ed []byte) (*websocket.Conn, error) {
data := []byte("WS " + uri)
if ed != nil {
data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...)
}
return dialRaw(data)
}
func DialGet(uri string) (*websocket.Conn, error) {
data := []byte("GET " + uri)
return dialRaw(data)
}
func DialPost(uri string, payload []byte) error {
data := []byte("POST " + uri)
conn, err := dialRaw(data)
if err != nil {
return err
}
err = conn.WriteMessage(websocket.BinaryMessage, payload)
if err != nil {
return err
}
err = CheckOK(conn)
if err != nil {
return err
}
conn.Close()
return nil
}
func dialRaw(data []byte) (*websocket.Conn, error) {
var conn *websocket.Conn
for {
conn = <-conns
if conn.WriteMessage(websocket.TextMessage, data) != nil {
conn.Close()
} else {
break
}
}
err := CheckOK(conn)
if err != nil {
return nil, err
}
return conn, nil
}
func CheckOK(conn *websocket.Conn) error {
if _, p, err := conn.ReadMessage(); err != nil {
conn.Close()
return err
} else if s := string(p); s != "ok" {
conn.Close()
return errors.New(s)
}
return nil
}

View File

@ -0,0 +1,136 @@
<!DOCTYPE html>
<html>
<head>
<title>Browser Dialer</title>
</head>
<body>
<script>
// Copyright (c) 2021 XRAY. Mozilla Public License 2.0.
var url = "ws://" + window.location.host + "/websocket?token=csrfToken";
var clientIdleCount = 0;
var upstreamGetCount = 0;
var upstreamWsCount = 0;
var upstreamPostCount = 0;
setInterval(check, 1000);
function check() {
if (clientIdleCount > 0) {
return;
}
clientIdleCount += 1;
console.log("Prepare", url);
var ws = new WebSocket(url);
// arraybuffer is significantly faster in chrome than default
// blob, tested with chrome 123
ws.binaryType = "arraybuffer";
ws.onmessage = function (event) {
clientIdleCount -= 1;
let [method, url, protocol] = event.data.split(" ");
if (method == "WS") {
upstreamWsCount += 1;
console.log("Dial WS", url, protocol);
const wss = new WebSocket(url, protocol);
wss.binaryType = "arraybuffer";
var opened = false;
ws.onmessage = function (event) {
wss.send(event.data)
}
wss.onopen = function (event) {
opened = true;
ws.send("ok")
}
wss.onmessage = function (event) {
ws.send(event.data)
}
wss.onclose = function (event) {
upstreamWsCount -= 1;
console.log("Dial WS DONE, remaining: ", upstreamWsCount);
ws.close()
}
wss.onerror = function (event) {
!opened && ws.send("fail")
wss.close()
}
ws.onclose = function (event) {
wss.close()
}
} else if (method == "GET") {
(async () => {
console.log("Dial GET", url);
ws.send("ok");
const controller = new AbortController();
/*
Aborting a streaming response in JavaScript
requires two levers to be pulled:
First, the streaming read itself has to be cancelled using
reader.cancel(), only then controller.abort() will actually work.
If controller.abort() alone is called while a
reader.read() is ongoing, it will block until the server closes the
response, the page is refreshed or the network connection is lost.
*/
let reader = null;
ws.onclose = (event) => {
try {
reader && reader.cancel();
} catch(e) {}
try {
controller.abort();
} catch(e) {}
}
try {
upstreamGetCount += 1;
const response = await fetch(url, {signal: controller.signal});
const body = await response.body;
reader = body.getReader();
while (true) {
const { done, value } = await reader.read();
ws.send(value);
if (done) break;
}
} finally {
upstreamGetCount -= 1;
console.log("Dial GET DONE, remaining: ", upstreamGetCount);
ws.close();
}
})()
} else if (method == "POST") {
upstreamPostCount += 1;
console.log("Dial POST", url);
ws.send("ok");
ws.onmessage = async (event) => {
try {
const response = await fetch(
url,
{method: "POST", body: event.data}
);
if (response.ok) {
ws.send("ok");
} else {
console.error("bad status code");
ws.send("fail");
}
} finally {
upstreamPostCount -= 1;
console.log("Dial POST DONE, remaining: ", upstreamPostCount);
ws.close();
}
};
}
check()
}
ws.onerror = function (event) {
ws.close()
}
}
</script>
</body>
</html>

View File

@ -0,0 +1,39 @@
package splithttp
import (
"context"
"io"
"io/ioutil"
gonet "net"
"github.com/xtls/xray-core/transport/internet/browser_dialer"
"github.com/xtls/xray-core/transport/internet/websocket"
)
// implements splithttp.DialerClient in terms of browser dialer
// has no fields because everything is global state :O)
type BrowserDialerClient struct{}
func (c *BrowserDialerClient) OpenDownload(ctx context.Context, baseURL string) (io.ReadCloser, gonet.Addr, gonet.Addr, error) {
conn, err := browser_dialer.DialGet(baseURL)
dummyAddr := &gonet.IPAddr{}
if err != nil {
return nil, dummyAddr, dummyAddr, err
}
return websocket.NewConnection(conn, dummyAddr, nil), conn.RemoteAddr(), conn.LocalAddr(), nil
}
func (c *BrowserDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
bytes, err := ioutil.ReadAll(payload)
if err != nil {
return err
}
err = browser_dialer.DialPost(url, bytes)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,169 @@
package splithttp
import (
"bytes"
"context"
"io"
gonet "net"
"net/http"
"net/http/httptrace"
"sync"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/signal/done"
)
// interface to abstract between use of browser dialer, vs net/http
type DialerClient interface {
// (ctx, baseURL, payload) -> err
// baseURL already contains sessionId and seq
SendUploadRequest(context.Context, string, io.ReadWriteCloser, int64) error
// (ctx, baseURL) -> (downloadReader, remoteAddr, localAddr)
// baseURL already contains sessionId
OpenDownload(context.Context, string) (io.ReadCloser, net.Addr, net.Addr, error)
}
// implements splithttp.DialerClient in terms of direct network connections
type DefaultDialerClient struct {
transportConfig *Config
download *http.Client
upload *http.Client
isH2 bool
// pool of net.Conn, created using dialUploadConn
uploadRawPool *sync.Pool
dialUploadConn func(ctxInner context.Context) (net.Conn, error)
}
func (c *DefaultDialerClient) OpenDownload(ctx context.Context, baseURL string) (io.ReadCloser, gonet.Addr, gonet.Addr, error) {
var remoteAddr gonet.Addr
var localAddr gonet.Addr
// this is done when the TCP/UDP connection to the server was established,
// and we can unblock the Dial function and print correct net addresses in
// logs
gotConn := done.New()
var downResponse io.ReadCloser
gotDownResponse := done.New()
go func() {
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
remoteAddr = connInfo.Conn.RemoteAddr()
localAddr = connInfo.Conn.LocalAddr()
gotConn.Close()
},
}
// in case we hit an error, we want to unblock this part
defer gotConn.Close()
req, err := http.NewRequestWithContext(
httptrace.WithClientTrace(ctx, trace),
"GET",
baseURL,
nil,
)
if err != nil {
errors.LogInfoInner(ctx, err, "failed to construct download http request")
gotDownResponse.Close()
return
}
req.Header = c.transportConfig.GetRequestHeader()
response, err := c.download.Do(req)
gotConn.Close()
if err != nil {
errors.LogInfoInner(ctx, err, "failed to send download http request")
gotDownResponse.Close()
return
}
if response.StatusCode != 200 {
response.Body.Close()
errors.LogInfo(ctx, "invalid status code on download:", response.Status)
gotDownResponse.Close()
return
}
downResponse = response.Body
gotDownResponse.Close()
}()
// we want to block Dial until we know the remote address of the server,
// for logging purposes
<-gotConn.Wait()
lazyDownload := &LazyReader{
CreateReader: func() (io.ReadCloser, error) {
<-gotDownResponse.Wait()
if downResponse == nil {
return nil, errors.New("downResponse failed")
}
return downResponse, nil
},
}
return lazyDownload, remoteAddr, localAddr, nil
}
func (c *DefaultDialerClient) SendUploadRequest(ctx context.Context, url string, payload io.ReadWriteCloser, contentLength int64) error {
req, err := http.NewRequest("POST", url, payload)
req.ContentLength = contentLength
if err != nil {
return err
}
req.Header = c.transportConfig.GetRequestHeader()
if c.isH2 {
resp, err := c.upload.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
return errors.New("bad status code:", resp.Status)
}
} else {
// stringify the entire HTTP/1.1 request so it can be
// safely retried. if instead req.Write is called multiple
// times, the body is already drained after the first
// request
requestBytes := new(bytes.Buffer)
common.Must(req.Write(requestBytes))
var uploadConn any
for {
uploadConn = c.uploadRawPool.Get()
newConnection := uploadConn == nil
if newConnection {
uploadConn, err = c.dialUploadConn(context.WithoutCancel(ctx))
if err != nil {
return err
}
}
_, err = uploadConn.(net.Conn).Write(requestBytes.Bytes())
// if the write failed, we try another connection from
// the pool, until the write on a new connection fails.
// failed writes to a pooled connection are normal when
// the connection has been closed in the meantime.
if err == nil {
break
} else if newConnection {
return err
}
}
c.uploadRawPool.Put(uploadConn)
}
return nil
}

View File

@ -1,13 +1,10 @@
package splithttp package splithttp
import ( import (
"bytes"
"context" "context"
gotls "crypto/tls" gotls "crypto/tls"
"io" "io"
gonet "net"
"net/http" "net/http"
"net/http/httptrace"
"net/url" "net/url"
"strconv" "strconv"
"sync" "sync"
@ -17,10 +14,10 @@ import (
"github.com/xtls/xray-core/common/buf" "github.com/xtls/xray-core/common/buf"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/signal/done"
"github.com/xtls/xray-core/common/signal/semaphore" "github.com/xtls/xray-core/common/signal/semaphore"
"github.com/xtls/xray-core/common/uuid" "github.com/xtls/xray-core/common/uuid"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/browser_dialer"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/tls"
"github.com/xtls/xray-core/transport/pipe" "github.com/xtls/xray-core/transport/pipe"
@ -32,32 +29,31 @@ type dialerConf struct {
*internet.MemoryStreamConfig *internet.MemoryStreamConfig
} }
type reusedClient struct {
download *http.Client
upload *http.Client
isH2 bool
// pool of net.Conn, created using dialUploadConn
uploadRawPool *sync.Pool
dialUploadConn func(ctxInner context.Context) (net.Conn, error)
}
var ( var (
globalDialerMap map[dialerConf]reusedClient globalDialerMap map[dialerConf]DialerClient
globalDialerAccess sync.Mutex globalDialerAccess sync.Mutex
) )
func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) reusedClient { func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient {
if browser_dialer.HasBrowserDialer() {
return &BrowserDialerClient{}
}
globalDialerAccess.Lock() globalDialerAccess.Lock()
defer globalDialerAccess.Unlock() defer globalDialerAccess.Unlock()
if globalDialerMap == nil { if globalDialerMap == nil {
globalDialerMap = make(map[dialerConf]reusedClient) globalDialerMap = make(map[dialerConf]DialerClient)
} }
if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found { if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
return client return client
} }
if browser_dialer.HasBrowserDialer() {
return &BrowserDialerClient{}
}
tlsConfig := tls.ConfigFromStreamSettings(streamSettings) tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1") isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1")
@ -116,7 +112,8 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
uploadTransport = nil uploadTransport = nil
} }
client := reusedClient{ client := &DefaultDialerClient{
transportConfig: streamSettings.ProtocolSettings.(*Config),
download: &http.Client{ download: &http.Client{
Transport: downloadTransport, Transport: downloadTransport,
}, },
@ -160,80 +157,9 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
httpClient := getHTTPClient(ctx, dest, streamSettings) httpClient := getHTTPClient(ctx, dest, streamSettings)
var remoteAddr gonet.Addr
var localAddr gonet.Addr
// this is done when the TCP/UDP connection to the server was established,
// and we can unblock the Dial function and print correct net addresses in
// logs
gotConn := done.New()
var downResponse io.ReadCloser
gotDownResponse := done.New()
sessionIdUuid := uuid.New() sessionIdUuid := uuid.New()
sessionId := sessionIdUuid.String() sessionId := sessionIdUuid.String()
baseURL := requestURL.String() + sessionId
go func() {
trace := &httptrace.ClientTrace{
GotConn: func(connInfo httptrace.GotConnInfo) {
remoteAddr = connInfo.Conn.RemoteAddr()
localAddr = connInfo.Conn.LocalAddr()
gotConn.Close()
},
}
// in case we hit an error, we want to unblock this part
defer gotConn.Close()
req, err := http.NewRequestWithContext(
httptrace.WithClientTrace(context.WithoutCancel(ctx), trace),
"GET",
requestURL.String()+sessionId,
nil,
)
if err != nil {
errors.LogInfoInner(ctx, err, "failed to construct download http request")
gotDownResponse.Close()
return
}
req.Header = transportConfiguration.GetRequestHeader()
response, err := httpClient.download.Do(req)
gotConn.Close()
if err != nil {
errors.LogInfoInner(ctx, err, "failed to send download http request")
gotDownResponse.Close()
return
}
if response.StatusCode != 200 {
response.Body.Close()
errors.LogInfo(ctx, "invalid status code on download:", response.Status)
gotDownResponse.Close()
return
}
// skip "ooooooooook" response
trashHeader := []byte{0}
for {
_, err = io.ReadFull(response.Body, trashHeader)
if err != nil {
response.Body.Close()
errors.LogInfoInner(ctx, err, "failed to read initial response")
gotDownResponse.Close()
return
}
if trashHeader[0] == 'k' {
break
}
}
downResponse = response.Body
gotDownResponse.Close()
}()
uploadUrl := requestURL.String() + sessionId + "/"
uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize)) uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize))
@ -252,97 +178,55 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
<-requestsLimiter.Wait() <-requestsLimiter.Wait()
url := uploadUrl + strconv.FormatInt(requestCounter, 10) seq := requestCounter
requestCounter += 1 requestCounter += 1
go func() { go func() {
defer requestsLimiter.Signal() defer requestsLimiter.Signal()
req, err := http.NewRequest("POST", url, &buf.MultiBufferContainer{MultiBuffer: chunk})
err := httpClient.SendUploadRequest(
context.WithoutCancel(ctx),
baseURL+"/"+strconv.FormatInt(seq, 10),
&buf.MultiBufferContainer{MultiBuffer: chunk},
int64(chunk.Len()),
)
if err != nil { if err != nil {
errors.LogInfoInner(ctx, err, "failed to send upload") errors.LogInfoInner(ctx, err, "failed to send upload")
uploadPipeReader.Interrupt() uploadPipeReader.Interrupt()
return
}
req.ContentLength = int64(chunk.Len())
req.Header = transportConfiguration.GetRequestHeader()
if httpClient.isH2 {
resp, err := httpClient.upload.Do(req)
if err != nil {
errors.LogInfoInner(ctx, err, "failed to send upload")
uploadPipeReader.Interrupt()
return
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
errors.LogInfo(ctx, "failed to send upload, bad status code:", resp.Status)
uploadPipeReader.Interrupt()
return
}
} else {
var uploadConn any
// stringify the entire HTTP/1.1 request so it can be
// safely retried. if instead req.Write is called multiple
// times, the body is already drained after the first
// request
requestBytes := new(bytes.Buffer)
common.Must(req.Write(requestBytes))
for {
uploadConn = httpClient.uploadRawPool.Get()
newConnection := uploadConn == nil
if newConnection {
uploadConn, err = httpClient.dialUploadConn(context.WithoutCancel(ctx))
if err != nil {
errors.LogInfoInner(ctx, err, "failed to connect upload")
uploadPipeReader.Interrupt()
return
}
}
_, err = uploadConn.(net.Conn).Write(requestBytes.Bytes())
// if the write failed, we try another connection from
// the pool, until the write on a new connection fails.
// failed writes to a pooled connection are normal when
// the connection has been closed in the meantime.
if err == nil {
break
} else if newConnection {
errors.LogInfoInner(ctx, err, "failed to send upload")
uploadPipeReader.Interrupt()
return
}
}
httpClient.uploadRawPool.Put(uploadConn)
} }
}() }()
} }
}() }()
// we want to block Dial until we know the remote address of the server, lazyRawDownload, remoteAddr, localAddr, err := httpClient.OpenDownload(context.WithoutCancel(ctx), baseURL)
// for logging purposes if err != nil {
<-gotConn.Wait() return nil, err
}
lazyDownload := &LazyReader{
CreateReader: func() (io.ReadCloser, error) {
// skip "ooooooooook" response
trashHeader := []byte{0}
for {
_, err := io.ReadFull(lazyRawDownload, trashHeader)
if err != nil {
return nil, errors.New("failed to read initial response").Base(err)
}
if trashHeader[0] == 'k' {
break
}
}
return lazyRawDownload, nil
},
}
// necessary in order to send larger chunks in upload // necessary in order to send larger chunks in upload
bufferedUploadPipeWriter := buf.NewBufferedWriter(uploadPipeWriter) bufferedUploadPipeWriter := buf.NewBufferedWriter(uploadPipeWriter)
bufferedUploadPipeWriter.SetBuffered(false) bufferedUploadPipeWriter.SetBuffered(false)
lazyDownload := &LazyReader{
CreateReader: func() (io.ReadCloser, error) {
<-gotDownResponse.Wait()
if downResponse == nil {
return nil, errors.New("downResponse failed")
}
return downResponse, nil
},
}
conn := splitConn{ conn := splitConn{
writer: bufferedUploadPipeWriter, writer: bufferedUploadPipeWriter,
reader: lazyDownload, reader: lazyDownload,

View File

@ -32,7 +32,7 @@ type requestHandler struct {
} }
type httpSession struct { type httpSession struct {
uploadQueue *UploadQueue uploadQueue *uploadQueue
// for as long as the GET request is not opened by the client, this will be // for as long as the GET request is not opened by the client, this will be
// open ("undone"), and the session may be expired within a certain TTL. // open ("undone"), and the session may be expired within a certain TTL.
// after the client connects, this becomes "done" and the session lives as // after the client connects, this becomes "done" and the session lives as
@ -163,7 +163,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
writer.Header().Set("X-Accel-Buffering", "no") writer.Header().Set("X-Accel-Buffering", "no")
// magic header to make the HTTP middle box consider this as SSE to disable buffer // magic header to make the HTTP middle box consider this as SSE to disable buffer
writer.Header().Set("Content-Type", "text/event-stream") writer.Header().Set("Content-Type", "text/event-stream")
writer.WriteHeader(http.StatusOK) writer.WriteHeader(http.StatusOK)
// send a chunk immediately to enable CDN streaming. // send a chunk immediately to enable CDN streaming.
// many CDN buffer the response headers until the origin starts sending // many CDN buffer the response headers until the origin starts sending

View File

@ -15,7 +15,7 @@ type Packet struct {
Seq uint64 Seq uint64
} }
type UploadQueue struct { type uploadQueue struct {
pushedPackets chan Packet pushedPackets chan Packet
heap uploadHeap heap uploadHeap
nextSeq uint64 nextSeq uint64
@ -23,8 +23,8 @@ type UploadQueue struct {
maxPackets int maxPackets int
} }
func NewUploadQueue(maxPackets int) *UploadQueue { func NewUploadQueue(maxPackets int) *uploadQueue {
return &UploadQueue{ return &uploadQueue{
pushedPackets: make(chan Packet, maxPackets), pushedPackets: make(chan Packet, maxPackets),
heap: uploadHeap{}, heap: uploadHeap{},
nextSeq: 0, nextSeq: 0,
@ -33,7 +33,7 @@ func NewUploadQueue(maxPackets int) *UploadQueue {
} }
} }
func (h *UploadQueue) Push(p Packet) error { func (h *uploadQueue) Push(p Packet) error {
if h.closed { if h.closed {
return errors.New("splithttp packet queue closed") return errors.New("splithttp packet queue closed")
} }
@ -42,13 +42,13 @@ func (h *UploadQueue) Push(p Packet) error {
return nil return nil
} }
func (h *UploadQueue) Close() error { func (h *uploadQueue) Close() error {
h.closed = true h.closed = true
close(h.pushedPackets) close(h.pushedPackets)
return nil return nil
} }
func (h *UploadQueue) Read(b []byte) (int, error) { func (h *uploadQueue) Read(b []byte) (int, error) {
if h.closed { if h.closed {
return 0, io.EOF return 0, io.EOF
} }

View File

@ -15,16 +15,14 @@ var _ buf.Writer = (*connection)(nil)
// connection is a wrapper for net.Conn over WebSocket connection. // connection is a wrapper for net.Conn over WebSocket connection.
type connection struct { type connection struct {
conn *websocket.Conn conn *websocket.Conn
reader io.Reader reader io.Reader
remoteAddr net.Addr
} }
func newConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection { func NewConnection(conn *websocket.Conn, remoteAddr net.Addr, extraReader io.Reader) *connection {
return &connection{ return &connection{
conn: conn, conn: conn,
remoteAddr: remoteAddr, reader: extraReader,
reader: extraReader,
} }
} }
@ -92,7 +90,7 @@ func (c *connection) LocalAddr() net.Addr {
} }
func (c *connection) RemoteAddr() net.Addr { func (c *connection) RemoteAddr() net.Addr {
return c.remoteAddr return c.conn.RemoteAddr()
} }
func (c *connection) SetDeadline(t time.Time) error { func (c *connection) SetDeadline(t time.Time) error {

View File

@ -1,54 +1,23 @@
package websocket package websocket
import ( import (
"bytes"
"context" "context"
_ "embed" _ "embed"
"encoding/base64" "encoding/base64"
"io" "io"
gonet "net" gonet "net"
"net/http"
"time" "time"
"github.com/gorilla/websocket" "github.com/gorilla/websocket"
"github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/uuid"
"github.com/xtls/xray-core/transport/internet" "github.com/xtls/xray-core/transport/internet"
"github.com/xtls/xray-core/transport/internet/browser_dialer"
"github.com/xtls/xray-core/transport/internet/stat" "github.com/xtls/xray-core/transport/internet/stat"
"github.com/xtls/xray-core/transport/internet/tls" "github.com/xtls/xray-core/transport/internet/tls"
) )
//go:embed dialer.html
var webpage []byte
var conns chan *websocket.Conn
func init() {
addr := platform.NewEnvFlag(platform.BrowserDialerAddress).GetValue(func() string { return "" })
if addr != "" {
token := uuid.New()
csrfToken := token.String()
webpage = bytes.ReplaceAll(webpage, []byte("csrfToken"), []byte(csrfToken))
conns = make(chan *websocket.Conn, 256)
go http.ListenAndServe(addr, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/websocket" {
if r.URL.Query().Get("token") == csrfToken {
if conn, err := upgrader.Upgrade(w, r, nil); err == nil {
conns <- conn
} else {
errors.LogError(context.Background(), "Browser dialer http upgrade unexpected error")
}
}
} else {
w.Write(webpage)
}
}))
}
}
// Dial dials a WebSocket connection to the given destination. // Dial dials a WebSocket connection to the given destination.
func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) { func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (stat.Connection, error) {
errors.LogInfo(ctx, "creating connection to ", dest) errors.LogInfo(ctx, "creating connection to ", dest)
@ -98,18 +67,18 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
// Like the NetDial in the dialer // Like the NetDial in the dialer
pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings) pconn, err := internet.DialSystem(ctx, dest, streamSettings.SocketSettings)
if err != nil { if err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr) errors.LogErrorInner(ctx, err, "failed to dial to "+addr)
return nil, err return nil, err
} }
// TLS and apply the handshake // TLS and apply the handshake
cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn) cn := tls.UClient(pconn, tlsConfig, fingerprint).(*tls.UConn)
if err := cn.WebsocketHandshakeContext(ctx); err != nil { if err := cn.WebsocketHandshakeContext(ctx); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr) errors.LogErrorInner(ctx, err, "failed to dial to "+addr)
return nil, err return nil, err
} }
if !tlsConfig.InsecureSkipVerify { if !tlsConfig.InsecureSkipVerify {
if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil { if err := cn.VerifyHostname(tlsConfig.ServerName); err != nil {
errors.LogErrorInner(ctx, err, "failed to dial to " + addr) errors.LogErrorInner(ctx, err, "failed to dial to "+addr)
return nil, err return nil, err
} }
} }
@ -124,28 +93,13 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
} }
uri := protocol + "://" + host + wsSettings.GetNormalizedPath() uri := protocol + "://" + host + wsSettings.GetNormalizedPath()
if conns != nil { if browser_dialer.HasBrowserDialer() {
data := []byte(uri) conn, err := browser_dialer.DialWS(uri, ed)
if ed != nil { if err != nil {
data = append(data, " "+base64.RawURLEncoding.EncodeToString(ed)...)
}
var conn *websocket.Conn
for {
conn = <-conns
if conn.WriteMessage(websocket.TextMessage, data) != nil {
conn.Close()
} else {
break
}
}
if _, p, err := conn.ReadMessage(); err != nil {
conn.Close()
return nil, err return nil, err
} else if s := string(p); s != "ok" {
conn.Close()
return nil, errors.New(s)
} }
return newConnection(conn, conn.RemoteAddr(), nil), nil
return NewConnection(conn, conn.RemoteAddr(), nil), nil
} }
header := wsSettings.GetRequestHeader() header := wsSettings.GetRequestHeader()
@ -163,7 +117,7 @@ func dialWebSocket(ctx context.Context, dest net.Destination, streamSettings *in
return nil, errors.New("failed to dial to (", uri, "): ", reason).Base(err) return nil, errors.New("failed to dial to (", uri, "): ", reason).Base(err)
} }
return newConnection(conn, conn.RemoteAddr(), nil), nil return NewConnection(conn, conn.RemoteAddr(), nil), nil
} }
type delayDialConn struct { type delayDialConn struct {

View File

@ -1,59 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>Browser Dialer</title>
</head>
<body>
<script>
// Copyright (c) 2021 XRAY. Mozilla Public License 2.0.
var url = "ws://" + window.location.host + "/websocket?token=csrfToken"
var count = 0
setInterval(check, 1000)
function check() {
if (count <= 0) {
count += 1
console.log("Prepare", url)
var ws = new WebSocket(url)
// arraybuffer is significantly faster in chrome than default
// blob, tested with chrome 123
ws.binaryType = "arraybuffer";
var wss = undefined
var first = true
ws.onmessage = function (event) {
if (first) {
first = false
count -= 1
var arr = event.data.split(" ")
console.log("Dial", arr[0], arr[1])
wss = new WebSocket(arr[0], arr[1])
wss.binaryType = "arraybuffer";
var opened = false
wss.onopen = function (event) {
opened = true
ws.send("ok")
}
wss.onmessage = function (event) {
ws.send(event.data)
}
wss.onclose = function (event) {
ws.close()
}
wss.onerror = function (event) {
!opened && ws.send("fail")
wss.close()
}
check()
} else wss.send(event.data)
}
ws.onclose = function (event) {
if (first) count -= 1
else wss.close()
}
ws.onerror = function (event) {
ws.close()
}
}
}
</script>
</body>
</html>

View File

@ -73,7 +73,7 @@ func (h *requestHandler) ServeHTTP(writer http.ResponseWriter, request *http.Req
} }
} }
h.ln.addConn(newConnection(conn, remoteAddr, extraReader)) h.ln.addConn(NewConnection(conn, remoteAddr, extraReader))
} }
type Listener struct { type Listener struct {