HttpReverseProxyMiddleware.cs 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. using FastGithub.Configuration;
  2. using FastGithub.Http;
  3. using Microsoft.AspNetCore.Http;
  4. using Microsoft.Extensions.Logging;
  5. using System;
  6. using System.Diagnostics.CodeAnalysis;
  7. using System.Threading.Tasks;
  8. using Yarp.ReverseProxy.Forwarder;
  9. namespace FastGithub.HttpServer
  10. {
  11. /// <summary>
  12. /// 反向代理中间件
  13. /// </summary>
  14. sealed class HttpReverseProxyMiddleware
  15. {
  16. private static readonly DomainConfig defaultDomainConfig = new() { TlsSni = true };
  17. private readonly IHttpForwarder httpForwarder;
  18. private readonly IHttpClientFactory httpClientFactory;
  19. private readonly FastGithubConfig fastGithubConfig;
  20. private readonly ILogger<HttpReverseProxyMiddleware> logger;
  21. public HttpReverseProxyMiddleware(
  22. IHttpForwarder httpForwarder,
  23. IHttpClientFactory httpClientFactory,
  24. FastGithubConfig fastGithubConfig,
  25. ILogger<HttpReverseProxyMiddleware> logger)
  26. {
  27. this.httpForwarder = httpForwarder;
  28. this.httpClientFactory = httpClientFactory;
  29. this.fastGithubConfig = fastGithubConfig;
  30. this.logger = logger;
  31. }
  32. /// <summary>
  33. /// 处理请求
  34. /// </summary>
  35. /// <param name="context"></param>
  36. /// <param name="next"?
  37. /// <returns></returns>
  38. public async Task InvokeAsync(HttpContext context, RequestDelegate next)
  39. {
  40. var host = context.Request.Host;
  41. if (this.TryGetDomainConfig(host, out var domainConfig) == false)
  42. {
  43. await next(context);
  44. }
  45. else if (domainConfig.Response == null)
  46. {
  47. var scheme = context.Request.Scheme;
  48. var destinationPrefix = GetDestinationPrefix(scheme, host, domainConfig.Destination);
  49. var httpClient = this.httpClientFactory.CreateHttpClient(host.Host, domainConfig);
  50. var error = await httpForwarder.SendAsync(context, destinationPrefix, httpClient);
  51. await HandleErrorAsync(context, error);
  52. }
  53. else
  54. {
  55. context.Response.StatusCode = domainConfig.Response.StatusCode;
  56. context.Response.ContentType = domainConfig.Response.ContentType;
  57. if (domainConfig.Response.ContentValue != null)
  58. {
  59. await context.Response.WriteAsync(domainConfig.Response.ContentValue);
  60. }
  61. }
  62. }
  63. /// <summary>
  64. /// 获取域名的DomainConfig
  65. /// </summary>
  66. /// <param name="host"></param>
  67. /// <param name="domainConfig"></param>
  68. /// <returns></returns>
  69. private bool TryGetDomainConfig(HostString host, [MaybeNullWhen(false)] out DomainConfig domainConfig)
  70. {
  71. if (this.fastGithubConfig.TryGetDomainConfig(host.Host, out domainConfig) == true)
  72. {
  73. return true;
  74. }
  75. // 未配置的域名,但仍然被解析到本机ip的域名
  76. if (host.Host.Contains('.') == true)
  77. {
  78. domainConfig = defaultDomainConfig;
  79. return true;
  80. }
  81. return false;
  82. }
  83. /// <summary>
  84. /// 获取目标前缀
  85. /// </summary>
  86. /// <param name="scheme"></param>
  87. /// <param name="host"></param>
  88. /// <param name="destination"></param>
  89. /// <returns></returns>
  90. private string GetDestinationPrefix(string scheme, HostString host, Uri? destination)
  91. {
  92. var defaultValue = $"{scheme}://{host}/";
  93. if (destination == null)
  94. {
  95. return defaultValue;
  96. }
  97. var baseUri = new Uri(defaultValue);
  98. var result = new Uri(baseUri, destination).ToString();
  99. this.logger.LogInformation($"{defaultValue} => {result}");
  100. return result;
  101. }
  102. /// <summary>
  103. /// 处理错误信息
  104. /// </summary>
  105. /// <param name="context"></param>
  106. /// <param name="error"></param>
  107. /// <returns></returns>
  108. private static async Task HandleErrorAsync(HttpContext context, ForwarderError error)
  109. {
  110. if (error == ForwarderError.None || context.Response.HasStarted)
  111. {
  112. return;
  113. }
  114. await context.Response.WriteAsJsonAsync(new
  115. {
  116. error = error.ToString(),
  117. message = context.GetForwarderErrorFeature()?.Exception?.Message
  118. });
  119. }
  120. }
  121. }