TcpReverseProxyHandler.cs 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. using FastGithub.DomainResolve;
  2. using Microsoft.AspNetCore.Connections;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.IO;
  6. using System.IO.Pipelines;
  7. using System.Net;
  8. using System.Net.Sockets;
  9. using System.Threading;
  10. using System.Threading.Tasks;
  11. namespace FastGithub.HttpServer
  12. {
  13. /// <summary>
  14. /// tcp反射代理处理者
  15. /// </summary>
  16. abstract class TcpReverseProxyHandler : ConnectionHandler
  17. {
  18. private readonly IDomainResolver domainResolver;
  19. private readonly DnsEndPoint endPoint;
  20. private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
  21. /// <summary>
  22. /// tcp反射代理处理者
  23. /// </summary>
  24. /// <param name="domainResolver"></param>
  25. /// <param name="endPoint"></param>
  26. public TcpReverseProxyHandler(IDomainResolver domainResolver, DnsEndPoint endPoint)
  27. {
  28. this.domainResolver = domainResolver;
  29. this.endPoint = endPoint;
  30. }
  31. /// <summary>
  32. /// ssh连接后
  33. /// </summary>
  34. /// <param name="context"></param>
  35. /// <returns></returns>
  36. public override async Task OnConnectedAsync(ConnectionContext context)
  37. {
  38. using var connection = await this.CreateConnectionAsync();
  39. var task1 = connection.CopyToAsync(context.Transport.Output);
  40. var task2 = context.Transport.Input.CopyToAsync(connection);
  41. await Task.WhenAny(task1, task2);
  42. }
  43. /// <summary>
  44. /// 创建连接
  45. /// </summary>
  46. /// <returns></returns>
  47. /// <exception cref="AggregateException"></exception>
  48. private async Task<Stream> CreateConnectionAsync()
  49. {
  50. var innerExceptions = new List<Exception>();
  51. await foreach (var address in this.domainResolver.ResolveAsync(this.endPoint))
  52. {
  53. var socket = new Socket(address.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
  54. try
  55. {
  56. using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
  57. await socket.ConnectAsync(address, this.endPoint.Port, timeoutTokenSource.Token);
  58. return new NetworkStream(socket, ownsSocket: false);
  59. }
  60. catch (Exception ex)
  61. {
  62. socket.Dispose();
  63. innerExceptions.Add(ex);
  64. }
  65. }
  66. throw new AggregateException($"无法连接到{this.endPoint.Host}:{this.endPoint.Port}", innerExceptions);
  67. }
  68. }
  69. }