about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRory& <root@rory.gay>2024-02-26 18:28:13 +0100
committerRory& <root@rory.gay>2024-02-26 18:28:13 +0100
commite93f51bd1b4f029982e227a0e7ea7a7ad9885d0e (patch)
tree2e87230390534748c8d5ac456d39880472ea774b
parentFix auth code to be excluded on federation. (diff)
downloadMatrixMediaGate-e93f51bd1b4f029982e227a0e7ea7a7ad9885d0e.tar.xz
Optimise hot paths
-rw-r--r--MatrixMediaGate/Program.cs35
-rw-r--r--MatrixMediaGate/ProxyConfiguration.cs1
-rw-r--r--MatrixMediaGate/Services/AuthValidator.cs37
3 files changed, 40 insertions, 33 deletions
diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs
index 9ebac51..4a62ebc 100644
--- a/MatrixMediaGate/Program.cs
+++ b/MatrixMediaGate/Program.cs
@@ -8,11 +8,19 @@ var builder = WebApplication.CreateBuilder(args);
 

 builder.Services.AddSingleton<ProxyConfiguration>();

 builder.Services.AddSingleton<IHttpContextAccessor, HttpContextAccessor>();

-builder.Services.AddScoped<AuthValidator>();

+builder.Services.AddSingleton<AuthValidator>();

+builder.Services.AddSingleton<HttpClient>(services => {

+    var cfg = services.GetRequiredService<ProxyConfiguration>();

+    // var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None };

+    return new HttpClient() {

+        BaseAddress = new Uri(cfg.Upstream),

+        MaxResponseContentBufferSize = 1 * 1024 * 1024 // 1MB

+    };

+});

 

 var app = builder.Build();

 

-async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {

+async Task Proxy(HttpClient hc, ProxyConfiguration cfg, HttpContext ctx, ILogger<Program> logger) {

     if (ctx is null) return;

     var path = ctx.Request.Path.Value;

     if (path is null) return;

@@ -20,15 +28,13 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
         path = path[1..];

     path += ctx.Request.QueryString.Value;

 

-    using var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None };

-

-    using var hc = new HttpClient(handler) { BaseAddress = new Uri(cfg.Upstream) };

     var method = new HttpMethod(ctx.Request.Method);

     using var req = new HttpRequestMessage(method, path);

     foreach (var header in ctx.Request.Headers) {

-        // if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length")

+        if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length")

             req.Headers.Add(header.Key, header.Value.ToArray());

     }

+    req.Headers.Host = cfg.Host;

 

     if (ctx.Request.ContentLength > 0) {

         req.Content = new StreamContent(ctx.Request.Body);

@@ -41,12 +47,13 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
     using var response = await hc.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);

     ctx.Response.Headers.Clear();

     foreach (var header in response.Headers) {

-        // if (header.Key != "Transfer-Encoding")

+        if (header.Key != "Transfer-Encoding")

             ctx.Response.Headers[header.Key] = header.Value.ToArray();

     }

 

     ctx.Response.StatusCode = (int)response.StatusCode;

     ctx.Response.ContentType = response.Content.Headers.ContentType?.ToString() ?? "application/json";

+    if (response.Content.Headers.ContentLength != null) ctx.Response.ContentLength = response.Content.Headers.ContentLength;

     await ctx.Response.StartAsync();

     await using var content = await response.Content.ReadAsStreamAsync();

     await content.CopyToAsync(ctx.Response.Body);

@@ -54,17 +61,15 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
     await ctx.Response.CompleteAsync();

 }

 

-async Task ProxyMaybeAuth(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {

-    if (ctx is null) return;

-

-    await auth.UpdateAuth();

+async Task ProxyMaybeAuth(HttpClient hc, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {

+    await auth.UpdateAuth(ctx);

 

-    await Proxy(cfg, auth, ctx, logger);

+    await Proxy(hc, cfg, ctx, logger);

 }

 

-async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {

-    if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth()) {

-        await ProxyMaybeAuth(cfg, auth, ctx, logger); // Some clients may send Authorization header...

+async Task ProxyMedia(string serverName, ProxyConfiguration cfg, HttpClient hc, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {

+    if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth(ctx)) {

+        await ProxyMaybeAuth(hc, cfg, auth, ctx, logger); // Some clients may send Authorization header...

     }

     else {

         ctx.Response.StatusCode = 403;

diff --git a/MatrixMediaGate/ProxyConfiguration.cs b/MatrixMediaGate/ProxyConfiguration.cs
index 0a126d4..ebe3509 100644
--- a/MatrixMediaGate/ProxyConfiguration.cs
+++ b/MatrixMediaGate/ProxyConfiguration.cs
@@ -10,5 +10,4 @@ public class ProxyConfiguration {
     public required string Upstream { get; set; }
     public required string Host { get; set; }
     public required List<string> TrustedServers { get; set; }
-    public bool ForceHost { get; set; }
 }
\ No newline at end of file
diff --git a/MatrixMediaGate/Services/AuthValidator.cs b/MatrixMediaGate/Services/AuthValidator.cs
index 6f2b0c1..cf53cef 100644
--- a/MatrixMediaGate/Services/AuthValidator.cs
+++ b/MatrixMediaGate/Services/AuthValidator.cs
@@ -3,12 +3,18 @@ using System.Text.Json;
 
 namespace MatrixMediaGate.Services;
 
-public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg, IHttpContextAccessor ctx) {
-    private static Dictionary<string, DateTime> _authCache = new();
+public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg) {
+    private readonly Dictionary<string, DateTime> _authCache = new();
+    private readonly HttpClient _hc = new() {
+        BaseAddress = new Uri(cfg.Upstream),
+        DefaultRequestHeaders = {
+            Host = cfg.Host
+        }
+    };
 
-    public async Task UpdateAuth() {
-        if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return;
-        var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString();
+    public async Task UpdateAuth(HttpContext ctx) {
+        if (ctx.Connection.RemoteIpAddress is null) return;
+        var remote = ctx.Connection.RemoteIpAddress.ToString();
         
         if (_authCache.TryGetValue(remote, out var value)) {
             if (value > DateTime.Now.AddSeconds(30)) {
@@ -18,12 +24,10 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg
             _authCache.Remove(remote);
         }
 
-        string? token = getToken();
+        string? token = getToken(ctx);
         if (token is null) return;
-        
-        using var hc = new HttpClient();
         using var req = new HttpRequestMessage(HttpMethod.Get, $"{cfg.Upstream}/_matrix/client/v3/account/whoami?access_token={token}");
-        var response = await hc.SendAsync(req);
+        var response = await _hc.SendAsync(req);
 
         if (response.Content.Headers.ContentType?.MediaType != "application/json") return;
         var content = await response.Content.ReadAsStringAsync();
@@ -34,9 +38,9 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg
         }
     }
 
-    public bool ValidateAuth() {
-        if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return false;
-        var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString();
+    public bool ValidateAuth(HttpContext ctx) {
+        if (ctx.Connection.RemoteIpAddress is null) return false;
+        var remote = ctx.Connection.RemoteIpAddress.ToString();
         
         if (_authCache.ContainsKey(remote)) {
             if (_authCache[remote] > DateTime.Now) {
@@ -49,13 +53,12 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg
         return false;
     }
 
-    private string? getToken() {
-        if (ctx.HttpContext is null) return null;
-        if (ctx.HttpContext.Request.Headers.TryGetValue("Authorization", out var header)) {
+    private string? getToken(HttpContext ctx) {
+        if (ctx.Request.Headers.TryGetValue("Authorization", out var header)) {
             return header.ToString().Split(' ', 2)[1];
         }
-        else if (ctx.HttpContext.Request.Query.ContainsKey("access_token")) {
-            return ctx.HttpContext.Request.Query["access_token"]!;
+        else if (ctx.Request.Query.ContainsKey("access_token")) {
+            return ctx.Request.Query["access_token"]!;
         }
         else {
             return null;