Browse Source

dns服务器状态缓存

陈国伟 3 years ago
parent
commit
c51cda09b7
1 changed files with 52 additions and 16 deletions
  1. 52 16
      FastGithub.DomainResolve/DnsClient.cs

+ 52 - 16
FastGithub.DomainResolve/DnsClient.cs

@@ -31,13 +31,15 @@ namespace FastGithub.DomainResolve
         private readonly ILogger<DnsClient> logger;
 
         private readonly ConcurrentDictionary<string, SemaphoreSlim> semaphoreSlims = new();
+        private readonly IMemoryCache dnsStateCache = new MemoryCache(Options.Create(new MemoryCacheOptions()));
         private readonly IMemoryCache dnsLookupCache = new MemoryCache(Options.Create(new MemoryCacheOptions()));
 
+        private readonly TimeSpan stateExpiration = TimeSpan.FromMinutes(5d);
         private readonly TimeSpan minTimeToLive = TimeSpan.FromSeconds(30d);
         private readonly TimeSpan maxTimeToLive = TimeSpan.FromMinutes(10d);
 
         private readonly int resolveTimeout = (int)TimeSpan.FromSeconds(4d).TotalMilliseconds;
-        private static readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(2d);
+        private static readonly TimeSpan tcpConnectTimeout = TimeSpan.FromSeconds(2d);
 
         private record LookupResult(IList<IPAddress> Addresses, TimeSpan TimeToLive);
 
@@ -67,7 +69,7 @@ namespace FastGithub.DomainResolve
         public async IAsyncEnumerable<IPAddress> ResolveAsync(DnsEndPoint endPoint, bool fastSort, [EnumeratorCancellation] CancellationToken cancellationToken)
         {
             var hashSet = new HashSet<IPAddress>();
-            foreach (var dns in this.GetDnsServers())
+            await foreach (var dns in this.GetDnsServersAsync(cancellationToken))
             {
                 var addresses = await this.LookupAsync(dns, endPoint, fastSort, cancellationToken);
                 foreach (var address in addresses)
@@ -84,7 +86,7 @@ namespace FastGithub.DomainResolve
         /// 获取dns服务
         /// </summary>
         /// <returns></returns>
-        private IEnumerable<IPEndPoint> GetDnsServers()
+        private async IAsyncEnumerable<IPEndPoint> GetDnsServersAsync([EnumeratorCancellation] CancellationToken cancellationToken)
         {
             var cryptDns = this.dnscryptProxy.LocalEndPoint;
             if (cryptDns != null)
@@ -93,15 +95,52 @@ namespace FastGithub.DomainResolve
                 yield return cryptDns;
             }
 
-            foreach (var fallbackDns in this.fastGithubConfig.FallbackDns)
+            foreach (var dns in this.fastGithubConfig.FallbackDns)
             {
-                if (Socket.OSSupportsIPv6 || fallbackDns.AddressFamily != AddressFamily.InterNetworkV6)
+                if (await this.IsDnsAvailableAsync(dns, cancellationToken))
                 {
-                    yield return fallbackDns;
+                    yield return dns;
                 }
             }
         }
 
+        /// <summary>
+        /// 获取dns是否可用
+        /// </summary>
+        /// <param name="dns"></param>
+        /// <param name="cancellationToken"></param>
+        /// <returns></returns>
+        private async ValueTask<bool> IsDnsAvailableAsync(IPEndPoint dns, CancellationToken cancellationToken)
+        {
+            if (dns.Port != DNS_PORT)
+            {
+                return true;
+            }
+
+            if (this.dnsStateCache.TryGetValue<bool>(dns, out var state))
+            {
+                return state;
+            }
+
+            var key = dns.ToString();
+            var semaphore = this.semaphoreSlims.GetOrAdd(key, _ => new SemaphoreSlim(1, 1));
+            await semaphore.WaitAsync(CancellationToken.None);
+
+            try
+            {
+                using var timeoutTokenSource = new CancellationTokenSource(tcpConnectTimeout);
+                using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(timeoutTokenSource.Token, cancellationToken);
+                using var socket = new Socket(dns.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
+                await socket.ConnectAsync(dns, linkedTokenSource.Token);
+                return this.dnsStateCache.Set(dns, true, this.stateExpiration);
+            }
+            catch (Exception)
+            {
+                cancellationToken.ThrowIfCancellationRequested();
+                return this.dnsStateCache.Set(dns, false, this.stateExpiration);
+            }
+        }
+
         /// <summary>
         /// 解析域名
         /// </summary>
@@ -132,7 +171,7 @@ namespace FastGithub.DomainResolve
             catch (Exception ex)
             {
                 this.logger.LogWarning($"{endPoint.Host}@{dns}->{ex.Message}");
-                var expiration = IsTcpResetException(ex) ? this.maxTimeToLive : this.minTimeToLive;
+                var expiration = IsSocketException(ex) ? this.maxTimeToLive : this.minTimeToLive;
                 return this.dnsLookupCache.Set(key, Array.Empty<IPAddress>(), expiration);
             }
             finally
@@ -142,22 +181,19 @@ namespace FastGithub.DomainResolve
         }
 
         /// <summary>
-        /// 是否为因收到tcp reset导致的关闭
+        /// 是否为Socket异常
         /// </summary>
         /// <param name="ex"></param>
         /// <returns></returns>
-        private static bool IsTcpResetException(Exception ex)
+        private static bool IsSocketException(Exception ex)
         {
-            if (ex is SocketException socketException)
+            if (ex is SocketException)
             {
-                if (socketException.SocketErrorCode == SocketError.ConnectionReset)
-                {
-                    return true;
-                }
+                return true;
             }
 
             var inner = ex.InnerException;
-            return inner != null && IsTcpResetException(inner);
+            return inner != null && IsSocketException(inner);
         }
 
 
@@ -272,7 +308,7 @@ namespace FastGithub.DomainResolve
                 return addresses;
             }
 
-            using var controlTokenSource = new CancellationTokenSource(connectTimeout);
+            using var controlTokenSource = new CancellationTokenSource(tcpConnectTimeout);
             using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, controlTokenSource.Token);
 
             var connectTasks = addresses.Select(address => ConnectAsync(address, port, linkedTokenSource.Token));