123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- using FastGithub.Configuration;
- using FastGithub.DomainResolve;
- using Microsoft.AspNetCore.Connections;
- using Microsoft.AspNetCore.Connections.Features;
- using Microsoft.AspNetCore.Http;
- using Microsoft.AspNetCore.Http.Features;
- using System;
- using System.Collections.Generic;
- using System.IO;
- using System.IO.Pipelines;
- using System.Net;
- using System.Net.Sockets;
- using System.Runtime.CompilerServices;
- using System.Threading;
- using System.Threading.Tasks;
- namespace FastGithub.HttpServer.TcpMiddlewares
- {
- /// <summary>
- /// 隧道中间件
- /// </summary>
- sealed class TunnelMiddleware
- {
- private readonly FastGithubConfig fastGithubConfig;
- private readonly IDomainResolver domainResolver;
- private readonly TimeSpan connectTimeout = TimeSpan.FromSeconds(10d);
- /// <summary>
- /// 隧道中间件
- /// </summary>
- /// <param name="fastGithubConfig"></param>
- /// <param name="domainResolver"></param>
- public TunnelMiddleware(
- FastGithubConfig fastGithubConfig,
- IDomainResolver domainResolver)
- {
- this.fastGithubConfig = fastGithubConfig;
- this.domainResolver = domainResolver;
- }
- /// <summary>
- /// 执行中间件
- /// </summary>
- /// <param name="next"></param>
- /// <param name="context"></param>
- /// <returns></returns>
- public async Task InvokeAsync(ConnectionDelegate next, ConnectionContext context)
- {
- var proxyFeature = context.Features.Get<IHttpProxyFeature>();
- if (proxyFeature == null || // 非代理
- proxyFeature.ProxyProtocol != ProxyProtocol.TunnelProxy || //非隧道代理
- context.Features.Get<ITlsConnectionFeature>() != null) // 经过隧道的https
- {
- await next(context);
- }
- else
- {
- var transport = context.Features.Get<IConnectionTransportFeature>()?.Transport;
- if (transport != null)
- {
- var cancellationToken = context.ConnectionClosed;
- using var connection = await this.CreateConnectionAsync(proxyFeature.ProxyHost, cancellationToken);
- var task1 = connection.CopyToAsync(transport.Output, cancellationToken);
- var task2 = transport.Input.CopyToAsync(connection, cancellationToken);
- await Task.WhenAny(task1, task2);
- }
- }
- }
- /// <summary>
- /// 创建连接
- /// </summary>
- /// <param name="host"></param>
- /// <param name="cancellationToken"></param>
- /// <returns></returns>
- /// <exception cref="AggregateException"></exception>
- private async Task<Stream> CreateConnectionAsync(HostString host, CancellationToken cancellationToken)
- {
- var innerExceptions = new List<Exception>();
- await foreach (var endPoint in this.GetUpstreamEndPointsAsync(host, cancellationToken))
- {
- var socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
- try
- {
- using var timeoutTokenSource = new CancellationTokenSource(this.connectTimeout);
- using var linkedTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutTokenSource.Token);
- await socket.ConnectAsync(endPoint, linkedTokenSource.Token);
- return new NetworkStream(socket, ownsSocket: true);
- }
- catch (Exception ex)
- {
- socket.Dispose();
- cancellationToken.ThrowIfCancellationRequested();
- innerExceptions.Add(ex);
- }
- }
- throw new AggregateException($"无法连接到{host}", innerExceptions);
- }
- /// <summary>
- /// 获取目标终节点
- /// </summary>
- /// <param name="host"></param>
- /// <param name="cancellationToken"></param>
- /// <returns></returns>
- private async IAsyncEnumerable<EndPoint> GetUpstreamEndPointsAsync(HostString host, [EnumeratorCancellation] CancellationToken cancellationToken)
- {
- const int HTTPS_PORT = 443;
- var targetHost = host.Host;
- var targetPort = host.Port ?? HTTPS_PORT;
- if (IPAddress.TryParse(targetHost, out var address) == true)
- {
- yield return new IPEndPoint(address, targetPort);
- }
- else if (this.fastGithubConfig.IsMatch(targetHost) == false)
- {
- yield return new DnsEndPoint(targetHost, targetPort);
- }
- else
- {
- var dnsEndPoint = new DnsEndPoint(targetHost, targetPort);
- await foreach (var item in this.domainResolver.ResolveAsync(dnsEndPoint, cancellationToken))
- {
- yield return new IPEndPoint(item, targetPort);
- }
- }
- }
- }
- }
|