陈国伟 2 anni fa
parent
commit
e493765433

+ 46 - 49
FastGithub.PacketIntercept/Dns/DnsInterceptor.cs

@@ -1,7 +1,6 @@
 using DNS.Protocol;
 using DNS.Protocol.ResourceRecords;
 using FastGithub.Configuration;
-using FastGithub.WinDiverts;
 using Microsoft.Extensions.Logging;
 using Microsoft.Extensions.Options;
 using System;
@@ -13,6 +12,7 @@ using System.Runtime.InteropServices;
 using System.Runtime.Versioning;
 using System.Threading;
 using System.Threading.Tasks;
+using WindivertDotnet;
 
 namespace FastGithub.PacketIntercept.Dns
 {
@@ -22,7 +22,7 @@ namespace FastGithub.PacketIntercept.Dns
     [SupportedOSPlatform("windows")]
     sealed class DnsInterceptor : IDnsInterceptor
     {
-        private const string DNS_FILTER = "udp.DstPort == 53";
+        private static readonly Filter filter = Filter.True.And(f => f.Udp.DstPort == 53);
 
         private readonly FastGithubConfig fastGithubConfig;
         private readonly ILogger<DnsInterceptor> logger;
@@ -40,8 +40,11 @@ namespace FastGithub.PacketIntercept.Dns
         /// </summary>
         static DnsInterceptor()
         {
-            var handle = WinDivert.WinDivertOpen("false", WinDivertLayer.Network, 0, WinDivertOpenFlags.None);
-            WinDivert.WinDivertClose(handle);
+            try
+            {
+                using (new WinDivert(Filter.False, WinDivertLayer.Network)) { }
+            }
+            catch (Exception) { }
         }
 
         /// <summary>
@@ -71,33 +74,23 @@ namespace FastGithub.PacketIntercept.Dns
         {
             await Task.Yield();
 
-            var handle = WinDivert.WinDivertOpen(DNS_FILTER, WinDivertLayer.Network, 0, WinDivertOpenFlags.None);
-            if (handle == new IntPtr(unchecked((long)ulong.MaxValue)))
+            using var divert = new WinDivert(filter, WinDivertLayer.Network);
+            cancellationToken.Register(d =>
             {
-                throw new Win32Exception();
-            }
-
-            cancellationToken.Register(hwnd =>
-            {
-                WinDivert.WinDivertClose((IntPtr)hwnd!);
+                ((WinDivert)d!).Dispose();
                 DnsFlushResolverCache();
-            }, handle);
+            }, divert);
 
-            var packetLength = 0U;
-            using var winDivertBuffer = new WinDivertBuffer();
-            var winDivertAddress = new WinDivertAddress();
+            var addr = new WinDivertAddress();
+            using var packet = new WinDivertPacket();
 
             DnsFlushResolverCache();
             while (cancellationToken.IsCancellationRequested == false)
             {
-                if (WinDivert.WinDivertRecv(handle, winDivertBuffer, ref winDivertAddress, ref packetLength) == false)
-                {
-                    throw new Win32Exception();
-                }
-
+                divert.Recv(packet, ref addr);
                 try
                 {
-                    this.ModifyDnsPacket(winDivertBuffer, ref winDivertAddress, ref packetLength);
+                    this.ModifyDnsPacket(packet, ref addr);
                 }
                 catch (Exception ex)
                 {
@@ -105,7 +98,7 @@ namespace FastGithub.PacketIntercept.Dns
                 }
                 finally
                 {
-                    WinDivert.WinDivertSend(handle, winDivertBuffer, packetLength, ref winDivertAddress);
+                    divert.Send(packet, ref addr);
                 }
             }
         }
@@ -113,13 +106,12 @@ namespace FastGithub.PacketIntercept.Dns
         /// <summary>
         /// 修改DNS数据包
         /// </summary>
-        /// <param name="winDivertBuffer"></param>
-        /// <param name="winDivertAddress"></param>
-        /// <param name="packetLength"></param> 
-        unsafe private void ModifyDnsPacket(WinDivertBuffer winDivertBuffer, ref WinDivertAddress winDivertAddress, ref uint packetLength)
+        /// <param name="packet"></param>
+        /// <param name="addr"></param>
+        unsafe private void ModifyDnsPacket(WinDivertPacket packet, ref WinDivertAddress addr)
         {
-            var packet = WinDivert.WinDivertHelperParsePacket(winDivertBuffer, packetLength);
-            var requestPayload = new Span<byte>(packet.PacketPayload, (int)packet.PacketPayloadLength).ToArray();
+            var result = packet.GetParseResult();
+            var requestPayload = result.DataSpan.ToArray();
 
             if (TryParseRequest(requestPayload, out var request) == false ||
                 request.OperationCode != OperationCode.Query ||
@@ -148,38 +140,43 @@ namespace FastGithub.PacketIntercept.Dns
             var responsePayload = response.ToArray();
 
             // 修改payload和包长 
-            responsePayload.CopyTo(new Span<byte>(packet.PacketPayload, responsePayload.Length));
-            packetLength = (uint)((int)packetLength + responsePayload.Length - requestPayload.Length);
+            responsePayload.CopyTo(new Span<byte>(result.Data, responsePayload.Length));
+            packet.Length = packet.Length + responsePayload.Length - requestPayload.Length;
 
             // 修改ip包
             IPAddress destAddress;
-            if (packet.IPv4Header != null)
+            if (result.IPV4Header != null)
             {
-                destAddress = packet.IPv4Header->DstAddr;
-                packet.IPv4Header->DstAddr = packet.IPv4Header->SrcAddr;
-                packet.IPv4Header->SrcAddr = destAddress;
-                packet.IPv4Header->Length = (ushort)packetLength;
+                destAddress = result.IPV4Header->DstAddr;
+                result.IPV4Header->DstAddr = result.IPV4Header->SrcAddr;
+                result.IPV4Header->SrcAddr = destAddress;
+                result.IPV4Header->Length = (ushort)packet.Length;
             }
             else
             {
-                destAddress = packet.IPv6Header->DstAddr;
-                packet.IPv6Header->DstAddr = packet.IPv6Header->SrcAddr;
-                packet.IPv6Header->SrcAddr = destAddress;
-                packet.IPv6Header->Length = (ushort)(packetLength - sizeof(IPv6Header));
+                destAddress = result.IPV6Header->DstAddr;
+                result.IPV6Header->DstAddr = result.IPV6Header->SrcAddr;
+                result.IPV6Header->SrcAddr = destAddress;
+                result.IPV6Header->Length = (ushort)(packet.Length - sizeof(IPV6Header));
             }
 
             // 修改udp包
-            var destPort = packet.UdpHeader->DstPort;
-            packet.UdpHeader->DstPort = packet.UdpHeader->SrcPort;
-            packet.UdpHeader->SrcPort = destPort;
-            packet.UdpHeader->Length = (ushort)(sizeof(UdpHeader) + responsePayload.Length);
+            var destPort = result.UdpHeader->DstPort;
+            result.UdpHeader->DstPort = result.UdpHeader->SrcPort;
+            result.UdpHeader->SrcPort = destPort;
+            result.UdpHeader->Length = (ushort)(sizeof(UdpHeader) + responsePayload.Length);
 
-            winDivertAddress.Impostor = true;
-            winDivertAddress.Direction = winDivertAddress.Loopback
-                ? WinDivertDirection.Outbound
-                : WinDivertDirection.Inbound;
+            addr.Flags |= WinDivertAddressFlag.Impostor;
+            if (addr.Flags.HasFlag(WinDivertAddressFlag.Loopback))
+            {
+                addr.Flags |= WinDivertAddressFlag.Outbound;
+            }
+            else
+            {
+                addr.Flags ^= WinDivertAddressFlag.Outbound;
+            } 
 
-            WinDivert.WinDivertHelperCalcChecksums(winDivertBuffer, packetLength, ref winDivertAddress, WinDivertChecksumHelperParam.All);
+            packet.CalcChecksums(ref addr);
             this.logger.LogInformation($"{domain}->{loopback}");
         }
 

+ 1 - 1
FastGithub.PacketIntercept/FastGithub.PacketIntercept.csproj

@@ -7,7 +7,7 @@
 	<ItemGroup>
 		<FrameworkReference Include="Microsoft.AspNetCore.App" />
 		<PackageReference Include="DNS" Version="7.0.0" />
-		<PackageReference Include="FastGithub.WinDiverts" Version="1.4.1" />
+		<PackageReference Include="WindivertDotnet" Version="1.0.0-beta1" />
 	</ItemGroup>
 
 	<ItemGroup>

+ 26 - 34
FastGithub.PacketIntercept/Tcp/TcpInterceptor.cs

@@ -1,5 +1,4 @@
-using FastGithub.WinDiverts;
-using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging;
 using System;
 using System.ComponentModel;
 using System.Net;
@@ -7,6 +6,7 @@ using System.Net.Sockets;
 using System.Runtime.Versioning;
 using System.Threading;
 using System.Threading.Tasks;
+using WindivertDotnet;
 
 namespace FastGithub.PacketIntercept.Tcp
 {
@@ -16,7 +16,7 @@ namespace FastGithub.PacketIntercept.Tcp
     [SupportedOSPlatform("windows")]
     abstract class TcpInterceptor : ITcpInterceptor
     {
-        private readonly string filter;
+        private readonly Filter filter;
         private readonly ushort oldServerPort;
         private readonly ushort newServerPort;
         private readonly ILogger logger;
@@ -29,7 +29,10 @@ namespace FastGithub.PacketIntercept.Tcp
         /// <param name="logger"></param>
         public TcpInterceptor(int oldServerPort, int newServerPort, ILogger logger)
         {
-            this.filter = $"loopback and (tcp.DstPort == {oldServerPort} or tcp.SrcPort == {newServerPort})";
+            this.filter = Filter.True
+                .And(f => f.Network.Loopback)
+                .And(f => f.Tcp.DstPort == oldServerPort || f.Tcp.SrcPort == newServerPort);
+
             this.oldServerPort = (ushort)oldServerPort;
             this.newServerPort = (ushort)newServerPort;
             this.logger = logger;
@@ -49,12 +52,7 @@ namespace FastGithub.PacketIntercept.Tcp
 
             await Task.Yield();
 
-            var handle = WinDivert.WinDivertOpen(this.filter, WinDivertLayer.Network, 0, WinDivertOpenFlags.None);
-            if (handle == new IntPtr(unchecked((long)ulong.MaxValue)))
-            {
-                throw new Win32Exception();
-            }
-
+            using var divert = new WinDivert(this.filter, WinDivertLayer.Network, 0, WinDivertFlag.None);
             if (Socket.OSSupportsIPv4)
             {
                 this.logger.LogInformation($"{IPAddress.Loopback}:{this.oldServerPort} <=> {IPAddress.Loopback}:{this.newServerPort}");
@@ -63,23 +61,18 @@ namespace FastGithub.PacketIntercept.Tcp
             {
                 this.logger.LogInformation($"{IPAddress.IPv6Loopback}:{this.oldServerPort} <=> {IPAddress.IPv6Loopback}:{this.newServerPort}");
             }
-            cancellationToken.Register(hwnd => WinDivert.WinDivertClose((IntPtr)hwnd!), handle);
-
-            var packetLength = 0U;
-            using var winDivertBuffer = new WinDivertBuffer();
-            var winDivertAddress = new WinDivertAddress();
+            cancellationToken.Register(d => ((WinDivert)d!).Dispose(), divert);
 
+            var addr = new WinDivertAddress();
+            using var packet = new WinDivertPacket();
             while (cancellationToken.IsCancellationRequested == false)
             {
-                winDivertAddress.Reset();
-                if (WinDivert.WinDivertRecv(handle, winDivertBuffer, ref winDivertAddress, ref packetLength) == false)
-                {
-                    throw new Win32Exception();
-                }
+                addr.Clear();
+                divert.Recv(packet, ref addr);
 
                 try
                 {
-                    this.ModifyTcpPacket(winDivertBuffer, ref winDivertAddress, ref packetLength);
+                    this.ModifyTcpPacket(packet, ref addr);
                 }
                 catch (Exception ex)
                 {
@@ -87,7 +80,7 @@ namespace FastGithub.PacketIntercept.Tcp
                 }
                 finally
                 {
-                    WinDivert.WinDivertSend(handle, winDivertBuffer, packetLength, ref winDivertAddress);
+                    divert.Send(packet, ref addr);
                 }
             }
         }
@@ -95,31 +88,30 @@ namespace FastGithub.PacketIntercept.Tcp
         /// <summary>
         /// 修改tcp数据端口的端口
         /// </summary>
-        /// <param name="winDivertBuffer"></param>
-        /// <param name="winDivertAddress"></param>
-        /// <param name="packetLength"></param> 
-        unsafe private void ModifyTcpPacket(WinDivertBuffer winDivertBuffer, ref WinDivertAddress winDivertAddress, ref uint packetLength)
+        /// <param name="packet"></param>
+        /// <param name="addr"></param>
+        unsafe private void ModifyTcpPacket(WinDivertPacket packet, ref WinDivertAddress addr)
         {
-            var packet = WinDivert.WinDivertHelperParsePacket(winDivertBuffer, packetLength);
-            if (packet.IPv4Header != null && packet.IPv4Header->SrcAddr.Equals(IPAddress.Loopback) == false)
+            var result = packet.GetParseResult();
+            if (result.IPV4Header != null && result.IPV4Header->SrcAddr.Equals(IPAddress.Loopback) == false)
             {
                 return;
             }
-            if (packet.IPv6Header != null && packet.IPv6Header->SrcAddr.Equals(IPAddress.IPv6Loopback) == false)
+            if (result.IPV6Header != null && result.IPV6Header->SrcAddr.Equals(IPAddress.IPv6Loopback) == false)
             {
                 return;
             }
 
-            if (packet.TcpHeader->DstPort == oldServerPort)
+            if (result.TcpHeader->DstPort == oldServerPort)
             {
-                packet.TcpHeader->DstPort = this.newServerPort;
+                result.TcpHeader->DstPort = this.newServerPort;
             }
             else
             {
-                packet.TcpHeader->SrcPort = oldServerPort;
+                result.TcpHeader->SrcPort = oldServerPort;
             }
-            winDivertAddress.Impostor = true;
-            WinDivert.WinDivertHelperCalcChecksums(winDivertBuffer, packetLength, ref winDivertAddress, WinDivertChecksumHelperParam.All);
+            addr.Flags |= WinDivertAddressFlag.Impostor;
+            packet.CalcChecksums(ref addr);
         }
     }
 }