Przeglądaj źródła

Merge pull request #4928 from fatedier/xtcp

improve context and polling logic in xtcp visitor
fatedier 1 tydzień temu
rodzic
commit
024c334d9d
1 zmienionych plików z 23 dodań i 26 usunięć
  1. 23 26
      client/visitor/xtcp.go

+ 23 - 26
client/visitor/xtcp.go

@@ -145,7 +145,7 @@ func (sv *XTCPVisitor) keepTunnelOpenWorker() {
 			return
 		case <-ticker.C:
 			xl.Debugf("keepTunnelOpenWorker try to check tunnel...")
-			conn, err := sv.getTunnelConn()
+			conn, err := sv.getTunnelConn(sv.ctx)
 			if err != nil {
 				xl.Warnf("keepTunnelOpenWorker get tunnel connection error: %v", err)
 				_ = sv.retryLimiter.Wait(sv.ctx)
@@ -161,9 +161,9 @@ func (sv *XTCPVisitor) keepTunnelOpenWorker() {
 
 func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
 	xl := xlog.FromContextSafe(sv.ctx)
-	isConnTransfered := false
+	isConnTransferred := false
 	defer func() {
-		if !isConnTransfered {
+		if !isConnTransferred {
 			userConn.Close()
 		}
 	}()
@@ -172,7 +172,7 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
 
 	// Open a tunnel connection to the server. If there is already a successful hole-punching connection,
 	// it will be reused. Otherwise, it will block and wait for a successful hole-punching connection until timeout.
-	ctx := context.Background()
+	ctx := sv.ctx
 	if sv.cfg.FallbackTo != "" {
 		timeoutCtx, cancel := context.WithTimeout(ctx, time.Duration(sv.cfg.FallbackTimeoutMs)*time.Millisecond)
 		defer cancel()
@@ -191,7 +191,7 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
 			xl.Errorf("transfer connection to visitor %s error: %v", sv.cfg.FallbackTo, err)
 			return
 		}
-		isConnTransfered = true
+		isConnTransferred = true
 		return
 	}
 
@@ -219,40 +219,37 @@ func (sv *XTCPVisitor) handleConn(userConn net.Conn) {
 // openTunnel will open a tunnel connection to the target server.
 func (sv *XTCPVisitor) openTunnel(ctx context.Context) (conn net.Conn, err error) {
 	xl := xlog.FromContextSafe(sv.ctx)
-	ticker := time.NewTicker(500 * time.Millisecond)
-	defer ticker.Stop()
+	ctx, cancel := context.WithTimeout(ctx, 20*time.Second)
+	defer cancel()
 
-	timeoutC := time.After(20 * time.Second)
-	immediateTrigger := make(chan struct{}, 1)
-	defer close(immediateTrigger)
-	immediateTrigger <- struct{}{}
+	timer := time.NewTimer(0)
+	defer timer.Stop()
 
 	for {
 		select {
 		case <-sv.ctx.Done():
 			return nil, sv.ctx.Err()
 		case <-ctx.Done():
+			if errors.Is(ctx.Err(), context.DeadlineExceeded) {
+				return nil, fmt.Errorf("open tunnel timeout")
+			}
 			return nil, ctx.Err()
-		case <-immediateTrigger:
-			conn, err = sv.getTunnelConn()
-		case <-ticker.C:
-			conn, err = sv.getTunnelConn()
-		case <-timeoutC:
-			return nil, fmt.Errorf("open tunnel timeout")
-		}
-
-		if err != nil {
-			if err != ErrNoTunnelSession {
-				xl.Warnf("get tunnel connection error: %v", err)
+		case <-timer.C:
+			conn, err = sv.getTunnelConn(ctx)
+			if err != nil {
+				if !errors.Is(err, ErrNoTunnelSession) {
+					xl.Warnf("get tunnel connection error: %v", err)
+				}
+				timer.Reset(500 * time.Millisecond)
+				continue
 			}
-			continue
+			return conn, nil
 		}
-		return conn, nil
 	}
 }
 
-func (sv *XTCPVisitor) getTunnelConn() (net.Conn, error) {
-	conn, err := sv.session.OpenConn(sv.ctx)
+func (sv *XTCPVisitor) getTunnelConn(ctx context.Context) (net.Conn, error) {
+	conn, err := sv.session.OpenConn(ctx)
 	if err == nil {
 		return conn, nil
 	}