Selaa lähdekoodia

增加DnsOverHttps服务

老九 4 vuotta sitten
vanhempi
commit
0f94f118ca

+ 23 - 0
FastGithub.Dns/DnsOverHttpsApplicationBuilderExtensions.cs

@@ -0,0 +1,23 @@
+using FastGithub.Dns;
+using Microsoft.AspNetCore.Builder;
+using Microsoft.Extensions.DependencyInjection;
+
+namespace FastGithub
+{
+    /// <summary>
+    /// DoH的中间件扩展
+    /// </summary>
+    public static class DnsOverHttpsApplicationBuilderExtensions
+    {
+        /// <summary>
+        /// 使用DoH的中间件
+        /// </summary>
+        /// <param name="app"></param> 
+        /// <returns></returns>
+        public static IApplicationBuilder UseDnsOverHttps(this IApplicationBuilder app)
+        {
+            var middleware = app.ApplicationServices.GetRequiredService<DnsOverHttpsMiddleware>();
+            return app.Use(next => context => middleware.InvokeAsync(context, next));
+        }
+    }
+}

+ 103 - 0
FastGithub.Dns/DnsOverHttpsMiddleware.cs

@@ -0,0 +1,103 @@
+using DNS.Protocol;
+using Microsoft.AspNetCore.Http;
+using System;
+using System.IO;
+using System.Linq;
+using System.Net;
+using System.Threading.Tasks;
+
+namespace FastGithub.Dns
+{
+    /// <summary>
+    /// DoH中间件
+    /// </summary>
+    sealed class DnsOverHttpsMiddleware
+    {
+        private static readonly PathString dnsQueryPath = "/dns-query";
+        private const string MEDIA_TYPE = "application/dns-message";
+        private readonly RequestResolver requestResolver;
+
+        /// <summary>
+        /// DoH中间件
+        /// </summary>
+        /// <param name="requestResolver"></param>
+        public DnsOverHttpsMiddleware(RequestResolver requestResolver)
+        {
+            this.requestResolver = requestResolver;
+        }
+
+        /// <summary>
+        /// 执行请求
+        /// </summary>
+        /// <param name="context"></param>
+        /// <param name="next"></param>
+        /// <returns></returns>
+        public async Task InvokeAsync(HttpContext context, RequestDelegate next)
+        {
+            try
+            {
+                var request = await ParseDnsRequestAsync(context.Request);
+                if (request == null)
+                {
+                    await next(context);
+                }
+                else
+                {
+                    var remoteIPAddress = context.Connection.RemoteIpAddress ?? IPAddress.Loopback;
+                    var remoteEndPoint = new IPEndPoint(remoteIPAddress, context.Connection.RemotePort);
+                    var remoteEndPointRequest = new RemoteEndPointRequest(request, remoteEndPoint);
+                    var response = await this.requestResolver.Resolve(remoteEndPointRequest);
+
+                    context.Response.ContentType = MEDIA_TYPE;
+                    await context.Response.BodyWriter.WriteAsync(response.ToArray());
+                }
+            }
+            catch (Exception)
+            {
+                await next(context);
+            }
+        }
+
+        /// <summary>
+        /// 解析dns请求
+        /// </summary>
+        /// <param name="request"></param>
+        /// <returns></returns>
+        private static async Task<Request?> ParseDnsRequestAsync(HttpRequest request)
+        {
+            if (request.Path != dnsQueryPath ||
+                request.Headers.TryGetValue("accept", out var accept) == false ||
+                accept.Contains(MEDIA_TYPE) == false)
+            {
+                return default;
+            }
+
+            if (request.Method == HttpMethods.Get)
+            {
+                if (request.Query.TryGetValue("dns", out var dns) == false)
+                {
+                    return default;
+                }
+
+                var dnsRequest = dns.ToString().Replace('-', '+').Replace('_', '/');
+                int mod = dnsRequest.Length % 4;
+                if (mod > 0)
+                {
+                    dnsRequest = dnsRequest.PadRight(dnsRequest.Length - mod + 4, '=');
+                }
+
+                var message = Convert.FromBase64String(dnsRequest);
+                return Request.FromArray(message);
+            }
+
+            if (request.Method == HttpMethods.Post && request.ContentType == MEDIA_TYPE)
+            {
+                using var message = new MemoryStream();
+                await request.Body.CopyToAsync(message);
+                return Request.FromArray(message.ToArray());
+            }
+
+            return default;
+        }
+    }
+}

+ 1 - 0
FastGithub.Dns/FastGithub.Dns.csproj

@@ -5,6 +5,7 @@
 	</PropertyGroup>
 
 	<ItemGroup>
+		<FrameworkReference Include="Microsoft.AspNetCore.App" />
 		<PackageReference Include="DNS" Version="6.1.0" />
 		<PackageReference Include="Microsoft.Extensions.Hosting" Version="5.0.0" />
 	</ItemGroup>

+ 1 - 0
FastGithub.Dns/ServiceCollectionExtensions.cs

@@ -18,6 +18,7 @@ namespace FastGithub
         {
             services.TryAddSingleton<RequestResolver>();
             services.TryAddSingleton<DnsServer>();
+            services.TryAddSingleton<DnsOverHttpsMiddleware>();
             services.AddSingleton<IDnsValidator, HostsValidator>();
             services.AddSingleton<IDnsValidator, ProxyValidtor>();
             return services.AddHostedService<DnsHostedService>();

+ 1 - 0
FastGithub/Startup.cs

@@ -45,6 +45,7 @@ namespace FastGithub
         public void Configure(IApplicationBuilder app)
         {
             app.UseRequestLogging();
+            app.UseDnsOverHttps();
             app.UseReverseProxy();
             app.UseRouting();
             app.UseEndpoints(endpoints =>