HttpReverseProxyMiddleware.cs 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. using FastGithub.Configuration;
  2. using FastGithub.Http;
  3. using Microsoft.AspNetCore.Http;
  4. using Microsoft.Extensions.Logging;
  5. using System;
  6. using System.Diagnostics.CodeAnalysis;
  7. using System.Net;
  8. using System.Threading.Tasks;
  9. using Yarp.ReverseProxy.Forwarder;
  10. namespace FastGithub.HttpServer
  11. {
  12. /// <summary>
  13. /// 反向代理中间件
  14. /// </summary>
  15. sealed class HttpReverseProxyMiddleware
  16. {
  17. private static readonly DomainConfig defaultDomainConfig = new() { TlsSni = true };
  18. private readonly IHttpForwarder httpForwarder;
  19. private readonly IHttpClientFactory httpClientFactory;
  20. private readonly FastGithubConfig fastGithubConfig;
  21. private readonly ILogger<HttpReverseProxyMiddleware> logger;
  22. public HttpReverseProxyMiddleware(
  23. IHttpForwarder httpForwarder,
  24. IHttpClientFactory httpClientFactory,
  25. FastGithubConfig fastGithubConfig,
  26. ILogger<HttpReverseProxyMiddleware> logger)
  27. {
  28. this.httpForwarder = httpForwarder;
  29. this.httpClientFactory = httpClientFactory;
  30. this.fastGithubConfig = fastGithubConfig;
  31. this.logger = logger;
  32. }
  33. /// <summary>
  34. /// 处理请求
  35. /// </summary>
  36. /// <param name="context"></param>
  37. /// <param name="next"?
  38. /// <returns></returns>
  39. public async Task InvokeAsync(HttpContext context, RequestDelegate next)
  40. {
  41. var host = context.Request.Host;
  42. if (this.TryGetDomainConfig(host, out var domainConfig) == false)
  43. {
  44. await next(context);
  45. }
  46. else if (domainConfig.Response == null)
  47. {
  48. var scheme = context.Request.Scheme;
  49. var destinationPrefix = GetDestinationPrefix(scheme, host, domainConfig.Destination);
  50. var httpClient = this.httpClientFactory.CreateHttpClient(host.Host, domainConfig);
  51. var error = await httpForwarder.SendAsync(context, destinationPrefix, httpClient);
  52. await HandleErrorAsync(context, error);
  53. }
  54. else
  55. {
  56. context.Response.StatusCode = domainConfig.Response.StatusCode;
  57. context.Response.ContentType = domainConfig.Response.ContentType;
  58. if (domainConfig.Response.ContentValue != null)
  59. {
  60. await context.Response.WriteAsync(domainConfig.Response.ContentValue);
  61. }
  62. }
  63. }
  64. /// <summary>
  65. /// 获取域名的DomainConfig
  66. /// </summary>
  67. /// <param name="host"></param>
  68. /// <param name="domainConfig"></param>
  69. /// <returns></returns>
  70. private bool TryGetDomainConfig(HostString host, [MaybeNullWhen(false)] out DomainConfig domainConfig)
  71. {
  72. if (this.fastGithubConfig.TryGetDomainConfig(host.Host, out domainConfig) == true)
  73. {
  74. return true;
  75. }
  76. // 未配置的域名,但仍然被解析到本机ip的域名
  77. if (OperatingSystem.IsWindows() && IsDomain(host.Host))
  78. {
  79. this.logger.LogWarning($"域名{host.Host}可能已经被DNS污染,如果域名为本机域名,请解析为非回环IP");
  80. domainConfig = defaultDomainConfig;
  81. return true;
  82. }
  83. return false;
  84. // 是否为域名
  85. static bool IsDomain(string host)
  86. {
  87. return IPAddress.TryParse(host, out _) == false && host.Contains('.');
  88. }
  89. }
  90. /// <summary>
  91. /// 获取目标前缀
  92. /// </summary>
  93. /// <param name="scheme"></param>
  94. /// <param name="host"></param>
  95. /// <param name="destination"></param>
  96. /// <returns></returns>
  97. private string GetDestinationPrefix(string scheme, HostString host, Uri? destination)
  98. {
  99. var defaultValue = $"{scheme}://{host}/";
  100. if (destination == null)
  101. {
  102. return defaultValue;
  103. }
  104. var baseUri = new Uri(defaultValue);
  105. var result = new Uri(baseUri, destination).ToString();
  106. this.logger.LogInformation($"{defaultValue} => {result}");
  107. return result;
  108. }
  109. /// <summary>
  110. /// 处理错误信息
  111. /// </summary>
  112. /// <param name="context"></param>
  113. /// <param name="error"></param>
  114. /// <returns></returns>
  115. private static async Task HandleErrorAsync(HttpContext context, ForwarderError error)
  116. {
  117. if (error == ForwarderError.None || context.Response.HasStarted)
  118. {
  119. return;
  120. }
  121. await context.Response.WriteAsJsonAsync(new
  122. {
  123. error = error.ToString(),
  124. message = context.GetForwarderErrorFeature()?.Exception?.Message
  125. });
  126. }
  127. }
  128. }