DnsOverHttpsMiddleware.cs 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. using DNS.Protocol;
  2. using Microsoft.AspNetCore.Http;
  3. using Microsoft.Extensions.Logging;
  4. using System;
  5. using System.IO;
  6. using System.Linq;
  7. using System.Net;
  8. using System.Threading.Tasks;
  9. namespace FastGithub.Dns
  10. {
  11. /// <summary>
  12. /// DoH中间件
  13. /// </summary>
  14. sealed class DnsOverHttpsMiddleware
  15. {
  16. private static readonly PathString dnsQueryPath = "/dns-query";
  17. private const string MEDIA_TYPE = "application/dns-message";
  18. private readonly RequestResolver requestResolver;
  19. private readonly ILogger<DnsOverHttpsMiddleware> logger;
  20. /// <summary>
  21. /// DoH中间件
  22. /// </summary>
  23. /// <param name="requestResolver"></param>
  24. /// <param name="logger"></param>
  25. public DnsOverHttpsMiddleware(
  26. RequestResolver requestResolver,
  27. ILogger<DnsOverHttpsMiddleware> logger)
  28. {
  29. this.requestResolver = requestResolver;
  30. this.logger = logger;
  31. }
  32. /// <summary>
  33. /// 执行请求
  34. /// </summary>
  35. /// <param name="context"></param>
  36. /// <param name="next"></param>
  37. /// <returns></returns>
  38. public async Task InvokeAsync(HttpContext context, RequestDelegate next)
  39. {
  40. Request? request;
  41. try
  42. {
  43. request = await ParseDnsRequestAsync(context.Request);
  44. }
  45. catch (Exception)
  46. {
  47. context.Response.StatusCode = StatusCodes.Status400BadRequest;
  48. return;
  49. }
  50. if (request == null)
  51. {
  52. await next(context);
  53. return;
  54. }
  55. var response = await this.ResolveAsync(context, request);
  56. context.Response.ContentType = MEDIA_TYPE;
  57. await context.Response.BodyWriter.WriteAsync(response.ToArray());
  58. }
  59. /// <summary>
  60. /// 解析dns域名
  61. /// </summary>
  62. /// <param name="context"></param>
  63. /// <param name="request"></param>
  64. /// <returns></returns>
  65. private async Task<IResponse> ResolveAsync(HttpContext context, Request request)
  66. {
  67. try
  68. {
  69. var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
  70. var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
  71. var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
  72. return await this.requestResolver.Resolve(remoteEndPointRequest);
  73. }
  74. catch (Exception ex)
  75. {
  76. this.logger.LogWarning($"处理DNS异常:{ex.Message}");
  77. return Response.FromRequest(request);
  78. }
  79. }
  80. /// <summary>
  81. /// 解析dns请求
  82. /// </summary>
  83. /// <param name="request"></param>
  84. /// <returns></returns>
  85. private static async Task<Request?> ParseDnsRequestAsync(HttpRequest request)
  86. {
  87. if (request.Path != dnsQueryPath ||
  88. request.Headers.TryGetValue("accept", out var accept) == false ||
  89. accept.Contains(MEDIA_TYPE) == false)
  90. {
  91. return default;
  92. }
  93. if (request.Method == HttpMethods.Get)
  94. {
  95. if (request.Query.TryGetValue("dns", out var dns) == false)
  96. {
  97. return default;
  98. }
  99. var dnsRequest = dns.ToString().Replace('-', '+').Replace('_', '/');
  100. int mod = dnsRequest.Length % 4;
  101. if (mod > 0)
  102. {
  103. dnsRequest = dnsRequest.PadRight(dnsRequest.Length - mod + 4, '=');
  104. }
  105. var message = Convert.FromBase64String(dnsRequest);
  106. return Request.FromArray(message);
  107. }
  108. if (request.Method == HttpMethods.Post && request.ContentType == MEDIA_TYPE)
  109. {
  110. using var message = new MemoryStream();
  111. await request.Body.CopyToAsync(message);
  112. return Request.FromArray(message.ToArray());
  113. }
  114. return default;
  115. }
  116. }
  117. }