Browse Source

向系统插入自身做为主DNS

xljiulang 4 years ago
parent
commit
4ae101a76f
2 changed files with 112 additions and 96 deletions
  1. 50 67
      FastGithub.Dns/DnsServerHostedService.cs
  2. 62 29
      FastGithub.Dns/SystemDnsUtil.cs

+ 50 - 67
FastGithub.Dns/DnsServerHostedService.cs

@@ -6,8 +6,6 @@ using System;
 using System.Diagnostics;
 using System.Net;
 using System.Net.Sockets;
-using System.Runtime.InteropServices;
-using System.Runtime.Versioning;
 using System.Threading;
 using System.Threading.Tasks;
 
@@ -19,50 +17,37 @@ namespace FastGithub.Dns
     sealed class DnsServerHostedService : BackgroundService
     {
         private readonly RequestResolver requestResolver;
-        private readonly FastGithubConfig fastGithubConfig;
         private readonly HostsValidator hostsValidator;
         private readonly ILogger<DnsServerHostedService> logger;
 
         private readonly Socket socket = new(AddressFamily.InterNetwork, SocketType.Dgram, ProtocolType.Udp);
         private readonly byte[] buffer = new byte[ushort.MaxValue];
-        private IPAddress[]? dnsAddresses;
 
-        [SupportedOSPlatform("windows")]
-        [DllImport("dnsapi.dll", EntryPoint = "DnsFlushResolverCache", SetLastError = true)]
-        private static extern void DnsFlushResolverCache();
 
         /// <summary>
         /// dns后台服务
         /// </summary>
         /// <param name="requestResolver"></param>
-        /// <param name="fastGithubConfig"></param>
+        /// <param name="hostsValidator"></param>
         /// <param name="options"></param>
         /// <param name="logger"></param>
         public DnsServerHostedService(
             RequestResolver requestResolver,
-            FastGithubConfig fastGithubConfig,
             HostsValidator hostsValidator,
             IOptionsMonitor<FastGithubOptions> options,
             ILogger<DnsServerHostedService> logger)
         {
             this.requestResolver = requestResolver;
-            this.fastGithubConfig = fastGithubConfig;
             this.hostsValidator = hostsValidator;
             this.logger = logger;
-            options.OnChange(opt => FlushResolverCache());
-        }
 
-        /// <summary>
-        /// 刷新dns缓存
-        /// </summary>
-        private static void FlushResolverCache()
-        {
             if (OperatingSystem.IsWindows())
             {
-                DnsFlushResolverCache();
+                options.OnChange(opt => SystemDnsUtil.DnsFlushResolverCache());
             }
         }
 
+
         /// <summary>
         /// 启动dns
         /// </summary>
@@ -70,57 +55,75 @@ namespace FastGithub.Dns
         /// <returns></returns>
         public override async Task StartAsync(CancellationToken cancellationToken)
         {
-            const int DNS_PORT = 53;
-            if (OperatingSystem.IsWindows() && UdpTable.TryGetOwnerProcessId(DNS_PORT, out var processId))
-            {
-                Process.GetProcessById(processId).Kill();
-            }
-
-            await BindAsync(this.socket, new IPEndPoint(IPAddress.Any, DNS_PORT), cancellationToken);
-            if (OperatingSystem.IsWindows())
-            {
-                const int SIO_UDP_CONNRESET = unchecked((int)0x9800000C);
-                this.socket.IOControl(SIO_UDP_CONNRESET, new byte[4], new byte[4]);
-            }
-
-            // 验证host文件 
+            await this.BindAsync(cancellationToken);
             await this.hostsValidator.ValidateAsync();
-
-            // 设置网关的dns
-            var secondary = this.fastGithubConfig.FastDns.Address;
-            this.dnsAddresses = this.SetNameServers(IPAddress.Loopback, secondary);
-            FlushResolverCache();
-
+            this.SetAsPrimitiveNameServer();
             this.logger.LogInformation("dns服务启动成功");
+
             await base.StartAsync(cancellationToken);
         }
 
         /// <summary>
         /// 尝试多次绑定
         /// </summary>
-        /// <param name="socket"></param>
-        /// <param name="localEndPoint"></param>
         /// <param name="cancellationToken"></param>
         /// <returns></returns>
-        private static async Task BindAsync(Socket socket, IPEndPoint localEndPoint, CancellationToken cancellationToken)
+        private async Task BindAsync(CancellationToken cancellationToken)
         {
+            const int DNS_PORT = 53;
+            if (OperatingSystem.IsWindows() && UdpTable.TryGetOwnerProcessId(DNS_PORT, out var processId))
+            {
+                Process.GetProcessById(processId).Kill();
+            }
+
+            var localEndPoint = new IPEndPoint(IPAddress.Any, DNS_PORT);
             var delay = TimeSpan.FromMilliseconds(100d);
             for (var i = 10; i >= 0; i--)
             {
                 try
                 {
-                    socket.Bind(localEndPoint);
+                    this.socket.Bind(localEndPoint);
                     break;
                 }
                 catch (Exception)
                 {
                     if (i == 0)
                     {
-                        throw new FastGithubException($"无法监听{localEndPoint},{localEndPoint.Port}的udp端口已被其它程序占用");
+                        throw new FastGithubException($"无法监听{localEndPoint},udp端口已被其它程序占用");
                     }
                     await Task.Delay(delay, cancellationToken);
                 }
             }
+
+            if (OperatingSystem.IsWindows())
+            {
+                const int SIO_UDP_CONNRESET = unchecked((int)0x9800000C);
+                this.socket.IOControl(SIO_UDP_CONNRESET, new byte[4], new byte[4]);
+            }
+        }
+
+        /// <summary>
+        /// 设置自身为主dns
+        /// </summary>
+        private void SetAsPrimitiveNameServer()
+        {
+            if (OperatingSystem.IsWindows())
+            {
+                try
+                {
+                    SystemDnsUtil.DnsSetPrimitive(IPAddress.Loopback);
+                    SystemDnsUtil.DnsFlushResolverCache();
+                    this.logger.LogInformation($"设置为本机dns成功");
+                }
+                catch (Exception ex)
+                {
+                    this.logger.LogWarning($"设置为本机dns失败:{ex.Message}");
+                }
+            }
+            else
+            {
+                this.logger.LogWarning("平台不支持自动设置dns,请手动设置网卡的主dns为127.0.0.1");
+            }
         }
 
         /// <summary>
@@ -178,40 +181,20 @@ namespace FastGithub.Dns
             this.socket.Dispose();
             this.logger.LogInformation("dns服务已终止");
 
-            if (this.dnsAddresses != null)
-            {
-                this.SetNameServers(this.dnsAddresses);
-            }
-            FlushResolverCache();
-            return base.StopAsync(cancellationToken);
-        }
-
-        /// <summary>
-        /// 设置dns
-        /// </summary>
-        /// <param name="nameServers"></param>
-        /// <returns></returns>
-        private IPAddress[]? SetNameServers(params IPAddress[] nameServers)
-        {
             if (OperatingSystem.IsWindows())
             {
                 try
                 {
-                    var results = SystemDnsUtil.SetNameServers(nameServers);
-                    this.logger.LogInformation($"设置本机dns成功");
-                    return results;
+                    SystemDnsUtil.DnsFlushResolverCache();
+                    SystemDnsUtil.DnsRemovePrimitive(IPAddress.Loopback);
                 }
                 catch (Exception ex)
                 {
-                    this.logger.LogWarning($"设置本机dns失败:{ex.Message}");
+                    this.logger.LogWarning($"恢复DNS记录失败:{ex.Message}");
                 }
             }
-            else
-            {
-                this.logger.LogWarning("不支持自动设置dns,请手动设置网卡的dns为127.0.0.1");
-            }
 
-            return default;
+            return base.StopAsync(cancellationToken);
         }
     }
 }

+ 62 - 29
FastGithub.Dns/SystemDnsUtil.cs

@@ -1,4 +1,5 @@
 using System;
+using System.Collections.Generic;
 using System.Diagnostics;
 using System.Linq;
 using System.Net;
@@ -22,64 +23,96 @@ namespace FastGithub.Dns
         [DllImport("iphlpapi")]
         private static extern int GetBestInterface(uint dwDestAddr, ref uint pdwBestIfIndex);
 
+        /// <summary>
+        /// 刷新DNS缓存
+        /// </summary>
+        [DllImport("dnsapi.dll", EntryPoint = "DnsFlushResolverCache", SetLastError = true)]
+        public static extern void DnsFlushResolverCache();
+
+
         /// <summary>
         /// 通过远程地址查找匹配的网络适接口
         /// </summary>
         /// <param name="remoteAddress"></param>
         /// <returns></returns>
-        private static NetworkInterface? GetBestNetworkInterface(IPAddress remoteAddress)
+        private static NetworkInterface GetBestNetworkInterface(IPAddress remoteAddress)
         {
             var dwBestIfIndex = 0u;
             var dwDestAddr = BitConverter.ToUInt32(remoteAddress.GetAddressBytes());
             var errorCode = GetBestInterface(dwDestAddr, ref dwBestIfIndex);
-            return errorCode != 0
-                ? throw new NetworkInformationException(errorCode)
-                : NetworkInterface
+            if (errorCode != 0)
+            {
+                throw new NetworkInformationException(errorCode);
+            }
+
+            var @interface = NetworkInterface
                 .GetAllNetworkInterfaces()
                 .Where(item => item.GetIPProperties().GetIPv4Properties().Index == dwBestIfIndex)
                 .FirstOrDefault();
+
+            return @interface ?? throw new NotSupportedException("找不到网络适配器用来设置dns");
         }
 
+
         /// <summary>
-        /// 设置域名服务
+        /// 设置主dns
         /// </summary>
-        /// <param name="nameServers"></param>
+        /// <param name="primitive"></param>
         /// <exception cref="NetworkInformationException"></exception>
-        /// <exception cref="NotSupportedException"></exception>
-        /// <returns>未设置之前的记录</returns>
-        public static IPAddress[] SetNameServers(params IPAddress[] nameServers)
+        /// <exception cref="NotSupportedException"></exception> 
+        public static void DnsSetPrimitive(IPAddress primitive)
         {
-            var networkInterface = GetBestNetworkInterface(www_baidu_com);
-            if (networkInterface == null)
+            var @interface = GetBestNetworkInterface(www_baidu_com);
+            var dnsAddresses = @interface.GetIPProperties().DnsAddresses;
+            if (primitive.Equals(dnsAddresses.FirstOrDefault()) == false)
             {
-                throw new NotSupportedException("找不到网络适配器用来设置dns");
+                var nameServers = dnsAddresses.Prepend(primitive);
+                SetNameServers(@interface, nameServers);
             }
-            var dnsAddresses = networkInterface.GetIPProperties().DnsAddresses.ToArray();
+        }
 
-            Netsh($@"interface ipv4 delete dns ""{networkInterface.Name}"" all");
-            foreach (var address in nameServers)
+        /// <summary>
+        /// 移除主dns
+        /// </summary>
+        /// <param name="primitive"></param>
+        /// <exception cref="NetworkInformationException"></exception>
+        /// <exception cref="NotSupportedException"></exception> 
+        public static void DnsRemovePrimitive(IPAddress primitive)
+        {
+            var @interface = GetBestNetworkInterface(www_baidu_com);
+            var dnsAddresses = @interface.GetIPProperties().DnsAddresses;
+            if (primitive.Equals(dnsAddresses.FirstOrDefault()))
             {
-                Netsh($@"interface ipv4 add dns ""{networkInterface.Name}"" {address} validate=no");
+                var nameServers = dnsAddresses.Skip(1);
+                SetNameServers(@interface, nameServers);
             }
-
-            return dnsAddresses;
         }
 
         /// <summary>
-        /// 执行Netsh
+        /// 设置网口的dns
         /// </summary>
-        /// <param name="arguments"></param>
-        private static void Netsh(string arguments)
+        /// <param name="interface"></param>
+        /// <param name="nameServers"></param>
+        private static void SetNameServers(NetworkInterface @interface, IEnumerable<IPAddress> nameServers)
         {
-            var netsh = new ProcessStartInfo
+            Netsh($@"interface ipv4 delete dns ""{@interface.Name}"" all");
+            foreach (var address in nameServers)
             {
-                FileName = "netsh.exe",
-                Arguments = arguments,
-                CreateNoWindow = true,
-                UseShellExecute = false,
-                WindowStyle = ProcessWindowStyle.Hidden
-            };
-            Process.Start(netsh)?.WaitForExit();
+                Netsh($@"interface ipv4 add dns ""{@interface.Name}"" {address} validate=no");
+            }
+
+            static void Netsh(string arguments)
+            {
+                var netsh = new ProcessStartInfo
+                {
+                    FileName = "netsh.exe",
+                    Arguments = arguments,
+                    CreateNoWindow = true,
+                    UseShellExecute = false,
+                    WindowStyle = ProcessWindowStyle.Hidden
+                };
+                Process.Start(netsh)?.WaitForExit();
+            }
         }
     }
 }