|
@@ -4,12 +4,16 @@ using Microsoft.AspNetCore.Connections.Features;
|
|
|
using Microsoft.AspNetCore.Http;
|
|
|
using Microsoft.AspNetCore.Http.Features;
|
|
|
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
|
|
|
+using System;
|
|
|
+using System.Collections.Generic;
|
|
|
+using System.IO;
|
|
|
using System.IO.Pipelines;
|
|
|
using System.Net;
|
|
|
using System.Net.Http;
|
|
|
using System.Net.Sockets;
|
|
|
using System.Reflection;
|
|
|
using System.Text;
|
|
|
+using System.Threading;
|
|
|
using System.Threading.Tasks;
|
|
|
using Yarp.ReverseProxy.Forwarder;
|
|
|
|
|
@@ -30,6 +34,7 @@ namespace FastGithub.HttpServer
|
|
|
private readonly HttpReverseProxyMiddleware httpReverseProxy;
|
|
|
|
|
|
private readonly HttpMessageInvoker defaultHttpClient;
|
|
|
+ private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
|
|
|
|
|
|
static HttpProxyMiddleware()
|
|
|
{
|
|
@@ -83,10 +88,7 @@ namespace FastGithub.HttpServer
|
|
|
}
|
|
|
else if (context.Request.Method == HttpMethods.Connect)
|
|
|
{
|
|
|
- var endpoint = await this.GetTargetEndPointAsync(host);
|
|
|
- using var targetSocket = new Socket(SocketType.Stream, ProtocolType.Tcp);
|
|
|
- await targetSocket.ConnectAsync(endpoint);
|
|
|
-
|
|
|
+ using var connection = await this.CreateConnectionAsync(host);
|
|
|
var responseFeature = context.Features.Get<IHttpResponseFeature>();
|
|
|
if (responseFeature != null)
|
|
|
{
|
|
@@ -98,8 +100,7 @@ namespace FastGithub.HttpServer
|
|
|
var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport;
|
|
|
if (transport != null)
|
|
|
{
|
|
|
- var targetStream = new NetworkStream(targetSocket, ownsSocket: false);
|
|
|
- await Task.WhenAny(targetStream.CopyToAsync(transport.Output), transport.Input.CopyToAsync(targetStream));
|
|
|
+ await Task.WhenAny(connection.CopyToAsync(transport.Output), transport.Input.CopyToAsync(connection));
|
|
|
}
|
|
|
}
|
|
|
else
|
|
@@ -151,40 +152,73 @@ namespace FastGithub.HttpServer
|
|
|
return buidler.ToString();
|
|
|
}
|
|
|
|
|
|
+ /// <summary>
|
|
|
+ /// 创建连接
|
|
|
+ /// </summary>
|
|
|
+ /// <param name="host"></param>
|
|
|
+ /// <returns></returns>
|
|
|
+ /// <exception cref="AggregateException"></exception>
|
|
|
+ private async Task<Stream> CreateConnectionAsync(HostString host)
|
|
|
+ {
|
|
|
+ var innerExceptions = new List<Exception>();
|
|
|
+ await foreach (var endPoint in this.GetTargetEndPointsAsync(host))
|
|
|
+ {
|
|
|
+ var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
|
|
|
+ try
|
|
|
+ {
|
|
|
+ using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
|
|
|
+ await socket.ConnectAsync(endPoint, timeoutTokenSource.Token);
|
|
|
+ return new NetworkStream(socket, ownsSocket: false);
|
|
|
+ }
|
|
|
+ catch (Exception ex)
|
|
|
+ {
|
|
|
+ socket.Dispose();
|
|
|
+ innerExceptions.Add(ex);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ throw new AggregateException($"无法连接到{host}", innerExceptions);
|
|
|
+ }
|
|
|
+
|
|
|
/// <summary>
|
|
|
/// 获取目标终节点
|
|
|
/// </summary>
|
|
|
/// <param name="host"></param>
|
|
|
/// <returns></returns>
|
|
|
- private async Task<EndPoint> GetTargetEndPointAsync(HostString host)
|
|
|
+ private async IAsyncEnumerable<EndPoint> GetTargetEndPointsAsync(HostString host)
|
|
|
{
|
|
|
var targetHost = host.Host;
|
|
|
var targetPort = host.Port ?? HTTPS_PORT;
|
|
|
|
|
|
if (IPAddress.TryParse(targetHost, out var address) == true)
|
|
|
{
|
|
|
- return new IPEndPoint(address, targetPort);
|
|
|
+ yield return new IPEndPoint(address, targetPort);
|
|
|
+ yield break;
|
|
|
}
|
|
|
|
|
|
// 不关心的域名,直接使用系统dns
|
|
|
if (this.fastGithubConfig.IsMatch(targetHost) == false)
|
|
|
{
|
|
|
- return new DnsEndPoint(targetHost, targetPort);
|
|
|
+ yield return new DnsEndPoint(targetHost, targetPort);
|
|
|
+ yield break;
|
|
|
}
|
|
|
|
|
|
if (targetPort == HTTP_PORT)
|
|
|
{
|
|
|
- return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Http);
|
|
|
+ yield return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Http);
|
|
|
+ yield break;
|
|
|
}
|
|
|
|
|
|
if (targetPort == HTTPS_PORT)
|
|
|
{
|
|
|
- return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Https);
|
|
|
+ yield return new IPEndPoint(IPAddress.Loopback, ReverseProxyPort.Https);
|
|
|
+ yield break;
|
|
|
}
|
|
|
|
|
|
- // 不使用系统dns
|
|
|
- address = await this.domainResolver.ResolveAnyAsync(new DnsEndPoint(targetHost, targetPort));
|
|
|
- return new IPEndPoint(address, targetPort);
|
|
|
+ var dnsEndPoint = new DnsEndPoint(targetHost, targetPort);
|
|
|
+ await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint))
|
|
|
+ {
|
|
|
+ yield return new IPEndPoint(item, targetPort);
|
|
|
+ }
|
|
|
}
|
|
|
|
|
|
/// <summary>
|