|
@@ -1,5 +1,6 @@
|
|
using DNS.Protocol;
|
|
using DNS.Protocol;
|
|
using Microsoft.AspNetCore.Http;
|
|
using Microsoft.AspNetCore.Http;
|
|
|
|
+using Microsoft.Extensions.Logging;
|
|
using System;
|
|
using System;
|
|
using System.IO;
|
|
using System.IO;
|
|
using System.Linq;
|
|
using System.Linq;
|
|
@@ -16,14 +17,19 @@ namespace FastGithub.Dns
|
|
private static readonly PathString dnsQueryPath = "/dns-query";
|
|
private static readonly PathString dnsQueryPath = "/dns-query";
|
|
private const string MEDIA_TYPE = "application/dns-message";
|
|
private const string MEDIA_TYPE = "application/dns-message";
|
|
private readonly RequestResolver requestResolver;
|
|
private readonly RequestResolver requestResolver;
|
|
|
|
+ private readonly ILogger<DnsOverHttpsMiddleware> logger;
|
|
|
|
|
|
/// <summary>
|
|
/// <summary>
|
|
/// DoH中间件
|
|
/// DoH中间件
|
|
/// </summary>
|
|
/// </summary>
|
|
/// <param name="requestResolver"></param>
|
|
/// <param name="requestResolver"></param>
|
|
- public DnsOverHttpsMiddleware(RequestResolver requestResolver)
|
|
|
|
|
|
+ /// <param name="logger"></param>
|
|
|
|
+ public DnsOverHttpsMiddleware(
|
|
|
|
+ RequestResolver requestResolver,
|
|
|
|
+ ILogger<DnsOverHttpsMiddleware> logger)
|
|
{
|
|
{
|
|
this.requestResolver = requestResolver;
|
|
this.requestResolver = requestResolver;
|
|
|
|
+ this.logger = logger;
|
|
}
|
|
}
|
|
|
|
|
|
/// <summary>
|
|
/// <summary>
|
|
@@ -34,27 +40,47 @@ namespace FastGithub.Dns
|
|
/// <returns></returns>
|
|
/// <returns></returns>
|
|
public async Task InvokeAsync(HttpContext context, RequestDelegate next)
|
|
public async Task InvokeAsync(HttpContext context, RequestDelegate next)
|
|
{
|
|
{
|
|
|
|
+ Request? request;
|
|
try
|
|
try
|
|
{
|
|
{
|
|
- var request = await ParseDnsRequestAsync(context.Request);
|
|
|
|
- if (request == null)
|
|
|
|
- {
|
|
|
|
- await next(context);
|
|
|
|
- }
|
|
|
|
- else
|
|
|
|
- {
|
|
|
|
- var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
|
|
|
|
- var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
|
|
|
|
- var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
|
|
|
|
- var response = await this.requestResolver.Resolve(remoteEndPointRequest);
|
|
|
|
-
|
|
|
|
- context.Response.ContentType = MEDIA_TYPE;
|
|
|
|
- await context.Response.BodyWriter.WriteAsync(response.ToArray());
|
|
|
|
- }
|
|
|
|
|
|
+ request = await ParseDnsRequestAsync(context.Request);
|
|
}
|
|
}
|
|
catch (Exception)
|
|
catch (Exception)
|
|
|
|
+ {
|
|
|
|
+ context.Response.StatusCode = StatusCodes.Status400BadRequest;
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ if (request == null)
|
|
{
|
|
{
|
|
await next(context);
|
|
await next(context);
|
|
|
|
+ return;
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ var response = await this.ResolveAsync(context, request);
|
|
|
|
+ context.Response.ContentType = MEDIA_TYPE;
|
|
|
|
+ await context.Response.BodyWriter.WriteAsync(response.ToArray());
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ /// <summary>
|
|
|
|
+ /// 解析dns域名
|
|
|
|
+ /// </summary>
|
|
|
|
+ /// <param name="context"></param>
|
|
|
|
+ /// <param name="request"></param>
|
|
|
|
+ /// <returns></returns>
|
|
|
|
+ private async Task<IResponse> ResolveAsync(HttpContext context, Request request)
|
|
|
|
+ {
|
|
|
|
+ try
|
|
|
|
+ {
|
|
|
|
+ var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
|
|
|
|
+ var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
|
|
|
|
+ var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
|
|
|
|
+ return await this.requestResolver.Resolve(remoteEndPointRequest);
|
|
|
|
+ }
|
|
|
|
+ catch (Exception ex)
|
|
|
|
+ {
|
|
|
|
+ this.logger.LogWarning($"处理DNS异常:{ex.Message}");
|
|
|
|
+ return Response.FromRequest(request);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|