diff --git a/app/proxyman/outbound/handler.go b/app/proxyman/outbound/handler.go index 361c88ad..3c3c6918 100644 --- a/app/proxyman/outbound/handler.go +++ b/app/proxyman/outbound/handler.go @@ -108,7 +108,7 @@ func NewHandler(ctx context.Context, config *core.OutboundHandlerConfig) (outbou } h.proxyConfig = proxyConfig - ctx = session.ContextWithHandler(ctx, h) + ctx = session.ContextWithFullHandler(ctx, h) rawProxyHandler, err := common.CreateObject(ctx, proxyConfig) if err != nil { diff --git a/app/reverse/bridge.go b/app/reverse/bridge.go index 74e20497..fc83a740 100644 --- a/app/reverse/bridge.go +++ b/app/reverse/bridge.go @@ -198,9 +198,11 @@ func (w *BridgeWorker) handleInternalConn(link *transport.Link) { func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*transport.Link, error) { if !isInternalDomain(dest) { - ctx = session.ContextWithInbound(ctx, &session.Inbound{ - Tag: w.Tag, - }) + if session.InboundFromContext(ctx) == nil { + ctx = session.ContextWithInbound(ctx, &session.Inbound{ + Tag: w.Tag, + }) + } return w.Dispatcher.Dispatch(ctx, dest) } @@ -221,9 +223,11 @@ func (w *BridgeWorker) Dispatch(ctx context.Context, dest net.Destination) (*tra func (w *BridgeWorker) DispatchLink(ctx context.Context, dest net.Destination, link *transport.Link) error { if !isInternalDomain(dest) { - ctx = session.ContextWithInbound(ctx, &session.Inbound{ - Tag: w.Tag, - }) + if session.InboundFromContext(ctx) == nil { + ctx = session.ContextWithInbound(ctx, &session.Inbound{ + Tag: w.Tag, + }) + } return w.Dispatcher.DispatchLink(ctx, dest, link) } diff --git a/common/mux/client.go b/common/mux/client.go index 93357574..28380331 100644 --- a/common/mux/client.go +++ b/common/mux/client.go @@ -264,7 +264,11 @@ func fetchInput(ctx context.Context, s *Session, output buf.Writer) { transferType = protocol.TransferTypePacket } s.transferType = transferType - writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx)) + var inbound *session.Inbound + if session.IsReverseMuxFromContext(ctx) { + inbound = session.InboundFromContext(ctx) + } + writer := NewWriter(s.ID, ob.Target, output, transferType, xudp.GetGlobalID(ctx), inbound) defer s.Close(false) defer writer.Close() @@ -384,7 +388,7 @@ func (m *ClientWorker) fetchOutput() { var meta FrameMetadata for { - err := meta.Unmarshal(reader) + err := meta.Unmarshal(reader, false) if err != nil { if errors.Cause(err) != io.EOF { errors.LogInfoInner(context.Background(), err, "failed to read metadata") diff --git a/common/mux/frame.go b/common/mux/frame.go index bdf5cc8c..f248fbdf 100644 --- a/common/mux/frame.go +++ b/common/mux/frame.go @@ -11,6 +11,7 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/session" ) type SessionStatus byte @@ -60,6 +61,7 @@ type FrameMetadata struct { Option bitmask.Byte SessionStatus SessionStatus GlobalID [8]byte + Inbound *session.Inbound } func (f FrameMetadata) WriteTo(b *buf.Buffer) error { @@ -79,11 +81,23 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error { case net.Network_UDP: common.Must(b.WriteByte(byte(TargetNetworkUDP))) } - if err := addrParser.WriteAddressPort(b, f.Target.Address, f.Target.Port); err != nil { return err } - if b.UDP != nil { // make sure it's user's proxy request + if f.Inbound != nil { + if f.Inbound.Source.Network == net.Network_TCP || f.Inbound.Source.Network == net.Network_UDP { + common.Must(b.WriteByte(byte(f.Inbound.Source.Network - 1))) + if err := addrParser.WriteAddressPort(b, f.Inbound.Source.Address, f.Inbound.Source.Port); err != nil { + return err + } + if f.Inbound.Local.Network == net.Network_TCP || f.Inbound.Local.Network == net.Network_UDP { + common.Must(b.WriteByte(byte(f.Inbound.Local.Network - 1))) + if err := addrParser.WriteAddressPort(b, f.Inbound.Local.Address, f.Inbound.Local.Port); err != nil { + return err + } + } + } + } else if b.UDP != nil { // make sure it's user's proxy request b.Write(f.GlobalID[:]) // no need to check whether it's empty } } else if b.UDP != nil { @@ -97,7 +111,7 @@ func (f FrameMetadata) WriteTo(b *buf.Buffer) error { } // Unmarshal reads FrameMetadata from the given reader. -func (f *FrameMetadata) Unmarshal(reader io.Reader) error { +func (f *FrameMetadata) Unmarshal(reader io.Reader, readSourceAndLocal bool) error { metaLen, err := serial.ReadUint16(reader) if err != nil { return err @@ -112,12 +126,12 @@ func (f *FrameMetadata) Unmarshal(reader io.Reader) error { if _, err := b.ReadFullFrom(reader, int32(metaLen)); err != nil { return err } - return f.UnmarshalFromBuffer(b) + return f.UnmarshalFromBuffer(b, readSourceAndLocal) } // UnmarshalFromBuffer reads a FrameMetadata from the given buffer. // Visible for testing only. -func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error { +func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer, readSourceAndLocal bool) error { if b.Len() < 4 { return errors.New("insufficient buffer: ", b.Len()) } @@ -150,6 +164,54 @@ func (f *FrameMetadata) UnmarshalFromBuffer(b *buf.Buffer) error { } } + if f.SessionStatus == SessionStatusNew && readSourceAndLocal { + f.Inbound = &session.Inbound{} + + if b.Len() == 0 { + return nil // for heartbeat, etc. + } + network := TargetNetwork(b.Byte(0)) + if network == 0 { + return nil // may be padding + } + b.Advance(1) + addr, port, err := addrParser.ReadAddressPort(nil, b) + if err != nil { + return errors.New("reading source: failed to parse address and port").Base(err) + } + switch network { + case TargetNetworkTCP: + f.Inbound.Source = net.TCPDestination(addr, port) + case TargetNetworkUDP: + f.Inbound.Source = net.UDPDestination(addr, port) + default: + return errors.New("reading source: unknown network type: ", network) + } + + if b.Len() == 0 { + return nil + } + network = TargetNetwork(b.Byte(0)) + if network == 0 { + return nil + } + b.Advance(1) + addr, port, err = addrParser.ReadAddressPort(nil, b) + if err != nil { + return errors.New("reading local: failed to parse address and port").Base(err) + } + switch network { + case TargetNetworkTCP: + f.Inbound.Local = net.TCPDestination(addr, port) + case TargetNetworkUDP: + f.Inbound.Local = net.UDPDestination(addr, port) + default: + return errors.New("reading local: unknown network type: ", network) + } + + return nil + } + // Application data is essential, to test whether the pipe is closed. if f.SessionStatus == SessionStatusNew && f.Option.Has(OptionData) && f.Target.Network == net.Network_UDP && b.Len() >= 8 { diff --git a/common/mux/mux_test.go b/common/mux/mux_test.go index f326ffd7..db01d372 100644 --- a/common/mux/mux_test.go +++ b/common/mux/mux_test.go @@ -10,6 +10,7 @@ import ( . "github.com/xtls/xray-core/common/mux" "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" + "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/transport/pipe" ) @@ -32,13 +33,13 @@ func TestReaderWriter(t *testing.T) { pReader, pWriter := pipe.New(pipe.WithSizeLimit(1024)) dest := net.TCPDestination(net.DomainAddress("example.com"), 80) - writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream, [8]byte{}) + writer := NewWriter(1, dest, pWriter, protocol.TransferTypeStream, [8]byte{}, &session.Inbound{}) dest2 := net.TCPDestination(net.LocalHostIP, 443) - writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream, [8]byte{}) + writer2 := NewWriter(2, dest2, pWriter, protocol.TransferTypeStream, [8]byte{}, &session.Inbound{}) dest3 := net.TCPDestination(net.LocalHostIPv6, 18374) - writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream, [8]byte{}) + writer3 := NewWriter(3, dest3, pWriter, protocol.TransferTypeStream, [8]byte{}, &session.Inbound{}) writePayload := func(writer *Writer, payload ...byte) error { b := buf.New() @@ -62,7 +63,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - common.Must(meta.Unmarshal(bytesReader)) + common.Must(meta.Unmarshal(bytesReader, false)) if r := cmp.Diff(meta, FrameMetadata{ SessionID: 1, SessionStatus: SessionStatusNew, @@ -81,7 +82,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - common.Must(meta.Unmarshal(bytesReader)) + common.Must(meta.Unmarshal(bytesReader, false)) if r := cmp.Diff(meta, FrameMetadata{ SessionStatus: SessionStatusNew, SessionID: 2, @@ -94,7 +95,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - common.Must(meta.Unmarshal(bytesReader)) + common.Must(meta.Unmarshal(bytesReader, false)) if r := cmp.Diff(meta, FrameMetadata{ SessionID: 1, SessionStatus: SessionStatusKeep, @@ -112,7 +113,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - common.Must(meta.Unmarshal(bytesReader)) + common.Must(meta.Unmarshal(bytesReader, false)) if r := cmp.Diff(meta, FrameMetadata{ SessionID: 3, SessionStatus: SessionStatusNew, @@ -131,7 +132,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - common.Must(meta.Unmarshal(bytesReader)) + common.Must(meta.Unmarshal(bytesReader, false)) if r := cmp.Diff(meta, FrameMetadata{ SessionID: 1, SessionStatus: SessionStatusEnd, @@ -143,7 +144,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - common.Must(meta.Unmarshal(bytesReader)) + common.Must(meta.Unmarshal(bytesReader, false)) if r := cmp.Diff(meta, FrameMetadata{ SessionID: 3, SessionStatus: SessionStatusEnd, @@ -155,7 +156,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - common.Must(meta.Unmarshal(bytesReader)) + common.Must(meta.Unmarshal(bytesReader, false)) if r := cmp.Diff(meta, FrameMetadata{ SessionID: 2, SessionStatus: SessionStatusKeep, @@ -173,7 +174,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - common.Must(meta.Unmarshal(bytesReader)) + common.Must(meta.Unmarshal(bytesReader, false)) if r := cmp.Diff(meta, FrameMetadata{ SessionID: 2, SessionStatus: SessionStatusEnd, @@ -187,7 +188,7 @@ func TestReaderWriter(t *testing.T) { { var meta FrameMetadata - err := meta.Unmarshal(bytesReader) + err := meta.Unmarshal(bytesReader, false) if err == nil { t.Error("nil error") } diff --git a/common/mux/server.go b/common/mux/server.go index 70c5ed24..f01c325d 100644 --- a/common/mux/server.go +++ b/common/mux/server.go @@ -166,6 +166,14 @@ func (w *ServerWorker) handleStatusKeepAlive(meta *FrameMetadata, reader *buf.Bu func (w *ServerWorker) handleStatusNew(ctx context.Context, meta *FrameMetadata, reader *buf.BufferedReader) error { ctx = session.SubContextFromMuxInbound(ctx) + if meta.Inbound != nil && meta.Inbound.Source.IsValid() && meta.Inbound.Local.IsValid() { + if inbound := session.InboundFromContext(ctx); inbound != nil { + newInbound := *inbound + newInbound.Source = meta.Inbound.Source + newInbound.Local = meta.Inbound.Local + ctx = session.ContextWithInbound(ctx, &newInbound) + } + } errors.LogInfo(ctx, "received request for ", meta.Target) { msg := &log.AccessMessage{ @@ -329,7 +337,7 @@ func (w *ServerWorker) handleStatusEnd(meta *FrameMetadata, reader *buf.Buffered func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedReader) error { var meta FrameMetadata - err := meta.Unmarshal(reader) + err := meta.Unmarshal(reader, session.IsReverseMuxFromContext(ctx)) if err != nil { return errors.New("failed to read metadata").Base(err) } @@ -340,7 +348,7 @@ func (w *ServerWorker) handleFrame(ctx context.Context, reader *buf.BufferedRead case SessionStatusEnd: err = w.handleStatusEnd(&meta, reader) case SessionStatusNew: - err = w.handleStatusNew(ctx, &meta, reader) + err = w.handleStatusNew(session.ContextWithIsReverseMux(ctx, false), &meta, reader) case SessionStatusKeep: err = w.handleStatusKeep(&meta, reader) default: diff --git a/common/mux/writer.go b/common/mux/writer.go index a6dc551d..0429f4fa 100644 --- a/common/mux/writer.go +++ b/common/mux/writer.go @@ -6,6 +6,7 @@ import ( "github.com/xtls/xray-core/common/net" "github.com/xtls/xray-core/common/protocol" "github.com/xtls/xray-core/common/serial" + "github.com/xtls/xray-core/common/session" ) type Writer struct { @@ -16,9 +17,10 @@ type Writer struct { hasError bool transferType protocol.TransferType globalID [8]byte + inbound *session.Inbound } -func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType, globalID [8]byte) *Writer { +func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType protocol.TransferType, globalID [8]byte, inbound *session.Inbound) *Writer { return &Writer{ id: id, dest: dest, @@ -26,6 +28,7 @@ func NewWriter(id uint16, dest net.Destination, writer buf.Writer, transferType followup: false, transferType: transferType, globalID: globalID, + inbound: inbound, } } @@ -43,6 +46,7 @@ func (w *Writer) getNextFrameMeta() FrameMetadata { SessionID: w.id, Target: w.dest, GlobalID: w.globalID, + Inbound: w.inbound, } if w.followup { diff --git a/common/session/context.go b/common/session/context.go index 6a812f99..c28f2081 100644 --- a/common/session/context.go +++ b/common/session/context.go @@ -17,13 +17,13 @@ const ( inboundSessionKey ctx.SessionKey = 1 outboundSessionKey ctx.SessionKey = 2 contentSessionKey ctx.SessionKey = 3 - muxPreferredSessionKey ctx.SessionKey = 4 // unused + isReverseMuxKey ctx.SessionKey = 4 // is reverse mux sockoptSessionKey ctx.SessionKey = 5 // used by dokodemo to only receive sockopt.Mark trackedConnectionErrorKey ctx.SessionKey = 6 // used by observer to get outbound error dispatcherKey ctx.SessionKey = 7 // used by ss2022 inbounds to get dispatcher timeoutOnlyKey ctx.SessionKey = 8 // mux context's child contexts to only cancel when its own traffic times out allowedNetworkKey ctx.SessionKey = 9 // muxcool server control incoming request tcp/udp - handlerSessionKey ctx.SessionKey = 10 // outbound gets full handler + fullHandlerKey ctx.SessionKey = 10 // outbound gets full handler mitmAlpn11Key ctx.SessionKey = 11 // used by TLS dialer mitmServerNameKey ctx.SessionKey = 12 // used by TLS dialer ) @@ -75,25 +75,21 @@ func ContentFromContext(ctx context.Context) *Content { return nil } -// ContextWithMuxPreferred returns a new context with the given bool -func ContextWithMuxPreferred(ctx context.Context, forced bool) context.Context { - return context.WithValue(ctx, muxPreferredSessionKey, forced) +func ContextWithIsReverseMux(ctx context.Context, isReverseMux bool) context.Context { + return context.WithValue(ctx, isReverseMuxKey, isReverseMux) } -// MuxPreferredFromContext returns value in this context, or false if not contained. -func MuxPreferredFromContext(ctx context.Context) bool { - if val, ok := ctx.Value(muxPreferredSessionKey).(bool); ok { +func IsReverseMuxFromContext(ctx context.Context) bool { + if val, ok := ctx.Value(isReverseMuxKey).(bool); ok { return val } return false } -// ContextWithSockopt returns a new context with Socket configs included func ContextWithSockopt(ctx context.Context, s *Sockopt) context.Context { return context.WithValue(ctx, sockoptSessionKey, s) } -// SockoptFromContext returns Socket configs in this context, or nil if not contained. func SockoptFromContext(ctx context.Context) *Sockopt { if sockopt, ok := ctx.Value(sockoptSessionKey).(*Sockopt); ok { return sockopt @@ -164,12 +160,12 @@ func AllowedNetworkFromContext(ctx context.Context) net.Network { return net.Network_Unknown } -func ContextWithHandler(ctx context.Context, handler outbound.Handler) context.Context { - return context.WithValue(ctx, handlerSessionKey, handler) +func ContextWithFullHandler(ctx context.Context, handler outbound.Handler) context.Context { + return context.WithValue(ctx, fullHandlerKey, handler) } -func HandlerFromContext(ctx context.Context) outbound.Handler { - if val, ok := ctx.Value(handlerSessionKey).(outbound.Handler); ok { +func FullHandlerFromContext(ctx context.Context) outbound.Handler { + if val, ok := ctx.Value(fullHandlerKey).(outbound.Handler); ok { return val } return nil diff --git a/proxy/vless/inbound/inbound.go b/proxy/vless/inbound/inbound.go index 2f15b5d8..223aade0 100644 --- a/proxy/vless/inbound/inbound.go +++ b/proxy/vless/inbound/inbound.go @@ -666,7 +666,7 @@ func (r *Reverse) Dispatch(ctx context.Context, link *transport.Link) { link.Reader = &buf.EndpointOverrideReader{Reader: link.Reader, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} link.Writer = &buf.EndpointOverrideWriter{Writer: link.Writer, Dest: ob.Target.Address, OriginalDest: ob.OriginalTarget.Address} } - r.client.Dispatch(ctx, link) + r.client.Dispatch(session.ContextWithIsReverseMux(ctx, true), link) } } diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index 041cc605..c425151f 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -89,8 +89,11 @@ func New(ctx context.Context, config *Config) (*Handler, error) { handler.reverse = &Reverse{ tag: a.Reverse.Tag, dispatcher: v.GetFeature(routing.DispatcherType()).(routing.Dispatcher), - ctx: ctx, - handler: handler, + ctx: session.ContextWithInbound(ctx, &session.Inbound{ + Tag: a.Reverse.Tag, + User: handler.server.User, // TODO: email + }), + handler: handler, } handler.reverse.monitorTask = &task.Periodic{ Execute: handler.reverse.monitor, @@ -397,7 +400,7 @@ func (r *Reverse) monitor() error { Tag: r.tag, Dispatcher: r.dispatcher, } - worker, err := mux.NewServerWorker(r.ctx, w, link1) + worker, err := mux.NewServerWorker(session.ContextWithIsReverseMux(r.ctx, true), w, link1) if err != nil { errors.LogWarningInner(r.ctx, err, "failed to create mux server worker") return nil @@ -408,7 +411,7 @@ func (r *Reverse) monitor() error { ctx := session.ContextWithOutbounds(r.ctx, []*session.Outbound{{ Target: net.Destination{Address: net.DomainAddress("v1.rvs.cool")}, }}) - r.handler.Process(ctx, link2, session.HandlerFromContext(ctx).(*proxyman.Handler)) + r.handler.Process(ctx, link2, session.FullHandlerFromContext(ctx).(*proxyman.Handler)) common.Interrupt(reader1) common.Interrupt(reader2) }()