2
0

HttpProxyMiddleware.cs 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. using Microsoft.AspNetCore.Connections;
  2. using Microsoft.AspNetCore.Http;
  3. using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
  4. using System;
  5. using System.Buffers;
  6. using System.IO.Pipelines;
  7. using System.Text;
  8. using System.Threading.Tasks;
  9. namespace FastGithub.HttpServer.TcpMiddlewares
  10. {
  11. /// <summary>
  12. /// 正向代理中间件
  13. /// </summary>
  14. sealed class HttpProxyMiddleware
  15. {
  16. private readonly HttpParser<HttpRequestHandler> httpParser = new();
  17. private readonly byte[] http200 = Encoding.ASCII.GetBytes("HTTP/1.1 200 Connection Established\r\n\r\n");
  18. private readonly byte[] http400 = Encoding.ASCII.GetBytes("HTTP/1.1 400 Bad Request\r\n\r\n");
  19. /// <summary>
  20. /// 执行中间件
  21. /// </summary>
  22. /// <param name="next"></param>
  23. /// <param name="context"></param>
  24. /// <returns></returns>
  25. public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext context)
  26. {
  27. var result = await context.Transport.Input.ReadAsync();
  28. var httpRequest = this.GetHttpRequestHandler(result, out var consumed);
  29. // 协议错误
  30. if (consumed == 0L)
  31. {
  32. await context.Transport.Output.WriteAsync(this.http400, context.ConnectionClosed);
  33. }
  34. else
  35. {
  36. // 隧道代理连接请求
  37. if (httpRequest.ProxyProtocol == ProxyProtocol.TunnelProxy)
  38. {
  39. var position = result.Buffer.GetPosition(consumed);
  40. context.Transport.Input.AdvanceTo(position);
  41. await context.Transport.Output.WriteAsync(this.http200, context.ConnectionClosed);
  42. }
  43. else
  44. {
  45. var position = result.Buffer.Start;
  46. context.Transport.Input.AdvanceTo(position);
  47. }
  48. context.Features.Set<IHttpProxyFeature>(httpRequest);
  49. await next(context);
  50. }
  51. }
  52. /// <summary>
  53. /// 获取http请求处理者
  54. /// </summary>
  55. /// <param name="result"></param>
  56. /// <param name="consumed"></param>
  57. /// <returns></returns>
  58. private HttpRequestHandler GetHttpRequestHandler(ReadResult result, out long consumed)
  59. {
  60. var handler = new HttpRequestHandler();
  61. var reader = new SequenceReader<byte>(result.Buffer);
  62. if (this.httpParser.ParseRequestLine(handler, ref reader) &&
  63. this.httpParser.ParseHeaders(handler, ref reader))
  64. {
  65. consumed = reader.Consumed;
  66. }
  67. else
  68. {
  69. consumed = 0L;
  70. }
  71. return handler;
  72. }
  73. /// <summary>
  74. /// 代理请求处理器
  75. /// </summary>
  76. private class HttpRequestHandler : IHttpRequestLineHandler, IHttpHeadersHandler, IHttpProxyFeature
  77. {
  78. private HttpMethod method;
  79. public HostString ProxyHost { get; private set; }
  80. public ProxyProtocol ProxyProtocol
  81. {
  82. get
  83. {
  84. if (this.ProxyHost.HasValue == false)
  85. {
  86. return ProxyProtocol.None;
  87. }
  88. if (this.method == HttpMethod.Connect)
  89. {
  90. return ProxyProtocol.TunnelProxy;
  91. }
  92. return ProxyProtocol.HttpProxy;
  93. }
  94. }
  95. void IHttpRequestLineHandler.OnStartLine(HttpVersionAndMethod versionAndMethod, TargetOffsetPathLength targetPath, Span<byte> startLine)
  96. {
  97. this.method = versionAndMethod.Method;
  98. var host = Encoding.ASCII.GetString(startLine.Slice(targetPath.Offset, targetPath.Length));
  99. if (versionAndMethod.Method == HttpMethod.Connect)
  100. {
  101. this.ProxyHost = HostString.FromUriComponent(host);
  102. }
  103. else if (Uri.TryCreate(host, UriKind.Absolute, out var uri))
  104. {
  105. this.ProxyHost = HostString.FromUriComponent(uri);
  106. }
  107. }
  108. void IHttpHeadersHandler.OnHeader(ReadOnlySpan<byte> name, ReadOnlySpan<byte> value)
  109. {
  110. }
  111. void IHttpHeadersHandler.OnHeadersComplete(bool endStream)
  112. {
  113. }
  114. void IHttpHeadersHandler.OnStaticIndexedHeader(int index)
  115. {
  116. }
  117. void IHttpHeadersHandler.OnStaticIndexedHeader(int index, ReadOnlySpan<byte> value)
  118. {
  119. }
  120. }
  121. }
  122. }