DnsInterceptor.cs 6.7 KB


  1. using DNS.Protocol;
  2. using DNS.Protocol.ResourceRecords;
  3. using FastGithub.Configuration;
  4. using Microsoft.Extensions.Logging;
  5. using Microsoft.Extensions.Options;
  6. using System;
  7. using System.ComponentModel;
  8. using System.Diagnostics.CodeAnalysis;
  9. using System.Linq;
  10. using System.Net;
  11. using System.Runtime.InteropServices;
  12. using System.Runtime.Versioning;
  13. using System.Threading;
  14. using System.Threading.Tasks;
  15. using WindivertDotnet;
  16. namespace FastGithub.PacketIntercept.Dns
  17. {
  18. /// <summary>
  19. /// dns拦截器
  20. /// </summary>
  21. [SupportedOSPlatform("windows")]
  22. sealed class DnsInterceptor : IDnsInterceptor
  23. {
  24. private static readonly Filter filter = Filter.True.And(f => f.Udp.DstPort == 53);
  25. private readonly FastGithubConfig fastGithubConfig;
  26. private readonly ILogger<DnsInterceptor> logger;
  27. private readonly TimeSpan ttl = TimeSpan.FromMinutes(5d);
  28. /// <summary>
  29. /// 刷新DNS缓存
  30. /// </summary>
  31. [DllImport("dnsapi.dll", EntryPoint = "DnsFlushResolverCache", SetLastError = true)]
  32. private static extern void DnsFlushResolverCache();
  33. /// <summary>
  34. /// 首次加载驱动往往有异常,所以要提前加载
  35. /// </summary>
  36. static DnsInterceptor()
  37. {
  38. try
  39. {
  40. using (new WinDivert(Filter.False, WinDivertLayer.Network)) { }
  41. }
  42. catch (Exception) { }
  43. }
  44. /// <summary>
  45. /// dns拦截器
  46. /// </summary>
  47. /// <param name="fastGithubConfig"></param>
  48. /// <param name="logger"></param>
  49. /// <param name="options"></param>
  50. public DnsInterceptor(
  51. FastGithubConfig fastGithubConfig,
  52. ILogger<DnsInterceptor> logger,
  53. IOptionsMonitor<FastGithubOptions> options)
  54. {
  55. this.fastGithubConfig = fastGithubConfig;
  56. this.logger = logger;
  57. options.OnChange(_ => DnsFlushResolverCache());
  58. }
  59. /// <summary>
  60. /// DNS拦截
  61. /// </summary>
  62. /// <param name="cancellationToken"></param>
  63. /// <exception cref="Win32Exception"></exception>
  64. /// <returns></returns>
  65. public async Task InterceptAsync(CancellationToken cancellationToken)
  66. {
  67. await Task.Yield();
  68. using var divert = new WinDivert(filter, WinDivertLayer.Network);
  69. cancellationToken.Register(d =>
  70. {
  71. ((WinDivert)d!).Dispose();
  72. DnsFlushResolverCache();
  73. }, divert);
  74. var addr = new WinDivertAddress();
  75. using var packet = new WinDivertPacket();
  76. DnsFlushResolverCache();
  77. while (cancellationToken.IsCancellationRequested == false)
  78. {
  79. divert.Recv(packet, ref addr);
  80. try
  81. {
  82. this.ModifyDnsPacket(packet, ref addr);
  83. }
  84. catch (Exception ex)
  85. {
  86. this.logger.LogWarning(ex.Message);
  87. }
  88. finally
  89. {
  90. divert.Send(packet, ref addr);
  91. }
  92. }
  93. }
  94. /// <summary>
  95. /// 修改DNS数据包
  96. /// </summary>
  97. /// <param name="packet"></param>
  98. /// <param name="addr"></param>
  99. unsafe private void ModifyDnsPacket(WinDivertPacket packet, ref WinDivertAddress addr)
  100. {
  101. var result = packet.GetParseResult();
  102. var requestPayload = result.DataSpan.ToArray();
  103. if (TryParseRequest(requestPayload, out var request) == false ||
  104. request.OperationCode != OperationCode.Query ||
  105. request.Questions.Count == 0)
  106. {
  107. return;
  108. }
  109. var question = request.Questions.First();
  110. if (question.Type != RecordType.A && question.Type != RecordType.AAAA)
  111. {
  112. return;
  113. }
  114. var domain = question.Name;
  115. if (this.fastGithubConfig.IsMatch(question.Name.ToString()) == false)
  116. {
  117. return;
  118. }
  119. // dns响应数据
  120. var response = Response.FromRequest(request);
  121. var loopback = question.Type == RecordType.A ? IPAddress.Loopback : IPAddress.IPv6Loopback;
  122. var record = new IPAddressResourceRecord(domain, loopback, this.ttl);
  123. response.AnswerRecords.Add(record);
  124. var responsePayload = response.ToArray();
  125. // 修改payload和包长
  126. responsePayload.CopyTo(new Span<byte>(result.Data, responsePayload.Length));
  127. packet.Length = packet.Length + responsePayload.Length - requestPayload.Length;
  128. // 修改ip包
  129. IPAddress destAddress;
  130. if (result.IPV4Header != null)
  131. {
  132. destAddress = result.IPV4Header->DstAddr;
  133. result.IPV4Header->DstAddr = result.IPV4Header->SrcAddr;
  134. result.IPV4Header->SrcAddr = destAddress;
  135. result.IPV4Header->Length = (ushort)packet.Length;
  136. }
  137. else
  138. {
  139. destAddress = result.IPV6Header->DstAddr;
  140. result.IPV6Header->DstAddr = result.IPV6Header->SrcAddr;
  141. result.IPV6Header->SrcAddr = destAddress;
  142. result.IPV6Header->Length = (ushort)(packet.Length - sizeof(IPV6Header));
  143. }
  144. // 修改udp包
  145. var destPort = result.UdpHeader->DstPort;
  146. result.UdpHeader->DstPort = result.UdpHeader->SrcPort;
  147. result.UdpHeader->SrcPort = destPort;
  148. result.UdpHeader->Length = (ushort)(sizeof(UdpHeader) + responsePayload.Length);
  149. addr.Flags |= WinDivertAddressFlag.Impostor;
  150. if (addr.Flags.HasFlag(WinDivertAddressFlag.Loopback))
  151. {
  152. addr.Flags |= WinDivertAddressFlag.Outbound;
  153. }
  154. else
  155. {
  156. addr.Flags ^= WinDivertAddressFlag.Outbound;
  157. }
  158. packet.CalcChecksums(ref addr);
  159. this.logger.LogInformation($"{domain}->{loopback}");
  160. }
  161. /// <summary>
  162. /// 尝试解析请求
  163. /// </summary>
  164. /// <param name="payload"></param>
  165. /// <param name="request"></param>
  166. /// <returns></returns>
  167. static bool TryParseRequest(byte[] payload, [MaybeNullWhen(false)] out Request request)
  168. {
  169. try
  170. {
  171. request = Request.FromArray(payload);
  172. return true;
  173. }
  174. catch (Exception)
  175. {
  176. request = null;
  177. return false;
  178. }
  179. }
  180. }
  181. }