陈国伟 преди 4 години
родител
ревизия
97727944ff
променени са 2 файла, в които са добавени 43 реда и са изтрити 17 реда
  1. 42 16
      FastGithub.Dns/DnsOverHttpsMiddleware.cs
  2. 1 1
      FastGithub.DomainResolve/DomainResolver.cs

+ 42 - 16
FastGithub.Dns/DnsOverHttpsMiddleware.cs

@@ -1,5 +1,6 @@
 using DNS.Protocol;
 using Microsoft.AspNetCore.Http;
+using Microsoft.Extensions.Logging;
 using System;
 using System.IO;
 using System.Linq;
@@ -16,14 +17,19 @@ namespace FastGithub.Dns
         private static readonly PathString dnsQueryPath = "/dns-query";
         private const string MEDIA_TYPE = "application/dns-message";
         private readonly RequestResolver requestResolver;
+        private readonly ILogger<DnsOverHttpsMiddleware> logger;
 
         /// <summary>
         /// DoH中间件
         /// </summary>
         /// <param name="requestResolver"></param>
-        public DnsOverHttpsMiddleware(RequestResolver requestResolver)
+        /// <param name="logger"></param>
+        public DnsOverHttpsMiddleware(
+            RequestResolver requestResolver,
+            ILogger<DnsOverHttpsMiddleware> logger)
         {
             this.requestResolver = requestResolver;
+            this.logger = logger;
         }
 
         /// <summary>
@@ -34,27 +40,47 @@ namespace FastGithub.Dns
         /// <returns></returns>
         public async Task InvokeAsync(HttpContext context, RequestDelegate next)
         {
+            Request? request;
             try
             {
-                var request = await ParseDnsRequestAsync(context.Request);
-                if (request == null)
-                {
-                    await next(context);
-                }
-                else
-                {
-                    var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
-                    var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
-                    var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
-                    var response = await this.requestResolver.Resolve(remoteEndPointRequest);
-
-                    context.Response.ContentType = MEDIA_TYPE;
-                    await context.Response.BodyWriter.WriteAsync(response.ToArray());
-                }
+                request = await ParseDnsRequestAsync(context.Request);
             }
             catch (Exception)
+            {
+                context.Response.StatusCode = StatusCodes.Status400BadRequest;
+                return;
+            }
+
+            if (request == null)
             {
                 await next(context);
+                return;
+            }
+
+            var response = await this.ResolveAsync(context, request);
+            context.Response.ContentType = MEDIA_TYPE;
+            await context.Response.BodyWriter.WriteAsync(response.ToArray());
+        }
+
+        /// <summary>
+        /// 解析dns域名
+        /// </summary>
+        /// <param name="context"></param>
+        /// <param name="request"></param>
+        /// <returns></returns>
+        private async Task<IResponse> ResolveAsync(HttpContext context, Request request)
+        {
+            try
+            {
+                var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
+                var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
+                var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
+                return await this.requestResolver.Resolve(remoteEndPointRequest);
+            }
+            catch (Exception ex)
+            {
+                this.logger.LogWarning($"处理DNS异常:{ex.Message}");
+                return Response.FromRequest(request);
             }
         }
 

+ 1 - 1
FastGithub.DomainResolve/DomainResolver.cs

@@ -61,7 +61,7 @@ namespace FastGithub.DomainResolve
             var semaphore = this.semaphoreSlims.GetOrAdd(domain, _ => new SemaphoreSlim(1, 1));
             try
             {
-                await semaphore.WaitAsync(cancellationToken);
+                await semaphore.WaitAsync();
                 return await this.LookupAsync(domain, cancellationToken);
             }
             finally