TcpReverseProxyHandler.cs 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. using FastGithub.DomainResolve;
  2. using Microsoft.AspNetCore.Connections;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using System.IO.Pipelines;
  7. using System.Net;
  8. using System.Net.Sockets;
  9. using System.Threading;
  10. using System.Threading.Tasks;
  11. namespace FastGithub.HttpServer.TcpMiddlewares
  12. {
  13. /// <summary>
  14. /// tcp协议代理处理者
  15. /// </summary>
  16. abstract class TcpReverseProxyHandler : ConnectionHandler
  17. {
  18. private readonly IDomainResolver domainResolver;
  19. private readonly DnsEndPoint endPoint;
  20. private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
  21. /// <summary>
  22. /// tcp协议代理处理者
  23. /// </summary>
  24. /// <param name="domainResolver"></param>
  25. /// <param name="endPoint"></param>
  26. public TcpReverseProxyHandler(IDomainResolver domainResolver, DnsEndPoint endPoint)
  27. {
  28. this.domainResolver = domainResolver;
  29. this.endPoint = endPoint;
  30. }
  31. /// <summary>
  32. /// tcp连接后
  33. /// </summary>
  34. /// <param name="context"></param>
  35. /// <returns></returns>
  36. public override async Task OnConnectedAsync(ConnectionContext context)
  37. {
  38. var cancellationToken = context.ConnectionClosed;
  39. using var connection = await CreateConnectionAsync(cancellationToken);
  40. var task1 = connection.CopyToAsync(context.Transport.Output, cancellationToken);
  41. var task2 = context.Transport.Input.CopyToAsync(connection, cancellationToken);
  42. await Task.WhenAny(task1, task2);
  43. }
  44. /// <summary>
  45. /// 创建连接
  46. /// </summary>
  47. /// <param name="cancellationToken"></param>
  48. /// <returns></returns>
  49. /// <exception cref="AggregateException"></exception>
  50. private async Task<Stream> CreateConnectionAsync(CancellationToken cancellationToken)
  51. {
  52. var innerExceptions = new List<Exception>();
  53. await foreach (var address in domainResolver.ResolveAsync(endPoint, cancellationToken))
  54. {
  55. var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
  56. try
  57. {
  58. using var timeoutTokenSource = new CancellationTokenSource(connectTimeout);
  59. using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
  60. await socket.ConnectAsync(address, endPoint.Port, linkedTokenSource.Token);
  61. return new NetworkStream(socket, ownsSocket: false);
  62. }
  63. catch (Exception ex)
  64. {
  65. socket.Dispose();
  66. cancellationToken.ThrowIfCancellationRequested();
  67. innerExceptions.Add(ex);
  68. }
  69. }
  70. throw new AggregateException($"无法连接到{endPoint.Host}:{endPoint.Port}", innerExceptions);
  71. }
  72. }
  73. }