2
0
老九 3 жил өмнө
parent
commit
e88312d177

+ 13 - 2
FastGithub.DomainResolve/DnsClient.cs

@@ -56,12 +56,22 @@ namespace FastGithub.DomainResolve
         /// <param name="domain">域名</param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
-        public async IAsyncEnumerable<IPAddress> ResolveAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
+        public async IAsyncEnumerable<IPAddress[]> ResolveAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
         {
             var hashSet = new HashSet<IPAddress>();
             foreach (var dns in this.GetDnsServers())
             {
-                foreach (var address in await this.LookupAsync(dns, domain, cancellationToken))
+                var addresses = await this.LookupAsync(dns, domain, cancellationToken);
+                var value = Filter(hashSet, addresses).ToArray();
+                if (value.Length > 0)
+                {
+                    yield return value;
+                }
+            }
+
+            static IEnumerable<IPAddress> Filter(HashSet<IPAddress> hashSet, IPAddress[] addresses)
+            {
+                foreach (var address in addresses)
                 {
                     if (hashSet.Add(address) == true)
                     {
@@ -71,6 +81,7 @@ namespace FastGithub.DomainResolve
             }
         }
 
+
         /// <summary>
         /// 获取dns服务
         /// </summary>

+ 4 - 10
FastGithub.DomainResolve/DomainResolver.cs

@@ -12,19 +12,14 @@ namespace FastGithub.DomainResolve
     /// </summary> 
     sealed class DomainResolver : IDomainResolver
     {
-        private readonly DnsClient dnsClient;
         private readonly DomainSpeedTester speedTester;
 
         /// <summary>
         /// 域名解析器
-        /// </summary>
-        /// <param name="dnsClient"></param>
+        /// </summary> 
         /// <param name="speedTester"></param>
-        public DomainResolver(
-            DnsClient dnsClient,
-            DomainSpeedTester speedTester)
+        public DomainResolver(DomainSpeedTester speedTester)
         {
-            this.dnsClient = dnsClient;
             this.speedTester = speedTester;
         }
 
@@ -51,8 +46,7 @@ namespace FastGithub.DomainResolve
         /// <returns></returns>
         public async IAsyncEnumerable<IPAddress> ResolveAllAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
         {
-            var addresses = this.speedTester.GetIPAddresses(domain);
-            if (addresses.Length > 0)
+            if (this.speedTester.TryGetOrderAllIPAddresses(domain, out var addresses))
             {
                 foreach (var address in addresses)
                 {
@@ -62,7 +56,7 @@ namespace FastGithub.DomainResolve
             else
             {
                 this.speedTester.Add(domain);
-                await foreach (var address in this.dnsClient.ResolveAsync(domain, cancellationToken))
+                await foreach (var address in this.speedTester.GetOrderAnyIPAddressAsync(domain, cancellationToken))
                 {
                     yield return address;
                 }

+ 33 - 6
FastGithub.DomainResolve/DomainSpeedTester.cs

@@ -2,9 +2,11 @@
 using Microsoft.Extensions.Logging;
 using System;
 using System.Collections.Generic;
+using System.Diagnostics.CodeAnalysis;
 using System.IO;
 using System.Linq;
 using System.Net;
+using System.Runtime.CompilerServices;
 using System.Text.Json;
 using System.Threading;
 using System.Threading.Tasks;
@@ -106,20 +108,42 @@ namespace FastGithub.DomainResolve
             File.WriteAllBytes(DOMAINS_JSON_FILE, utf8Json);
         }
 
+
         /// <summary>
-        /// 获取测试后排序的IP
+        /// 尝试获取测试后排序的IP地址
         /// </summary>
         /// <param name="domain"></param>
+        /// <param name="addresses"></param>
         /// <returns></returns>
-        public IPAddress[] GetIPAddresses(string domain)
+        public bool TryGetOrderAllIPAddresses(string domain, [MaybeNullWhen(false)] out IPAddress[] addresses)
         {
             lock (this.syncRoot)
             {
                 if (this.domainIPAddressHashSet.TryGetValue(domain, out var hashSet) && hashSet.Count > 0)
                 {
-                    return hashSet.ToArray().OrderBy(item => item.PingElapsed).Select(item => item.Address).ToArray();
+                    addresses = hashSet.ToArray().OrderBy(item => item.PingElapsed).Select(item => item.Address).ToArray();
+                    return true;
+                }
+            }
+
+            addresses = default;
+            return false;
+        }
+
+        /// <summary>
+        /// 获取只排序头个元素的IP地址
+        /// </summary>
+        /// <param name="domain">域名</param>
+        /// <param name="cancellationToken"></param>
+        /// <returns></returns>
+        public async IAsyncEnumerable<IPAddress> GetOrderAnyIPAddressAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
+        {
+            await foreach (var addresses in this.dnsClient.ResolveAsync(domain, cancellationToken))
+            {
+                foreach (var address in addresses)
+                {
+                    yield return address;
                 }
-                return Array.Empty<IPAddress>();
             }
         }
 
@@ -140,9 +164,12 @@ namespace FastGithub.DomainResolve
             {
                 var domain = keyValue.Key;
                 var hashSet = keyValue.Value;
-                await foreach (var address in this.dnsClient.ResolveAsync(domain, cancellationToken))
+                await foreach (var addresses in this.dnsClient.ResolveAsync(domain, cancellationToken))
                 {
-                    hashSet.Add(new IPAddressItem(address));
+                    foreach (var address in addresses)
+                    {
+                        hashSet.Add(new IPAddressItem(address));
+                    }
                 }
                 await hashSet.PingAllAsync();
             }