using DNS.Protocol; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Logging; using System; using System.IO; using System.Linq; using System.Net; using System.Threading.Tasks; namespace FastGithub.Dns { /// /// DoH中间件 /// sealed class DnsOverHttpsMiddleware { private static readonly PathString dnsQueryPath = "/dns-query"; private const string MEDIA_TYPE = "application/dns-message"; private readonly RequestResolver requestResolver; private readonly ILogger logger; /// /// DoH中间件 /// /// /// public DnsOverHttpsMiddleware( RequestResolver requestResolver, ILogger logger) { this.requestResolver = requestResolver; this.logger = logger; } /// /// 执行请求 /// /// /// /// public async Task InvokeAsync(HttpContext context, RequestDelegate next) { Request? request; try { request = await ParseDnsRequestAsync(context.Request); } catch (Exception) { context.Response.StatusCode = StatusCodes.Status400BadRequest; return; } if (request == null) { await next(context); return; } var response = await this.ResolveAsync(context, request); context.Response.ContentType = MEDIA_TYPE; await context.Response.BodyWriter.WriteAsync(response.ToArray()); } /// /// 解析dns域名 /// /// /// /// private async Task ResolveAsync(HttpContext context, Request request) { try { var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback; var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort); var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint); return await this.requestResolver.Resolve(remoteEndPointRequest); } catch (Exception ex) { this.logger.LogWarning($"处理DNS异常:{ex.Message}"); return Response.FromRequest(request); } } /// /// 解析dns请求 /// /// /// private static async Task ParseDnsRequestAsync(HttpRequest request) { if (request.Path != dnsQueryPath || request.Headers.TryGetValue("accept", out var accept) == false || accept.Contains(MEDIA_TYPE) == false) { return default; } if (request.Method == HttpMethods.Get) { if (request.Query.TryGetValue("dns", out var dns) == false) { return default; } var dnsRequest = dns.ToString().Replace('-', '+').Replace('_', '/'); int mod = dnsRequest.Length % 4; if (mod > 0) { dnsRequest = dnsRequest.PadRight(dnsRequest.Length - mod + 4, '='); } var message = Convert.FromBase64String(dnsRequest); return Request.FromArray(message); } if (request.Method == HttpMethods.Post && request.ContentType == MEDIA_TYPE) { using var message = new MemoryStream(); await request.Body.CopyToAsync(message); return Request.FromArray(message.ToArray()); } return default; } } }