DomainResolver.cs 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. using FastGithub.Configuration;
  2. using Microsoft.Extensions.Logging;
  3. using System;
  4. using System.Collections.Concurrent;
  5. using System.Collections.Generic;
  6. using System.Linq;
  7. using System.Net;
  8. using System.Runtime.CompilerServices;
  9. using System.Threading;
  10. using System.Threading.Tasks;
  11. namespace FastGithub.DomainResolve
  12. {
  13. /// <summary>
  14. /// 域名解析器
  15. /// </summary>
  16. sealed class DomainResolver : IDomainResolver
  17. {
  18. private readonly DnsClient dnsClient;
  19. private readonly DomainPersistence persistence;
  20. private readonly ILogger<DomainResolver> logger;
  21. private readonly ConcurrentDictionary<DnsEndPoint, IPAddressElapsed[]> dnsEndPointAddressElapseds = new();
  22. /// <summary>
  23. /// 域名解析器
  24. /// </summary>
  25. /// <param name="dnsClient"></param>
  26. /// <param name="persistence"></param>
  27. /// <param name="logger"></param>
  28. public DomainResolver(
  29. DnsClient dnsClient,
  30. DomainPersistence persistence,
  31. ILogger<DomainResolver> logger)
  32. {
  33. this.dnsClient = dnsClient;
  34. this.persistence = persistence;
  35. this.logger = logger;
  36. foreach (var endPoint in persistence.ReadDnsEndPoints())
  37. {
  38. this.dnsEndPointAddressElapseds.TryAdd(endPoint, Array.Empty<IPAddressElapsed>());
  39. }
  40. }
  41. /// <summary>
  42. /// 解析ip
  43. /// </summary>
  44. /// <param name="endPoint">节点</param>
  45. /// <param name="cancellationToken"></param>
  46. /// <returns></returns>
  47. public async Task<IPAddress> ResolveAnyAsync(DnsEndPoint endPoint, CancellationToken cancellationToken = default)
  48. {
  49. await foreach (var address in this.ResolveAllAsync(endPoint, cancellationToken))
  50. {
  51. return address;
  52. }
  53. throw new FastGithubException($"解析不到{endPoint.Host}的IP");
  54. }
  55. /// <summary>
  56. /// 解析域名
  57. /// </summary>
  58. /// <param name="endPoint">节点</param>
  59. /// <param name="cancellationToken"></param>
  60. /// <returns></returns>
  61. public async IAsyncEnumerable<IPAddress> ResolveAllAsync(DnsEndPoint endPoint, [EnumeratorCancellation] CancellationToken cancellationToken)
  62. {
  63. if (this.dnsEndPointAddressElapseds.TryGetValue(endPoint, out var addressElapseds) && addressElapseds.Length > 0)
  64. {
  65. foreach (var addressElapsed in addressElapseds)
  66. {
  67. yield return addressElapsed.Adddress;
  68. }
  69. }
  70. else
  71. {
  72. if (this.dnsEndPointAddressElapseds.TryAdd(endPoint, Array.Empty<IPAddressElapsed>()))
  73. {
  74. await this.persistence.WriteDnsEndPointsAsync(this.dnsEndPointAddressElapseds.Keys, cancellationToken);
  75. }
  76. await foreach (var adddress in this.dnsClient.ResolveAsync(endPoint, fastSort: true, cancellationToken))
  77. {
  78. yield return adddress;
  79. }
  80. }
  81. }
  82. /// <summary>
  83. /// 对所有节点进行测速
  84. /// </summary>
  85. /// <param name="cancellationToken"></param>
  86. /// <returns></returns>
  87. public async Task TestAllEndPointsAsync(CancellationToken cancellationToken)
  88. {
  89. foreach (var keyValue in this.dnsEndPointAddressElapseds)
  90. {
  91. var dnsEndPoint = keyValue.Key;
  92. var hashSet = new HashSet<IPAddressElapsed>();
  93. foreach (var item in keyValue.Value)
  94. {
  95. hashSet.Add(item);
  96. }
  97. await foreach (var adddress in this.dnsClient.ResolveAsync(dnsEndPoint, fastSort: false, cancellationToken))
  98. {
  99. hashSet.Add(new IPAddressElapsed(adddress, dnsEndPoint.Port));
  100. }
  101. var updateTasks = hashSet
  102. .Where(item => item.CanUpdateElapsed())
  103. .Select(item => item.UpdateElapsedAsync(cancellationToken));
  104. await Task.WhenAll(updateTasks);
  105. var addressElapseds = hashSet
  106. .Where(item => item.Elapsed < TimeSpan.MaxValue)
  107. .OrderBy(item => item.Elapsed)
  108. .ToArray();
  109. if (keyValue.Value.SequenceEqual(addressElapseds) == false)
  110. {
  111. var addressArray = string.Join(", ", addressElapseds.Select(item => item.ToString()));
  112. this.logger.LogInformation($"{dnsEndPoint.Host}->[{addressArray}]");
  113. }
  114. this.dnsEndPointAddressElapseds[dnsEndPoint] = addressElapseds;
  115. }
  116. }
  117. }
  118. }