diff options
author | Rory& <root@rory.gay> | 2024-02-26 18:28:13 +0100 |
---|---|---|
committer | Rory& <root@rory.gay> | 2024-02-26 18:28:13 +0100 |
commit | e93f51bd1b4f029982e227a0e7ea7a7ad9885d0e (patch) | |
tree | 2e87230390534748c8d5ac456d39880472ea774b | |
parent | Fix auth code to be excluded on federation. (diff) | |
download | MatrixMediaGate-e93f51bd1b4f029982e227a0e7ea7a7ad9885d0e.tar.xz |
Optimise hot paths
-rw-r--r-- | MatrixMediaGate/Program.cs | 35 | ||||
-rw-r--r-- | MatrixMediaGate/ProxyConfiguration.cs | 1 | ||||
-rw-r--r-- | MatrixMediaGate/Services/AuthValidator.cs | 37 |
3 files changed, 40 insertions, 33 deletions
diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs index 9ebac51..4a62ebc 100644 --- a/MatrixMediaGate/Program.cs +++ b/MatrixMediaGate/Program.cs @@ -8,11 +8,19 @@ var builder = WebApplication.CreateBuilder(args); builder.Services.AddSingleton<ProxyConfiguration>(); builder.Services.AddSingleton<IHttpContextAccessor, HttpContextAccessor>(); -builder.Services.AddScoped<AuthValidator>(); +builder.Services.AddSingleton<AuthValidator>(); +builder.Services.AddSingleton<HttpClient>(services => { + var cfg = services.GetRequiredService<ProxyConfiguration>(); + // var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None }; + return new HttpClient() { + BaseAddress = new Uri(cfg.Upstream), + MaxResponseContentBufferSize = 1 * 1024 * 1024 // 1MB + }; +}); var app = builder.Build(); -async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) { +async Task Proxy(HttpClient hc, ProxyConfiguration cfg, HttpContext ctx, ILogger<Program> logger) { if (ctx is null) return; var path = ctx.Request.Path.Value; if (path is null) return; @@ -20,15 +28,13 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL path = path[1..]; path += ctx.Request.QueryString.Value; - using var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None }; - - using var hc = new HttpClient(handler) { BaseAddress = new Uri(cfg.Upstream) }; var method = new HttpMethod(ctx.Request.Method); using var req = new HttpRequestMessage(method, path); foreach (var header in ctx.Request.Headers) { - // if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") + if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") req.Headers.Add(header.Key, header.Value.ToArray()); } + req.Headers.Host = cfg.Host; if (ctx.Request.ContentLength > 0) { req.Content = new StreamContent(ctx.Request.Body); @@ -41,12 +47,13 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL using var response = await hc.SendAsync(req, HttpCompletionOption.ResponseHeadersRead); ctx.Response.Headers.Clear(); foreach (var header in response.Headers) { - // if (header.Key != "Transfer-Encoding") + if (header.Key != "Transfer-Encoding") ctx.Response.Headers[header.Key] = header.Value.ToArray(); } ctx.Response.StatusCode = (int)response.StatusCode; ctx.Response.ContentType = response.Content.Headers.ContentType?.ToString() ?? "application/json"; + if (response.Content.Headers.ContentLength != null) ctx.Response.ContentLength = response.Content.Headers.ContentLength; await ctx.Response.StartAsync(); await using var content = await response.Content.ReadAsStreamAsync(); await content.CopyToAsync(ctx.Response.Body); @@ -54,17 +61,15 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL await ctx.Response.CompleteAsync(); } -async Task ProxyMaybeAuth(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) { - if (ctx is null) return; - - await auth.UpdateAuth(); +async Task ProxyMaybeAuth(HttpClient hc, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) { + await auth.UpdateAuth(ctx); - await Proxy(cfg, auth, ctx, logger); + await Proxy(hc, cfg, ctx, logger); } -async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) { - if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth()) { - await ProxyMaybeAuth(cfg, auth, ctx, logger); // Some clients may send Authorization header... +async Task ProxyMedia(string serverName, ProxyConfiguration cfg, HttpClient hc, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) { + if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth(ctx)) { + await ProxyMaybeAuth(hc, cfg, auth, ctx, logger); // Some clients may send Authorization header... } else { ctx.Response.StatusCode = 403; diff --git a/MatrixMediaGate/ProxyConfiguration.cs b/MatrixMediaGate/ProxyConfiguration.cs index 0a126d4..ebe3509 100644 --- a/MatrixMediaGate/ProxyConfiguration.cs +++ b/MatrixMediaGate/ProxyConfiguration.cs @@ -10,5 +10,4 @@ public class ProxyConfiguration { public required string Upstream { get; set; } public required string Host { get; set; } public required List<string> TrustedServers { get; set; } - public bool ForceHost { get; set; } } \ No newline at end of file diff --git a/MatrixMediaGate/Services/AuthValidator.cs b/MatrixMediaGate/Services/AuthValidator.cs index 6f2b0c1..cf53cef 100644 --- a/MatrixMediaGate/Services/AuthValidator.cs +++ b/MatrixMediaGate/Services/AuthValidator.cs @@ -3,12 +3,18 @@ using System.Text.Json; namespace MatrixMediaGate.Services; -public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg, IHttpContextAccessor ctx) { - private static Dictionary<string, DateTime> _authCache = new(); +public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg) { + private readonly Dictionary<string, DateTime> _authCache = new(); + private readonly HttpClient _hc = new() { + BaseAddress = new Uri(cfg.Upstream), + DefaultRequestHeaders = { + Host = cfg.Host + } + }; - public async Task UpdateAuth() { - if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return; - var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString(); + public async Task UpdateAuth(HttpContext ctx) { + if (ctx.Connection.RemoteIpAddress is null) return; + var remote = ctx.Connection.RemoteIpAddress.ToString(); if (_authCache.TryGetValue(remote, out var value)) { if (value > DateTime.Now.AddSeconds(30)) { @@ -18,12 +24,10 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg _authCache.Remove(remote); } - string? token = getToken(); + string? token = getToken(ctx); if (token is null) return; - - using var hc = new HttpClient(); using var req = new HttpRequestMessage(HttpMethod.Get, $"{cfg.Upstream}/_matrix/client/v3/account/whoami?access_token={token}"); - var response = await hc.SendAsync(req); + var response = await _hc.SendAsync(req); if (response.Content.Headers.ContentType?.MediaType != "application/json") return; var content = await response.Content.ReadAsStringAsync(); @@ -34,9 +38,9 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg } } - public bool ValidateAuth() { - if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return false; - var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString(); + public bool ValidateAuth(HttpContext ctx) { + if (ctx.Connection.RemoteIpAddress is null) return false; + var remote = ctx.Connection.RemoteIpAddress.ToString(); if (_authCache.ContainsKey(remote)) { if (_authCache[remote] > DateTime.Now) { @@ -49,13 +53,12 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg return false; } - private string? getToken() { - if (ctx.HttpContext is null) return null; - if (ctx.HttpContext.Request.Headers.TryGetValue("Authorization", out var header)) { + private string? getToken(HttpContext ctx) { + if (ctx.Request.Headers.TryGetValue("Authorization", out var header)) { return header.ToString().Split(' ', 2)[1]; } - else if (ctx.HttpContext.Request.Query.ContainsKey("access_token")) { - return ctx.HttpContext.Request.Query["access_token"]!; + else if (ctx.Request.Query.ContainsKey("access_token")) { + return ctx.Request.Query["access_token"]!; } else { return null; |