DnsClient.cs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. using DNS.Client;
  2. using DNS.Client.RequestResolver;
  3. using DNS.Protocol;
  4. using DNS.Protocol.ResourceRecords;
  5. using FastGithub.Configuration;
  6. using Microsoft.Extensions.Caching.Memory;
  7. using Microsoft.Extensions.Logging;
  8. using Microsoft.Extensions.Options;
  9. using System;
  10. using System.Collections.Concurrent;
  11. using System.Collections.Generic;
  12. using System.Diagnostics.CodeAnalysis;
  13. using System.Linq;
  14. using System.Net;
  15. using System.Net.NetworkInformation;
  16. using System.Runtime.CompilerServices;
  17. using System.Threading;
  18. using System.Threading.Tasks;
  19. namespace FastGithub.DomainResolve
  20. {
  21. /// <summary>
  22. /// DNS客户端
  23. /// </summary>
  24. sealed class DnsClient
  25. {
  26. private const int DNS_PORT = 53;
  27. private const string LOCALHOST = "localhost";
  28. private readonly DnscryptProxy dnscryptProxy;
  29. private readonly FastGithubConfig fastGithubConfig;
  30. private readonly ILogger<DnsClient> logger;
  31. private readonly ConcurrentDictionary<string, IPAddressCollection> domainIPAddressCollection = new();
  32. private readonly ConcurrentDictionary<string, SemaphoreSlim> semaphoreSlims = new();
  33. private readonly IMemoryCache dnsCache = new MemoryCache(Options.Create(new MemoryCacheOptions()));
  34. private readonly TimeSpan defaultEmptyTtl = TimeSpan.FromSeconds(30d);
  35. private readonly int resolveTimeout = (int)TimeSpan.FromSeconds(2d).TotalMilliseconds;
  36. private record LookupResult(IPAddress[] Addresses, TimeSpan TimeToLive);
  37. /// <summary>
  38. /// DNS客户端
  39. /// </summary>
  40. /// <param name="dnscryptProxy"></param>
  41. /// <param name="fastGithubConfig"></param>
  42. /// <param name="logger"></param>
  43. public DnsClient(
  44. DnscryptProxy dnscryptProxy,
  45. FastGithubConfig fastGithubConfig,
  46. ILogger<DnsClient> logger)
  47. {
  48. this.dnscryptProxy = dnscryptProxy;
  49. this.fastGithubConfig = fastGithubConfig;
  50. this.logger = logger;
  51. }
  52. /// <summary>
  53. /// 解析域名
  54. /// </summary>
  55. /// <param name="domain">域名</param>
  56. /// <param name="cancellationToken"></param>
  57. /// <returns></returns>
  58. public async IAsyncEnumerable<IPAddress> ResolveAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
  59. {
  60. if (this.TryGetPingedIPAddresses(domain, out var addresses))
  61. {
  62. foreach (var address in addresses)
  63. {
  64. yield return address;
  65. }
  66. }
  67. else
  68. {
  69. this.domainIPAddressCollection.TryAdd(domain, new IPAddressCollection());
  70. await foreach (var adddress in this.ResolveCoreAsync(domain, cancellationToken))
  71. {
  72. yield return adddress;
  73. }
  74. }
  75. }
  76. /// <summary>
  77. /// 对所有域名所有IP进行ping测试
  78. /// </summary>
  79. /// <param name="cancellationToken"></param>
  80. /// <returns></returns>
  81. public async Task PingAllDomainsAsync(CancellationToken cancellationToken)
  82. {
  83. foreach (var keyValue in this.domainIPAddressCollection)
  84. {
  85. var domain = keyValue.Key;
  86. var collection = keyValue.Value;
  87. await foreach (var address in this.ResolveCoreAsync(domain, cancellationToken))
  88. {
  89. collection.Add(address);
  90. }
  91. await collection.PingAllAsync();
  92. }
  93. }
  94. /// <summary>
  95. /// 尝试获取域名下已经过ping排序的IP地址
  96. /// </summary>
  97. /// <param name="domain"></param>
  98. /// <param name="addresses"></param>
  99. /// <returns></returns>
  100. private bool TryGetPingedIPAddresses(string domain, [MaybeNullWhen(false)] out IPAddress[] addresses)
  101. {
  102. if (this.domainIPAddressCollection.TryGetValue(domain, out var collection) && collection.Count > 0)
  103. {
  104. addresses = collection.ToArray();
  105. return true;
  106. }
  107. addresses = default;
  108. return false;
  109. }
  110. /// <summary>
  111. /// 解析域名
  112. /// </summary>
  113. /// <param name="domain">域名</param>
  114. /// <param name="cancellationToken"></param>
  115. /// <returns></returns>
  116. private async IAsyncEnumerable<IPAddress> ResolveCoreAsync(string domain, [EnumeratorCancellation] CancellationToken cancellationToken)
  117. {
  118. var hashSet = new HashSet<IPAddress>();
  119. foreach (var dns in this.GetDnsServers())
  120. {
  121. var addresses = await this.LookupAsync(dns, domain, cancellationToken);
  122. foreach (var address in addresses)
  123. {
  124. if (hashSet.Add(address) == true)
  125. {
  126. yield return address;
  127. }
  128. }
  129. }
  130. }
  131. /// <summary>
  132. /// 获取dns服务
  133. /// </summary>
  134. /// <returns></returns>
  135. private IEnumerable<IPEndPoint> GetDnsServers()
  136. {
  137. var cryptDns = this.dnscryptProxy.LocalEndPoint;
  138. if (cryptDns != null)
  139. {
  140. yield return cryptDns;
  141. yield return cryptDns;
  142. }
  143. foreach (var fallbackDns in this.fastGithubConfig.FallbackDns)
  144. {
  145. yield return fallbackDns;
  146. }
  147. }
  148. /// <summary>
  149. /// 解析域名
  150. /// </summary>
  151. /// <param name="dns"></param>
  152. /// <param name="domain"></param>
  153. /// <param name="cancellationToken"></param>
  154. /// <returns></returns>
  155. private async Task<IPAddress[]> LookupAsync(IPEndPoint dns, string domain, CancellationToken cancellationToken = default)
  156. {
  157. var key = $"{dns}:{domain}";
  158. var semaphore = this.semaphoreSlims.GetOrAdd(key, _ => new SemaphoreSlim(1, 1));
  159. await semaphore.WaitAsync(CancellationToken.None);
  160. try
  161. {
  162. if (this.dnsCache.TryGetValue<IPAddress[]>(key, out var value))
  163. {
  164. return value;
  165. }
  166. var result = await this.LookupCoreAsync(dns, domain, cancellationToken);
  167. this.dnsCache.Set(key, result.Addresses, result.TimeToLive);
  168. var items = string.Join(", ", result.Addresses.Select(item => item.ToString()));
  169. this.logger.LogInformation($"dns://{dns}:{domain}->[{items}]");
  170. return result.Addresses;
  171. }
  172. catch (OperationCanceledException)
  173. {
  174. this.logger.LogInformation($"dns://{dns}无法解析{domain}:请求超时");
  175. return Array.Empty<IPAddress>();
  176. }
  177. catch (Exception ex)
  178. {
  179. this.logger.LogWarning($"dns://{dns}无法解析{domain}:{ex.Message}");
  180. return Array.Empty<IPAddress>();
  181. }
  182. finally
  183. {
  184. semaphore.Release();
  185. }
  186. }
  187. /// <summary>
  188. /// 解析域名
  189. /// </summary>
  190. /// <param name="dns"></param>
  191. /// <param name="domain"></param>
  192. /// <param name="cancellationToken"></param>
  193. /// <returns></returns>
  194. private async Task<LookupResult> LookupCoreAsync(IPEndPoint dns, string domain, CancellationToken cancellationToken = default)
  195. {
  196. if (domain == LOCALHOST)
  197. {
  198. return new LookupResult(new[] { IPAddress.Loopback }, TimeSpan.MaxValue);
  199. }
  200. var resolver = dns.Port == DNS_PORT
  201. ? (IRequestResolver)new TcpRequestResolver(dns)
  202. : new UdpRequestResolver(dns, new TcpRequestResolver(dns), this.resolveTimeout);
  203. var request = new Request
  204. {
  205. RecursionDesired = true,
  206. OperationCode = OperationCode.Query
  207. };
  208. request.Questions.Add(new Question(new Domain(domain), RecordType.A));
  209. var clientRequest = new ClientRequest(resolver, request);
  210. var response = await clientRequest.Resolve(cancellationToken);
  211. var addresses = response.AnswerRecords
  212. .OfType<IPAddressResourceRecord>()
  213. .Where(item => IPAddress.IsLoopback(item.IPAddress) == false)
  214. .Select(item => item.IPAddress)
  215. .ToArray();
  216. if (addresses.Length == 0)
  217. {
  218. return new LookupResult(addresses, this.defaultEmptyTtl);
  219. }
  220. if (addresses.Length > 1)
  221. {
  222. addresses = await OrderByPingAnyAsync(addresses);
  223. }
  224. var timeToLive = response.AnswerRecords.First().TimeToLive;
  225. if (timeToLive <= TimeSpan.Zero)
  226. {
  227. timeToLive = this.defaultEmptyTtl;
  228. }
  229. return new LookupResult(addresses, timeToLive);
  230. }
  231. /// <summary>
  232. /// ping排序
  233. /// </summary>
  234. /// <param name="addresses"></param>
  235. /// <returns></returns>
  236. private static async Task<IPAddress[]> OrderByPingAnyAsync(IPAddress[] addresses)
  237. {
  238. var fastedAddress = await await Task.WhenAny(addresses.Select(address => PingAsync(address)));
  239. if (fastedAddress == null)
  240. {
  241. return addresses;
  242. }
  243. var list = new List<IPAddress> { fastedAddress };
  244. foreach (var address in addresses)
  245. {
  246. if (address.Equals(fastedAddress) == false)
  247. {
  248. list.Add(address);
  249. }
  250. }
  251. return list.ToArray();
  252. }
  253. /// <summary>
  254. /// ping请求
  255. /// </summary>
  256. /// <param name="address"></param>
  257. /// <returns></returns>
  258. private static async Task<IPAddress?> PingAsync(IPAddress address)
  259. {
  260. try
  261. {
  262. using var ping = new Ping();
  263. var reply = await ping.SendPingAsync(address);
  264. return reply.Status == IPStatus.Success ? address : default;
  265. }
  266. catch (Exception)
  267. {
  268. return default;
  269. }
  270. }
  271. }
  272. }