diff --git a/proxy/vless/outbound/outbound.go b/proxy/vless/outbound/outbound.go index fc44ed57..6bdbe776 100644 --- a/proxy/vless/outbound/outbound.go +++ b/proxy/vless/outbound/outbound.go @@ -54,9 +54,11 @@ type Handler struct { encryption *encryption.ClientInstance reverse *Reverse - testpre uint32 - locker sync.Mutex - conns []stat.Connection + testpre uint32 + initConns sync.Once + preConns chan stat.Connection + preConnWait chan struct{} + preConnStop chan struct{} } // New creates a new VLess outbound handler. @@ -117,6 +119,13 @@ func New(ctx context.Context, config *Config) (*Handler, error) { // Close implements common.Closable.Close(). func (h *Handler) Close() error { + if h.preConnStop != nil { + close(h.preConnStop) + for range h.testpre { + conn := <-h.preConns + common.CloseIfExists(conn) + } + } if h.reverse != nil { return h.reverse.Close() } @@ -136,30 +145,19 @@ func (h *Handler) Process(ctx context.Context, link *transport.Link, dialer inte var conn stat.Connection if h.testpre > 0 && h.reverse == nil { - h.locker.Lock() - if h.conns == nil { - h.conns = make([]stat.Connection, 0) - go func() { - for { // TODO: close & inactive - time.Sleep(100 * time.Millisecond) // TODO: customize & randomize - h.locker.Lock() - if len(h.conns) >= int(h.testpre) { - h.locker.Unlock() - continue - } - h.locker.Unlock() - if conn, err := dialer.Dial(context.Background(), rec.Destination); err == nil { // TODO: timeout & concurrency? & ctx mitm? - h.locker.Lock() - h.conns = append(h.conns, conn) // TODO: vision paddings - h.locker.Unlock() - } - } - }() - } else if len(h.conns) > 0 { - conn = h.conns[0] - h.conns = h.conns[1:] + h.initConns.Do(func() { + h.preConns = make(chan stat.Connection, h.testpre) + h.preConnStop = make(chan struct{}) + go h.preConnWorker(dialer, rec.Destination) + }) + select { + case h.preConnWait <- struct{}{}: + default: + } + select { + case conn = <-h.preConns: + default: } - h.locker.Unlock() } if conn == nil { @@ -464,3 +462,51 @@ func (r *Reverse) Start() error { func (r *Reverse) Close() error { return r.monitorTask.Close() } + +func (h *Handler) preConnWorker(dialer internet.Dialer, dest net.Destination) { + // conn in conns may be nil + conns := make(chan stat.Connection) + dial := func() { + conn, err := dialer.Dial(context.Background(), dest) + if err != nil { + errors.LogError(context.Background(), "failed to dial VLESS pre connection: ", err) + common.CloseIfExists(conn) + } + conns <- conn + } + go func() { + go dial() // get a conn immediately + for range h.testpre - 1 { + select { + case <-h.preConnWait: + go dial() + case <-h.preConnStop: + return + } + } + }() + for { + select { + case conn := <-conns: + if conn != nil { + select { + case h.preConns <- conn: + case <-h.preConnStop: + common.CloseIfExists(conn) + return + } + go dial() + } else { + // sleep until next client try if dial failed + select { + case <-h.preConnWait: + go dial() + case <-h.preConnStop: + return + } + } + case <-h.preConnStop: + return + } + } +}