TlsInvadeMiddleware.cs 2.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. using Microsoft.AspNetCore.Connections;
  2. using Microsoft.AspNetCore.Http.Features;
  3. using System.Buffers;
  4. using System.IO.Pipelines;
  5. using System.Threading.Tasks;
  6. namespace FastGithub.HttpServer.TlsMiddlewares
  7. {
  8. /// <summary>
  9. /// https入侵中间件
  10. /// </summary>
  11. sealed class TlsInvadeMiddleware
  12. {
  13. /// <summary>
  14. /// 执行中间件
  15. /// </summary>
  16. /// <param name="context"></param>
  17. /// <returns></returns>
  18. public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext context)
  19. {
  20. // 连接不是tls
  21. if (await IsTlsConnectionAsync(context) == false)
  22. {
  23. // 没有任何tls中间件执行过
  24. if (context.Features.Get<ITlsConnectionFeature>() == null)
  25. {
  26. // 设置假的ITlsConnectionFeature,迫使https中间件跳过自身的工作
  27. context.Features.Set<ITlsConnectionFeature>(FakeTlsConnectionFeature.Instance);
  28. }
  29. }
  30. await next(context);
  31. }
  32. /// <summary>
  33. /// 是否为tls协议
  34. /// </summary>
  35. /// <param name="context"></param>
  36. /// <returns></returns>
  37. private static async Task<bool> IsTlsConnectionAsync(ConnectionContext context)
  38. {
  39. try
  40. {
  41. var result = await context.Transport.Input.ReadAtLeastAsync(2, context.ConnectionClosed);
  42. var state = IsTlsProtocol(result);
  43. context.Transport.Input.AdvanceTo(result.Buffer.Start);
  44. return state;
  45. }
  46. catch
  47. {
  48. return false;
  49. }
  50. static bool IsTlsProtocol(ReadResult result)
  51. {
  52. var reader = new SequenceReader<byte>(result.Buffer);
  53. return reader.TryRead(out var firstByte) &&
  54. reader.TryRead(out var nextByte) &&
  55. firstByte == 0x16 &&
  56. nextByte == 0x3;
  57. }
  58. }
  59. }
  60. }