Bläddra i källkod

基于tcp连接测速

陈国伟 3 år sedan
förälder
incheckning
255c44af9c

+ 0 - 55
FastGithub.DomainResolve/DnsClient.cs

@@ -30,7 +30,6 @@ namespace FastGithub.DomainResolve
         private readonly FastGithubConfig fastGithubConfig;
         private readonly ILogger<DnsClient> logger;
 
-        private readonly ConcurrentDictionary<string, IPAddressCollection> domainIPAddressCollection = new();
         private readonly ConcurrentDictionary<string, SemaphoreSlim> semaphoreSlims = new();
         private readonly IMemoryCache dnsCache = new MemoryCache(Options.Create(new MemoryCacheOptions()));
         private readonly TimeSpan defaultEmptyTtl = TimeSpan.FromSeconds(30d);
@@ -54,15 +53,6 @@ namespace FastGithub.DomainResolve
             this.logger = logger;
         }
 
-        /// <summary>
-        /// 预加载
-        /// </summary>
-        /// <param name="domain">域名</param>
-        public void Prefetch(string domain)
-        {
-            this.domainIPAddressCollection.TryAdd(domain, new IPAddressCollection());
-        }
-
         /// <summary>
         /// 解析域名
         /// </summary>
@@ -70,51 +60,6 @@ namespace FastGithub.DomainResolve
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
         public async IAsyncEnumerable<IPAddress> ResolveAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
-        {
-            if (this.domainIPAddressCollection.TryGetValue(domain, out var collection) && collection.Count > 0)
-            {
-                foreach (var address in collection.ToArray())
-                {
-                    yield return address;
-                }
-            }
-            else
-            {
-                this.domainIPAddressCollection.TryAdd(domain, new IPAddressCollection());
-                await foreach (var adddress in this.ResolveCoreAsync(domain, cancellationToken))
-                {
-                    yield return adddress;
-                }
-            }
-        }
-
-        /// <summary>
-        /// 对所有域名所有IP进行ping测试
-        /// </summary>
-        /// <param name="cancellationToken"></param>
-        /// <returns></returns>
-        public async Task PingAllDomainsAsync(CancellationToken cancellationToken)
-        {
-            foreach (var keyValue in this.domainIPAddressCollection)
-            {
-                var domain = keyValue.Key;
-                var collection = keyValue.Value;
-
-                await foreach (var address in this.ResolveCoreAsync(domain, cancellationToken))
-                {
-                    collection.Add(address);
-                }
-                await collection.PingAllAsync();
-            }
-        }
-
-        /// <summary>
-        /// 解析域名
-        /// </summary>
-        /// <param name="domain">域名</param>
-        /// <param name="cancellationToken"></param>
-        /// <returns></returns>
-        private async IAsyncEnumerable<IPAddress> ResolveCoreAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
         {
             var hashSet = new HashSet<IPAddress>();
             foreach (var dns in this.GetDnsServers())

+ 7 - 7
FastGithub.DomainResolve/DomainResolveHostedService.cs

@@ -11,20 +11,20 @@ namespace FastGithub.DomainResolve
     sealed class DomainResolveHostedService : BackgroundService
     {
         private readonly DnscryptProxy dnscryptProxy;
-        private readonly DnsClient dnsClient;
-        private readonly TimeSpan pingPeriodTimeSpan = TimeSpan.FromSeconds(10d);
+        private readonly IDomainResolver domainResolver;
+        private readonly TimeSpan testPeriodTimeSpan = TimeSpan.FromSeconds (1d);
 
         /// <summary>
         /// 域名解析后台服务
         /// </summary>
         /// <param name="dnscryptProxy"></param>
-        /// <param name="dnsClient"></param>
+        /// <param name="domainResolver"></param>
         public DomainResolveHostedService(
             DnscryptProxy dnscryptProxy,
-            DnsClient dnsClient)
+            IDomainResolver domainResolver)
         {
             this.dnscryptProxy = dnscryptProxy;
-            this.dnsClient = dnsClient;
+            this.domainResolver = domainResolver;
         }
 
         /// <summary>
@@ -37,8 +37,8 @@ namespace FastGithub.DomainResolve
             await this.dnscryptProxy.StartAsync(stoppingToken);
             while (stoppingToken.IsCancellationRequested == false)
             {
-                await this.dnsClient.PingAllDomainsAsync(stoppingToken);
-                await Task.Delay(this.pingPeriodTimeSpan, stoppingToken);
+                await this.domainResolver.TestAllEndPointsAsync(stoppingToken);
+                await Task.Delay(this.testPeriodTimeSpan, stoppingToken);
             }
         }
 

+ 95 - 9
FastGithub.DomainResolve/DomainResolver.cs

@@ -1,6 +1,12 @@
 using FastGithub.Configuration;
+using System;
+using System.Collections.Concurrent;
 using System.Collections.Generic;
+using System.Diagnostics;
+using System.Linq;
 using System.Net;
+using System.Net.Sockets;
+using System.Runtime.CompilerServices;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -12,6 +18,7 @@ namespace FastGithub.DomainResolve
     sealed class DomainResolver : IDomainResolver
     {
         private readonly DnsClient dnsClient;
+        private readonly ConcurrentDictionary<DnsEndPoint, IPAddressTestResult> dnsEndPointAddressTestResult = new();
 
         /// <summary>
         /// 域名解析器
@@ -23,38 +30,117 @@ namespace FastGithub.DomainResolve
         }
 
         /// <summary>
-        /// 加载
+        /// 加载
         /// </summary>
         /// <param name="domain">域名</param>
         public void Prefetch(string domain)
         {
-            this.dnsClient.Prefetch(domain);
+            var endPoint = new DnsEndPoint(domain, 443);
+            this.dnsEndPointAddressTestResult.TryAdd(endPoint, IPAddressTestResult.Empty);
+        }
+
+        /// <summary>
+        /// 对所有节点进行测速
+        /// </summary>
+        /// <param name="cancellationToken"></param>
+        /// <returns></returns>
+        public async Task TestAllEndPointsAsync(CancellationToken cancellationToken)
+        {
+            foreach (var keyValue in this.dnsEndPointAddressTestResult)
+            {
+                if (keyValue.Value.IsEmpty || keyValue.Value.IsExpired)
+                {
+                    var dnsEndPoint = keyValue.Key;
+                    var addresses = new List<IPAddress>();
+                    await foreach (var adddress in this.dnsClient.ResolveAsync(dnsEndPoint.Host, cancellationToken))
+                    {
+                        addresses.Add(adddress);
+                    }
+
+                    var addressTestResult = IPAddressTestResult.Empty;
+                    if (addresses.Count == 1)
+                    {
+                        var addressElapseds = new[] { new IPAddressElapsed(addresses[0], TimeSpan.Zero) };
+                        addressTestResult = new IPAddressTestResult(addressElapseds);
+                    }
+                    else if (addresses.Count > 1)
+                    {
+                        var tasks = addresses.Select(item => GetIPAddressElapsedAsync(item, dnsEndPoint.Port, cancellationToken));
+                        var addressElapseds = await Task.WhenAll(tasks);
+                        addressTestResult = new IPAddressTestResult(addressElapseds);
+                    }
+                    this.dnsEndPointAddressTestResult[dnsEndPoint] = addressTestResult;
+                }
+            }
+        }
+
+        /// <summary>
+        /// 获取连接耗时
+        /// </summary>
+        /// <param name="address"></param>
+        /// <param name="port"></param>
+        /// <param name="cancellationToken"></param>
+        /// <returns></returns>
+        private static async Task<IPAddressElapsed> GetIPAddressElapsedAsync(IPAddress address, int port, CancellationToken cancellationToken)
+        {
+            var stopWatch = Stopwatch.StartNew();
+            try
+            {
+                using var timeoutTokenSource = new CancellationTokenSource(TimeSpan.FromSeconds(10d));
+                using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
+                using var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
+                await socket.ConnectAsync(address, port, linkedTokenSource.Token);
+                return new IPAddressElapsed(address, stopWatch.Elapsed);
+            }
+            catch (Exception)
+            {
+                cancellationToken.ThrowIfCancellationRequested();
+                return new IPAddressElapsed(address, TimeSpan.MaxValue);
+            }
+            finally
+            {
+                stopWatch.Stop();
+            }
         }
 
         /// <summary>
         /// 解析ip
         /// </summary>
-        /// <param name="domain">域名</param>
+        /// <param name="endPoint">节点</param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
-        public async Task<IPAddress> ResolveAnyAsync(string domain, CancellationToken cancellationToken = default)
+        public async Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default)
         {
-            await foreach (var address in this.ResolveAllAsync(domain, cancellationToken))
+            await foreach (var address in this.ResolveAllAsync(endPoint, cancellationToken))
             {
                 return address;
             }
-            throw new FastGithubException($"解析不到{domain}的IP");
+            throw new FastGithubException($"解析不到{endPoint.Host}的IP");
         }
 
         /// <summary>
         /// 解析域名
         /// </summary>
-        /// <param name="domain">域名</param>
+        /// <param name="endPoint">节点</param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
-        public IAsyncEnumerable<IPAddress> ResolveAllAsync(string domain, CancellationToken cancellationToken)
+        public async IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
         {
-            return this.dnsClient.ResolveAsync(domain, cancellationToken);
+            if (this.dnsEndPointAddressTestResult.TryGetValue(endPoint, out var speedTestResult) && speedTestResult.IsEmpty == false)
+            {
+                foreach (var addressElapsed in speedTestResult.AddressElapseds)
+                {
+                    yield return addressElapsed.Adddress;
+                }
+            }
+            else
+            {
+                this.dnsEndPointAddressTestResult.TryAdd(endPoint, IPAddressTestResult.Empty);
+                await foreach (var adddress in this.dnsClient.ResolveAsync(endPoint.Host, cancellationToken))
+                {
+                    yield return adddress;
+                }
+            }
         }
     }
 }

+ 12 - 5
FastGithub.DomainResolve/IDomainResolver.cs

@@ -11,25 +11,32 @@ namespace FastGithub.DomainResolve
     public interface IDomainResolver
     {
         /// <summary>
-        /// 加载
+        /// 加载
         /// </summary>
         /// <param name="domain">域名</param>
         void Prefetch(string domain);
 
+        /// <summary>
+        /// 对所有节点进行测速
+        /// </summary>
+        /// <param name="cancellationToken"></param>
+        /// <returns></returns>
+        Task TestAllEndPointsAsync(CancellationToken cancellationToken);
+
         /// <summary>
         /// 解析ip
         /// </summary>
-        /// <param name="domain">域名</param>
+        /// <param name="endPoint">节点</param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
-        Task<IPAddress> ResolveAnyAsync(string domain, CancellationToken cancellationToken = default);
+        Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
 
         /// <summary>
         /// 解析所有ip
         /// </summary>
-        /// <param name="domain">域名</param>
+        /// <param name="endPoint">节点</param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
-        IAsyncEnumerable<IPAddress> ResolveAllAsync(string domain, CancellationToken cancellationToken = default);
+        IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default);
     }
 }

+ 0 - 165
FastGithub.DomainResolve/IPAddressCollection.cs

@@ -1,165 +0,0 @@
-using System;
-using System.Collections.Generic;
-using System.Diagnostics;
-using System.Linq;
-using System.Net;
-using System.Net.NetworkInformation;
-using System.Threading.Tasks;
-
-namespace FastGithub.DomainResolve
-{
-    /// <summary>
-    /// IPAddress集合
-    /// </summary>
-    [DebuggerDisplay("Count = {Count}")]
-    sealed class IPAddressCollection
-    {
-        private readonly object syncRoot = new();
-        private readonly HashSet<IPAddressItem> hashSet = new();
-
-        /// <summary>
-        /// 获取元素数量
-        /// </summary>
-        public int Count => this.hashSet.Count;
-
-        /// <summary>
-        /// 添加元素
-        /// </summary>
-        /// <param name="address"></param>
-        /// <returns></returns>
-        public bool Add(IPAddress address)
-        {
-            lock (this.syncRoot)
-            {
-                return this.hashSet.Add(new IPAddressItem(address));
-            }
-        }
-
-        /// <summary>
-        /// 转后为数组
-        /// </summary>
-        /// <returns></returns>
-        public IPAddress[] ToArray()
-        {
-            lock (this.syncRoot)
-            {
-                return this.hashSet.OrderBy(item => item.PingElapsed).Select(item => item.Address).ToArray();
-            }
-        }
-
-        /// <summary>
-        /// Ping所有IP
-        /// </summary>
-        /// <returns></returns>
-        public async Task PingAllAsync()
-        {
-            foreach (var item in this.ToItemArray())
-            {
-                await item.PingAsync();
-            }
-        }
-
-        /// <summary>
-        /// 转换为数组
-        /// </summary>
-        /// <returns></returns>
-        private IPAddressItem[] ToItemArray()
-        {
-            lock (this.syncRoot)
-            {
-                return this.hashSet.ToArray();
-            }
-        }
-
-        /// <summary>
-        /// IP地址项
-        /// </summary>
-        [DebuggerDisplay("Address = {Address}, PingElapsed = {PingElapsed}")]
-        private class IPAddressItem : IEquatable<IPAddressItem>
-        {
-            /// <summary>
-            /// Ping的时间点
-            /// </summary>
-            private int? pingTicks;
-
-            /// <summary>
-            /// 地址
-            /// </summary>
-            public IPAddress Address { get; }
-
-            /// <summary>
-            /// Ping耗时
-            /// </summary>
-            public TimeSpan PingElapsed { get; private set; } = TimeSpan.MaxValue;
-
-            /// <summary>
-            /// IP地址项
-            /// </summary>
-            /// <param name="address"></param>
-            public IPAddressItem(IPAddress address)
-            {
-                this.Address = address;
-            }
-            /// <summary>
-            /// 发起ping请求
-            /// </summary>
-            /// <returns></returns>
-            public async Task PingAsync()
-            {
-                if (this.NeedToPing() == false)
-                {
-                    return;
-                }
-
-                try
-                {
-                    using var ping = new Ping();
-                    var reply = await ping.SendPingAsync(this.Address);
-                    this.PingElapsed = reply.Status == IPStatus.Success
-                        ? TimeSpan.FromMilliseconds(reply.RoundtripTime)
-                        : TimeSpan.MaxValue;
-                }
-                catch (Exception)
-                {
-                    this.PingElapsed = TimeSpan.MaxValue;
-                }
-                finally
-                {
-                    this.pingTicks = Environment.TickCount;
-                }
-            }
-
-            /// <summary>
-            /// 是否需要ping
-            /// 5分钟内只ping一次
-            /// </summary>
-            /// <returns></returns>
-            private bool NeedToPing()
-            {
-                var ticks = this.pingTicks;
-                if (ticks == null)
-                {
-                    return true;
-                }
-
-                var pingTimeSpan = TimeSpan.FromMilliseconds(Environment.TickCount - ticks.Value);
-                return pingTimeSpan > TimeSpan.FromMinutes(5d);
-            }
-
-            public bool Equals(IPAddressItem? other)
-            {
-                return other != null && other.Address.Equals(this.Address);
-            }
-
-            public override bool Equals(object? obj)
-            {
-                return obj is IPAddressItem other && this.Equals(other);
-            }
-
-            public override int GetHashCode()
-            {
-                return this.Address.GetHashCode();
-            }
-        }
-    }
-}

+ 34 - 0
FastGithub.DomainResolve/IPAddressElapsed.cs

@@ -0,0 +1,34 @@
+using System;
+using System.Diagnostics;
+using System.Net;
+
+namespace FastGithub.DomainResolve
+{
+    /// <summary>
+    /// IP连接耗时
+    /// </summary>
+    [DebuggerDisplay("Adddress={Adddress} Elapsed={Elapsed}")]
+    struct IPAddressElapsed
+    {
+        /// <summary>
+        /// 获取IP地址
+        /// </summary>
+        public IPAddress Adddress { get; }
+
+        /// <summary>
+        /// 获取连接耗时
+        /// </summary>
+        public TimeSpan Elapsed { get; }
+
+        /// <summary>
+        /// IP连接耗时
+        /// </summary>
+        /// <param name="adddress"></param>
+        /// <param name="elapsed"></param>
+        public IPAddressElapsed(IPAddress adddress, TimeSpan elapsed)
+        {
+            this.Adddress = adddress;
+            this.Elapsed = elapsed;
+        }
+    }
+}

+ 44 - 0
FastGithub.DomainResolve/IPAddressTestResult.cs

@@ -0,0 +1,44 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace FastGithub.DomainResolve
+{
+    /// <summary>
+    /// IP测速结果
+    /// </summary>
+    sealed class IPAddressTestResult
+    {
+        private static readonly TimeSpan lifeTime = TimeSpan.FromMinutes(2d);
+        private readonly int creationTickCount = Environment.TickCount;
+
+        /// <summary>
+        /// 获取空的
+        /// </summary>
+        public static IPAddressTestResult Empty = new(Array.Empty<IPAddressElapsed>());
+
+        /// <summary>
+        /// 获取是否为空
+        /// </summary>
+        public bool IsEmpty => this.AddressElapseds.Length == 0;
+
+        /// <summary>
+        /// 获取是否已过期
+        /// </summary>
+        public bool IsExpired => lifeTime < TimeSpan.FromMilliseconds(Environment.TickCount - this.creationTickCount);
+
+        /// <summary>
+        /// 获取测速结果
+        /// </summary>
+        public IPAddressElapsed[] AddressElapseds { get; }
+
+        /// <summary>
+        /// 测速结果
+        /// </summary>
+        /// <param name="result"></param>
+        public IPAddressTestResult(IEnumerable<IPAddressElapsed> addressElapseds)
+        {
+            this.AddressElapseds = addressElapseds.OrderBy(item => item.Elapsed).ToArray();
+        }
+    }
+}

+ 1 - 1
FastGithub.Http/HttpClientHandler.cs

@@ -186,7 +186,7 @@ namespace FastGithub.Http
             }
             else
             {
-                await foreach (var item in this.domainResolver.ResolveAllAsync(dnsEndPoint.Host, cancellationToken))
+                await foreach (var item in this.domainResolver.ResolveAllAsync(dnsEndPoint, cancellationToken))
                 {
                     yield return new IPEndPoint(item, dnsEndPoint.Port);
                 }

+ 1 - 1
FastGithub.HttpServer/HttpProxyMiddleware.cs

@@ -158,7 +158,7 @@ namespace FastGithub.HttpServer
             }
 
             // 不使用系统dns
-            address = await this.domainResolver.ResolveAnyAsync(targetHost);
+            address = await this.domainResolver.ResolveAnyAsync(new DnsEndPoint(targetHost, targetPort));
             return new IPEndPoint(address, targetPort);
         }
 

+ 4 - 4
FastGithub.HttpServer/SshReverseProxyHandler.cs

@@ -1,6 +1,7 @@
 using FastGithub.DomainResolve;
 using Microsoft.AspNetCore.Connections;
 using System.IO.Pipelines;
+using System.Net;
 using System.Net.Sockets;
 using System.Threading.Tasks;
 
@@ -12,8 +13,7 @@ namespace FastGithub.HttpServer
     sealed class SshReverseProxyHandler : ConnectionHandler
     {
         private readonly IDomainResolver domainResolver;
-        private const string SSH_GITHUB_COM = "ssh.github.com";
-        private const int SSH_OVER_HTTPS_PORT = 443;
+        private readonly DnsEndPoint sshOverHttpsEndPoint = new("ssh.github.com", 443);
 
         /// <summary>
         /// github的ssh代理处理者
@@ -31,9 +31,9 @@ namespace FastGithub.HttpServer
         /// <returns></returns>
         public override async Task OnConnectedAsync(ConnectionContext context)
         {
-            var address = await this.domainResolver.ResolveAnyAsync(SSH_GITHUB_COM);
+            var address = await this.domainResolver.ResolveAnyAsync(this.sshOverHttpsEndPoint);
             using var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
-            await socket.ConnectAsync(address, SSH_OVER_HTTPS_PORT);
+            await socket.ConnectAsync(address, this.sshOverHttpsEndPoint.Port);
             var targetStream = new NetworkStream(socket, ownsSocket: false);
 
             var task1 = targetStream.CopyToAsync(context.Transport.Output);