Browse Source

减少DnsClient的分配

老九 3 years ago
parent
commit
8e2e3802f0
1 changed files with 24 additions and 10 deletions
  1. 24 10
      FastGithub.DomainResolve/DnsClient.cs

+ 24 - 10
FastGithub.DomainResolve/DnsClient.cs

@@ -39,7 +39,7 @@ namespace FastGithub.DomainResolve
         private readonly int resolveTimeout = (int)TimeSpan.FromSeconds(2d).TotalMilliseconds;
         private readonly int resolveTimeout = (int)TimeSpan.FromSeconds(2d).TotalMilliseconds;
         private static readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(2d);
         private static readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(2d);
 
 
-        private record LookupResult(IPAddress[] Addresses, TimeSpan TimeToLive);
+        private record LookupResult(IList<IPAddress> Addresses, TimeSpan TimeToLive);
 
 
         /// <summary>
         /// <summary>
         /// DNS客户端
         /// DNS客户端
@@ -110,7 +110,7 @@ namespace FastGithub.DomainResolve
         /// <param name="fastSort"></param>
         /// <param name="fastSort"></param>
         /// <param name="cancellationToken"></param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
         /// <returns></returns>
-        private async Task<IPAddress[]> LookupAsync(IPEndPoint dns, DnsEndPoint endPoint, bool fastSort, CancellationToken cancellationToken = default)
+        private async Task<IList<IPAddress>> LookupAsync(IPEndPoint dns, DnsEndPoint endPoint, bool fastSort, CancellationToken cancellationToken = default)
         {
         {
             var key = $"{dns}/{endPoint}";
             var key = $"{dns}/{endPoint}";
             var semaphore = this.semaphoreSlims.GetOrAdd(key, _ => new SemaphoreSlim(1, 1));
             var semaphore = this.semaphoreSlims.GetOrAdd(key, _ => new SemaphoreSlim(1, 1));
@@ -118,7 +118,7 @@ namespace FastGithub.DomainResolve
 
 
             try
             try
             {
             {
-                if (this.dnsCache.TryGetValue<IPAddress[]>(key, out var value))
+                if (this.dnsCache.TryGetValue<IList<IPAddress>>(key, out var value))
                 {
                 {
                     return value;
                     return value;
                 }
                 }
@@ -155,7 +155,16 @@ namespace FastGithub.DomainResolve
         {
         {
             if (endPoint.Host == LOCALHOST)
             if (endPoint.Host == LOCALHOST)
             {
             {
-                return new LookupResult(new[] { IPAddress.Loopback }, TimeSpan.MaxValue);
+                var loopbacks = new List<IPAddress>();
+                if (Socket.OSSupportsIPv4 == true)
+                {
+                    loopbacks.Add(IPAddress.Loopback);
+                }
+                if (Socket.OSSupportsIPv6 == true)
+                {
+                    loopbacks.Add(IPAddress.IPv6Loopback);
+                }
+                return new LookupResult(loopbacks, TimeSpan.MaxValue);
             }
             }
 
 
             var resolver = dns.Port == DNS_PORT
             var resolver = dns.Port == DNS_PORT
@@ -163,17 +172,17 @@ namespace FastGithub.DomainResolve
                 : new UdpRequestResolver(dns, new TcpRequestResolver(dns), this.resolveTimeout);
                 : new UdpRequestResolver(dns, new TcpRequestResolver(dns), this.resolveTimeout);
 
 
             var addressRecords = await GetAddressRecordsAsync(resolver, endPoint.Host, cancellationToken);
             var addressRecords = await GetAddressRecordsAsync(resolver, endPoint.Host, cancellationToken);
-            var addresses = addressRecords
+            var addresses = (IList<IPAddress>)addressRecords
                 .Where(item => IPAddress.IsLoopback(item.IPAddress) == false)
                 .Where(item => IPAddress.IsLoopback(item.IPAddress) == false)
                 .Select(item => item.IPAddress)
                 .Select(item => item.IPAddress)
                 .ToArray();
                 .ToArray();
 
 
-            if (addresses.Length == 0)
+            if (addresses.Count == 0)
             {
             {
                 return new LookupResult(addresses, this.minTimeToLive);
                 return new LookupResult(addresses, this.minTimeToLive);
             }
             }
 
 
-            if (fastSort && addresses.Length > 1)
+            if (fastSort == true)
             {
             {
                 addresses = await OrderByConnectAnyAsync(addresses, endPoint.Port, cancellationToken);
                 addresses = await OrderByConnectAnyAsync(addresses, endPoint.Port, cancellationToken);
             }
             }
@@ -238,8 +247,13 @@ namespace FastGithub.DomainResolve
         /// <param name="port"></param>
         /// <param name="port"></param>
         /// <param name="cancellationToken"></param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
         /// <returns></returns>
-        private static async Task<IPAddress[]> OrderByConnectAnyAsync(IPAddress[] addresses, int port, CancellationToken cancellationToken)
+        private static async Task<IList<IPAddress>> OrderByConnectAnyAsync(IList<IPAddress> addresses, int port, CancellationToken cancellationToken)
         {
         {
+            if (addresses.Count <= 1)
+            {
+                return addresses;
+            }
+
             using var controlTokenSource = new CancellationTokenSource(connectTimeout);
             using var controlTokenSource = new CancellationTokenSource(connectTimeout);
             using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, controlTokenSource.Token);
             using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, controlTokenSource.Token);
 
 
@@ -247,7 +261,7 @@ namespace FastGithub.DomainResolve
             var fastestAddress = await await Task.WhenAny(connectTasks);
             var fastestAddress = await await Task.WhenAny(connectTasks);
             controlTokenSource.Cancel();
             controlTokenSource.Cancel();
 
 
-            if (fastestAddress == null)
+            if (fastestAddress == null || addresses.First().Equals(fastestAddress))
             {
             {
                 return addresses;
                 return addresses;
             }
             }
@@ -260,7 +274,7 @@ namespace FastGithub.DomainResolve
                     list.Add(address);
                     list.Add(address);
                 }
                 }
             }
             }
-            return list.ToArray();
+            return list;
         }
         }
 
 
         /// <summary>
         /// <summary>