Selaa lähdekoodia

使用多ip尝试连接

陈国伟 3 vuotta sitten
vanhempi
commit
82efd98448

+ 3 - 20
FastGithub.DomainResolve/DomainResolver.cs

@@ -1,5 +1,4 @@
-using FastGithub.Configuration;
-using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging;
 using System;
 using System;
 using System.Collections.Concurrent;
 using System.Collections.Concurrent;
 using System.Collections.Generic;
 using System.Collections.Generic;
@@ -44,23 +43,7 @@ namespace FastGithub.DomainResolve
             {
             {
                 this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty<IPAddress>());
                 this.dnsEndPointAddress.TryAdd(endPoint, Array.Empty<IPAddress>());
             }
             }
-        }
-
-
-        /// <summary>
-        /// 解析ip
-        /// </summary>
-        /// <param name="endPoint">节点</param>
-        /// <param name="cancellationToken"></param>
-        /// <returns></returns>
-        public async Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default)
-        {
-            await foreach (var address in this.ResolveAllAsync(endPoint, cancellationToken))
-            {
-                return address;
-            }
-            throw new FastGithubException($"解析不到{endPoint.Host}的IP");
-        }
+        } 
 
 
         /// <summary>
         /// <summary>
         /// 解析域名
         /// 解析域名
@@ -68,7 +51,7 @@ namespace FastGithub.DomainResolve
         /// <param name="endPoint">节点</param>
         /// <param name="endPoint">节点</param>
         /// <param name="cancellationToken"></param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
         /// <returns></returns>
-        public async IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
+        public async IAsyncEnumerable<IPAddress> ResolveAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
         {
         {
             if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0)
             if (this.dnsEndPointAddress.TryGetValue(endPoint, out var addresses) && addresses.Length > 0)
             {
             {

+ 2 - 10
FastGithub.DomainResolve/IDomainResolver.cs

@@ -9,22 +9,14 @@ namespace FastGithub.DomainResolve
     /// 域名解析器
     /// 域名解析器
     /// </summary>
     /// </summary>
     public interface IDomainResolver
     public interface IDomainResolver
-    {
-        /// <summary>
-        /// 解析ip
-        /// </summary>
-        /// <param name="endPoint">节点</param>
-        /// <param name="cancellationToken"></param>
-        /// <returns></returns>
-        Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
-
+    { 
         /// <summary>
         /// <summary>
         /// 解析所有ip
         /// 解析所有ip
         /// </summary>
         /// </summary>
         /// <param name="endPoint">节点</param>
         /// <param name="endPoint">节点</param>
         /// <param name="cancellationToken"></param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
         /// <returns></returns>
-        IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
+        IAsyncEnumerable<IPAddress> ResolveAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
 
 
         /// <summary>
         /// <summary>
         /// 对所有节点进行测速
         /// 对所有节点进行测速

+ 1 - 1
FastGithub.Http/HttpClientHandler.cs

@@ -186,7 +186,7 @@ namespace FastGithub.Http
                     yield return new IPEndPoint(this.domainConfig.IPAddress, dnsEndPoint.Port);
                     yield return new IPEndPoint(this.domainConfig.IPAddress, dnsEndPoint.Port);
                 }
                 }
 
 
-                await foreach (var item in this.domainResolver.ResolveAllAsync(dnsEndPoint, cancellationToken))
+                await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint, cancellationToken))
                 {
                 {
                     yield return new IPEndPoint(item, dnsEndPoint.Port);
                     yield return new IPEndPoint(item, dnsEndPoint.Port);
                 }
                 }

+ 48 - 14
FastGithub.HttpServer/HttpProxyMiddleware.cs

@@ -4,12 +4,16 @@ using Microsoft.AspNetCore.Connections.Features;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Http.Features;
 using Microsoft.AspNetCore.Http.Features;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
+using System;
+using System.Collections.Generic;
+using System.IO;
 using System.IO.Pipelines;
 using System.IO.Pipelines;
 using System.Net;
 using System.Net;
 using System.Net.Http;
 using System.Net.Http;
 using System.Net.Sockets;
 using System.Net.Sockets;
 using System.Reflection;
 using System.Reflection;
 using System.Text;
 using System.Text;
+using System.Threading;
 using System.Threading.Tasks;
 using System.Threading.Tasks;
 using Yarp.ReverseProxy.Forwarder;
 using Yarp.ReverseProxy.Forwarder;
 
 
@@ -30,6 +34,7 @@ namespace FastGithub.HttpServer
         private readonly HttpReverseProxyMiddleware httpReverseProxy;
         private readonly HttpReverseProxyMiddleware httpReverseProxy;
 
 
         private readonly HttpMessageInvoker defaultHttpClient;
         private readonly HttpMessageInvoker defaultHttpClient;
+        private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
 
 
         static HttpProxyMiddleware()
         static HttpProxyMiddleware()
         {
         {
@@ -83,10 +88,7 @@ namespace FastGithub.HttpServer
             }
             }
             else if (context.Request.Method == HttpMethods.Connect)
             else if (context.Request.Method == HttpMethods.Connect)
             {
             {
-                var endpoint = await this.GetTargetEndPointAsync(host);
-                using var targetSocket = new Socket(SocketType.Stream, ProtocolType.Tcp);
-                await targetSocket.ConnectAsync(endpoint);
-
+                using var connection = await this.CreateConnectionAsync(host);
                 var responseFeature = context.Features.Get<IHttpResponseFeature>();
                 var responseFeature = context.Features.Get<IHttpResponseFeature>();
                 if (responseFeature != null)
                 if (responseFeature != null)
                 {
                 {
@@ -98,8 +100,7 @@ namespace FastGithub.HttpServer
                 var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport;
                 var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport;
                 if (transport != null)
                 if (transport != null)
                 {
                 {
-                    var targetStream = new NetworkStream(targetSocket, ownsSocket: false);
-                    await Task.WhenAny(targetStream.CopyToAsync(transport.Output), transport.Input.CopyToAsync(targetStream));
+                    await Task.WhenAny(connection.CopyToAsync(transport.Output), transport.Input.CopyToAsync(connection));
                 }
                 }
             }
             }
             else
             else
@@ -151,40 +152,73 @@ namespace FastGithub.HttpServer
             return buidler.ToString();
             return buidler.ToString();
         }
         }
 
 
+        /// <summary>
+        /// 创建连接
+        /// </summary>
+        /// <param name="host"></param>
+        /// <returns></returns>
+        /// <exception cref="AggregateException"></exception>
+        private async Task<Stream> CreateConnectionAsync(HostString host)
+        {
+            var innerExceptions = new List<Exception>();
+            await foreach (var endPoint in this.GetTargetEndPointsAsync(host))
+            {
+                var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
+                try
+                {
+                    using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
+                    await socket.ConnectAsync(endPoint, timeoutTokenSource.Token);
+                    return new NetworkStream(socket, ownsSocket: false);
+                }
+                catch (Exception ex)
+                {
+                    socket.Dispose();
+                    innerExceptions.Add(ex);
+                }
+            }
+            throw new AggregateException($"无法连接到{host}", innerExceptions);
+        }
+
         /// <summary>
         /// <summary>
         /// 获取目标终节点
         /// 获取目标终节点
         /// </summary>
         /// </summary>
         /// <param name="host"></param>
         /// <param name="host"></param>
         /// <returns></returns>
         /// <returns></returns>
-        private async Task<EndPoint> GetTargetEndPointAsync(HostString host)
+        private async IAsyncEnumerable<EndPoint> GetTargetEndPointsAsync(HostString host)
         {
         {
             var targetHost = host.Host;
             var targetHost = host.Host;
             var targetPort = host.Port ?? HTTPS_PORT;
             var targetPort = host.Port ?? HTTPS_PORT;
 
 
             if (IPAddress.TryParse(targetHost, out var address) == true)
             if (IPAddress.TryParse(targetHost, out var address) == true)
             {
             {
-                return new IPEndPoint(address, targetPort);
+                yield return new IPEndPoint(address, targetPort);
+                yield break;
             }
             }
 
 
             // 不关心的域名,直接使用系统dns
             // 不关心的域名,直接使用系统dns
             if (this.fastGithubConfig.IsMatch(targetHost) == false)
             if (this.fastGithubConfig.IsMatch(targetHost) == false)
             {
             {
-                return new DnsEndPoint(targetHost, targetPort);
+                yield return new DnsEndPoint(targetHost, targetPort);
+                yield break;
             }
             }
 
 
             if (targetPort == HTTP_PORT)
             if (targetPort == HTTP_PORT)
             {
             {
-                return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Http);
+                yield return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Http);
+                yield break;
             }
             }
 
 
             if (targetPort == HTTPS_PORT)
             if (targetPort == HTTPS_PORT)
             {
             {
-                return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Https);
+                yield return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Https);
+                yield break;
             }
             }
 
 
-            // 不使用系统dns
-            address = await this.domainResolver.ResolveAnyAsync(new DnsEndPoint(targetHost, targetPort));
-            return new IPEndPoint(address, targetPort);
+            var dnsEndPoint = new DnsEndPoint(targetHost, targetPort);
+            await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint))
+            {
+                yield return new IPEndPoint(item, targetPort);
+            }
         }
         }
 
 
         /// <summary>
         /// <summary>

+ 8 - 5
FastGithub.HttpServer/TcpReverseProxyHandler.cs

@@ -6,6 +6,7 @@ using System.IO;
 using System.IO.Pipelines;
 using System.IO.Pipelines;
 using System.Net;
 using System.Net;
 using System.Net.Sockets;
 using System.Net.Sockets;
+using System.Threading;
 using System.Threading.Tasks;
 using System.Threading.Tasks;
 
 
 namespace FastGithub.HttpServer
 namespace FastGithub.HttpServer
@@ -17,6 +18,7 @@ namespace FastGithub.HttpServer
     {
     {
         private readonly IDomainResolver domainResolver;
         private readonly IDomainResolver domainResolver;
         private readonly DnsEndPoint endPoint;
         private readonly DnsEndPoint endPoint;
+        private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
 
 
         /// <summary>
         /// <summary>
         /// tcp反射代理处理者
         /// tcp反射代理处理者
@@ -36,9 +38,9 @@ namespace FastGithub.HttpServer
         /// <returns></returns>
         /// <returns></returns>
         public override async Task OnConnectedAsync(ConnectionContext context)
         public override async Task OnConnectedAsync(ConnectionContext context)
         {
         {
-            using var targetStream = await this.CreateConnectionAsync();
-            var task1 = targetStream.CopyToAsync(context.Transport.Output);
-            var task2 = context.Transport.Input.CopyToAsync(targetStream);
+            using var connection = await this.CreateConnectionAsync();
+            var task1 = connection.CopyToAsync(context.Transport.Output);
+            var task2 = context.Transport.Input.CopyToAsync(connection);
             await Task.WhenAny(task1, task2);
             await Task.WhenAny(task1, task2);
         }
         }
 
 
@@ -50,12 +52,13 @@ namespace FastGithub.HttpServer
         private async Task<Stream> CreateConnectionAsync()
         private async Task<Stream> CreateConnectionAsync()
         {
         {
             var innerExceptions = new List<Exception>();
             var innerExceptions = new List<Exception>();
-            await foreach (var address in this.domainResolver.ResolveAllAsync(this.endPoint))
+            await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint))
             {
             {
                 var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
                 var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
                 try
                 try
                 {
                 {
-                    await socket.ConnectAsync(address, this.endPoint.Port);
+                    using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
+                    await socket.ConnectAsync(address, this.endPoint.Port, timeoutTokenSource.Token);
                     return new NetworkStream(socket, ownsSocket: false);
                     return new NetworkStream(socket, ownsSocket: false);
                 }
                 }
                 catch (Exception ex)
                 catch (Exception ex)