0
mirror of https://github.com/XTLS/Xray-core.git synced 2025-06-15 04:47:16 +03:00

SplitHTTP client: Add xmux (multiplex controller) for H3 & H2 (#3613)

https://github.com/XTLS/Xray-core/pull/3613#issuecomment-2351954957

Closes https://github.com/XTLS/Xray-core/issues/3560#issuecomment-2247495778

---------

Co-authored-by: mmmray <142015632+mmmray@users.noreply.github.com>
This commit is contained in:
ll11l1lIllIl1lll
2024-09-16 12:42:01 +00:00
committed by GitHub
parent a931507dd6
commit b1c6471eeb
9 changed files with 475 additions and 64 deletions

View File

@ -41,32 +41,51 @@ type dialerConf struct {
}
var (
globalDialerMap map[dialerConf]DialerClient
globalDialerMap map[dialerConf]*muxManager
globalDialerAccess sync.Mutex
)
func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient {
func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *internet.MemoryStreamConfig) (DialerClient, *muxResource) {
if browser_dialer.HasBrowserDialer() {
return &BrowserDialerClient{}
return &BrowserDialerClient{}, nil
}
tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1")
isH3 := tlsConfig != nil && (len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3")
globalDialerAccess.Lock()
defer globalDialerAccess.Unlock()
if globalDialerMap == nil {
globalDialerMap = make(map[dialerConf]DialerClient)
globalDialerMap = make(map[dialerConf]*muxManager)
}
key := dialerConf{dest, streamSettings}
muxManager, found := globalDialerMap[key]
if !found {
transportConfig := streamSettings.ProtocolSettings.(*Config)
var mux Multiplexing
if transportConfig.Xmux != nil {
mux = *transportConfig.Xmux
}
muxManager = NewMuxManager(mux, func() interface{} {
return createHTTPClient(dest, streamSettings)
})
globalDialerMap[key] = muxManager
}
res := muxManager.GetResource(ctx)
return res.Resource.(DialerClient), res
}
func createHTTPClient(dest net.Destination, streamSettings *internet.MemoryStreamConfig) DialerClient {
tlsConfig := tls.ConfigFromStreamSettings(streamSettings)
isH2 := tlsConfig != nil && !(len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "http/1.1")
isH3 := tlsConfig != nil && (len(tlsConfig.NextProtocol) == 1 && tlsConfig.NextProtocol[0] == "h3")
if isH3 {
dest.Network = net.Network_UDP
}
if client, found := globalDialerMap[dialerConf{dest, streamSettings}]; found {
return client
}
var gotlsConfig *gotls.Config
@ -74,6 +93,8 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
gotlsConfig = tlsConfig.GetTLSConfig(tls.WithDestination(dest))
}
transportConfig := streamSettings.ProtocolSettings.(*Config)
dialContext := func(ctxInner context.Context) (net.Conn, error) {
conn, err := internet.DialSystem(ctxInner, dest, streamSettings.SocketSettings)
if err != nil {
@ -94,8 +115,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
return conn, nil
}
var downloadTransport http.RoundTripper
var uploadTransport http.RoundTripper
var transport http.RoundTripper
if isH3 {
quicConfig := &quic.Config{
@ -107,7 +127,7 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
MaxIncomingStreams: -1,
KeepAlivePeriod: h3KeepalivePeriod,
}
roundTripper := &http3.RoundTripper{
transport = &http3.RoundTripper{
QUICConfig: quicConfig,
TLSClientConfig: gotlsConfig,
Dial: func(ctx context.Context, addr string, tlsCfg *gotls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
@ -147,23 +167,20 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
return quic.DialEarly(ctx, udpConn, udpAddr, tlsCfg, cfg)
},
}
downloadTransport = roundTripper
uploadTransport = roundTripper
} else if isH2 {
downloadTransport = &http2.Transport{
transport = &http2.Transport{
DialTLSContext: func(ctxInner context.Context, network string, addr string, cfg *gotls.Config) (net.Conn, error) {
return dialContext(ctxInner)
},
IdleConnTimeout: connIdleTimeout,
ReadIdleTimeout: h2KeepalivePeriod,
}
uploadTransport = downloadTransport
} else {
httpDialContext := func(ctxInner context.Context, network string, addr string) (net.Conn, error) {
return dialContext(ctxInner)
}
downloadTransport = &http.Transport{
transport = &http.Transport{
DialTLSContext: httpDialContext,
DialContext: httpDialContext,
IdleConnTimeout: connIdleTimeout,
@ -171,17 +188,12 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
// http.Client and our custom dial context.
DisableKeepAlives: true,
}
// we use uploadRawPool for that
uploadTransport = nil
}
client := &DefaultDialerClient{
transportConfig: streamSettings.ProtocolSettings.(*Config),
download: &http.Client{
Transport: downloadTransport,
},
upload: &http.Client{
Transport: uploadTransport,
transportConfig: transportConfig,
client: &http.Client{
Transport: transport,
},
isH2: isH2,
isH3: isH3,
@ -189,7 +201,6 @@ func getHTTPClient(ctx context.Context, dest net.Destination, streamSettings *in
dialUploadConn: dialContext,
}
globalDialerMap[dialerConf{dest, streamSettings}] = client
return client
}
@ -223,7 +234,7 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
requestURL.Path = transportConfiguration.GetNormalizedPath() + sessionIdUuid.String()
requestURL.RawQuery = transportConfiguration.GetNormalizedQuery()
httpClient := getHTTPClient(ctx, dest, streamSettings)
httpClient, muxResource := getHTTPClient(ctx, dest, streamSettings)
maxUploadSize := scMaxEachPostBytes.roll()
// WithSizeLimit(0) will still allow single bytes to pass, and a lot of
@ -231,7 +242,15 @@ func Dial(ctx context.Context, dest net.Destination, streamSettings *internet.Me
// uploadWriter wrapper, exact size limits can be enforced
uploadPipeReader, uploadPipeWriter := pipe.New(pipe.WithSizeLimit(maxUploadSize - 1))
if muxResource != nil {
muxResource.OpenRequests.Add(1)
}
go func() {
if muxResource != nil {
defer muxResource.OpenRequests.Add(-1)
}
requestsLimiter := semaphore.New(int(scMaxConcurrentPosts.roll()))
var requestCounter int64