about summary refs log tree commit diff
path: root/MatrixMediaGate/Program.cs
diff options
context:
space:
mode:
authorRory& <root@rory.gay>2024-02-29 09:00:17 +0000
committerRory& <root@rory.gay>2024-02-29 09:00:17 +0000
commit8f10d01613e794c145f4d61fc50924a765d37f0e (patch)
treeb88e12359bae3a0fbb3bc5fb6a2e55cb94ae25ca /MatrixMediaGate/Program.cs
parentAlternate header handling for request headers (diff)
downloadMatrixMediaGate-8f10d01613e794c145f4d61fc50924a765d37f0e.tar.xz
Dump failed requests
Diffstat (limited to 'MatrixMediaGate/Program.cs')
-rw-r--r--MatrixMediaGate/Program.cs161
1 files changed, 115 insertions, 46 deletions
diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs
index 9436230..19ae4f8 100644
--- a/MatrixMediaGate/Program.cs
+++ b/MatrixMediaGate/Program.cs
@@ -18,56 +18,93 @@ builder.Services.AddSingleton<HttpClient>(services => {
 
 var app = builder.Build();
 
+app.Map("{*_}", ProxyMaybeAuth);
+app.Map("/_matrix/federation/{*_}", Proxy); // Don't bother with auth for federation
+
+foreach (var route in (string[]) [ // Require recent auth for these routes
+             "/_matrix/media/{version}/download/{serverName}/{mediaId}",
+             "/_matrix/media/{version}/download/{serverName}/{mediaId}/{fileName}",
+             "/_matrix/media/{version}/thumbnail/{serverName}/{mediaId}",
+         ])
+    app.Map(route, ProxyMedia);
+
+app.Run();
+
+// Proxy a request
 async Task Proxy(HttpClient hc, ProxyConfiguration cfg, HttpContext ctx, ILogger<Program> logger) {
-    var path = ctx.Request.GetEncodedPathAndQuery();
-    if (path.StartsWith('/'))
-        path = path[1..];
-
-    var method = new HttpMethod(ctx.Request.Method);
-    using var req = new HttpRequestMessage(method, path);
-    hc.DefaultRequestHeaders.Clear();
-    req.Headers.Clear();
-    foreach (var header in ctx.Request.Headers) {
-        if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") {
-            // req.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
-            req.Headers.Remove(header.Key);
-            req.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
+    HttpRequestMessage? upstreamRequest = null;
+    HttpResponseMessage? upstreamResponse = null;
+    Exception? exception = null;
+    try {
+        var path = ctx.Request.GetEncodedPathAndQuery();
+        if (path.StartsWith('/'))
+            path = path[1..];
+
+        var method = new HttpMethod(ctx.Request.Method);
+        upstreamRequest = new HttpRequestMessage(method, path);
+        hc.DefaultRequestHeaders.Clear();
+        upstreamRequest.Headers.Clear();
+        foreach (var header in ctx.Request.Headers) {
+            if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") {
+                // req.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
+                upstreamRequest.Headers.Remove(header.Key);
+                upstreamRequest.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
+            }
         }
-    }
 
-    req.Headers.Host = cfg.Host;
+        upstreamRequest.Headers.Host = cfg.Host;
 
-    if (ctx.Request.ContentLength > 0) {
-        req.Content = new StreamContent(ctx.Request.Body);
-        if (ctx.Request.ContentType != null) req.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(ctx.Request.ContentType);
-        
-        if (ctx.Request.ContentLength != null) req.Content.Headers.ContentLength = ctx.Request.ContentLength;
-    }
+        if (ctx.Request.ContentLength > 0) {
+            upstreamRequest.Content = new StreamContent(ctx.Request.Body);
+            if (ctx.Request.ContentType != null) upstreamRequest.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(ctx.Request.ContentType);
 
-    logger.LogInformation("Proxying {method} {path} to {target}", method, path, hc.BaseAddress + path);
+            if (ctx.Request.ContentLength != null) upstreamRequest.Content.Headers.ContentLength = ctx.Request.ContentLength;
+        }
+
+        logger.LogInformation("Proxying {method} {path} to {target}", method, path, hc.BaseAddress + path);
 
-    using var response = await hc.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);
-    ctx.Response.Headers.Clear();
-    foreach (var header in response.Headers) {
-        if (header.Key != "Transfer-Encoding")
-            ctx.Response.Headers[header.Key] = header.Value.ToArray();
+        upstreamResponse = await hc.SendAsync(upstreamRequest, HttpCompletionOption.ResponseHeadersRead);
+        ctx.Response.Headers.Clear();
+        foreach (var header in upstreamResponse.Headers) {
+            if (header.Key != "Transfer-Encoding")
+                ctx.Response.Headers[header.Key] = header.Value.ToArray();
+        }
+
+        ctx.Response.StatusCode = (int)upstreamResponse.StatusCode;
+        ctx.Response.ContentType = upstreamResponse.Content.Headers.ContentType?.ToString() ?? "application/json";
+        if (upstreamResponse.Content.Headers.ContentLength != null) ctx.Response.ContentLength = upstreamResponse.Content.Headers.ContentLength;
+        await ctx.Response.StartAsync();
+        await using var content = await upstreamResponse.Content.ReadAsStreamAsync();
+        await content.CopyToAsync(ctx.Response.Body);
+    }
+    catch (HttpRequestException e) {
+        exception = e;
+        logger.LogError(e, "Failed to proxy request");
+        ctx.Response.StatusCode = 502;
+        ctx.Response.ContentType = "application/json";
+        await ctx.Response.StartAsync();
+        await JsonSerializer.SerializeAsync(ctx.Response.Body, new { errcode = "M_UNAVAILABLE", error = "Failed to proxy request" });
+    }
+    finally {
+        await ctx.Response.Body.FlushAsync();
+        await ctx.Response.CompleteAsync();
+        if (ctx.Response.StatusCode >= 400) {
+            await ProxyDump(cfg, ctx, upstreamRequest, upstreamResponse, exception);
+        }
+        
+        upstreamRequest?.Dispose();
+        upstreamResponse?.Dispose();
     }
 
-    ctx.Response.StatusCode = (int)response.StatusCode;
-    ctx.Response.ContentType = response.Content.Headers.ContentType?.ToString() ?? "application/json";
-    if (response.Content.Headers.ContentLength != null) ctx.Response.ContentLength = response.Content.Headers.ContentLength;
-    await ctx.Response.StartAsync();
-    await using var content = await response.Content.ReadAsStreamAsync();
-    await content.CopyToAsync(ctx.Response.Body);
-    await ctx.Response.Body.FlushAsync();
-    await ctx.Response.CompleteAsync();
 }
 
+// We attempt to update auth, but we don't require it
 async Task ProxyMaybeAuth(HttpClient hc, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
     await auth.UpdateAuth(ctx);
     await Proxy(hc, cfg, ctx, logger);
 }
 
+// We know this is a media path, we require recent auth here to prevent abuse
 async Task ProxyMedia(string serverName, ProxyConfiguration cfg, HttpClient hc, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
     // Some clients may send Authorization header, so we handle this last...
     if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth(ctx) || await auth.UpdateAuth(ctx)) {
@@ -83,14 +120,46 @@ async Task ProxyMedia(string serverName, ProxyConfiguration cfg, HttpClient hc,
     }
 }
 
-app.Map("{*_}", ProxyMaybeAuth);
-app.Map("/_matrix/federation/{*_}", Proxy); // Don't bother with auth for federation
-
-foreach (var route in (string[]) [ // Require recent auth for these routes
-             "/_matrix/media/{version}/download/{serverName}/{mediaId}",
-             "/_matrix/media/{version}/download/{serverName}/{mediaId}/{fileName}",
-             "/_matrix/media/{version}/thumbnail/{serverName}/{mediaId}",
-         ])
-    app.Map(route, ProxyMedia);
-
-app.Run();
\ No newline at end of file
+var jsonOptions = new JsonSerializerOptions {
+    WriteIndented = true
+};
+// We dump failed requests to disk
+async Task ProxyDump(ProxyConfiguration cfg, HttpContext ctx, HttpRequestMessage? req, HttpResponseMessage? resp, Exception? e) {
+    if (ctx.Response.StatusCode >= 400 && cfg.DumpFailedRequests) {
+        var path = Path.Combine(cfg.DumpPath, "failed_requests", $"{resp.StatusCode}-{DateTimeOffset.UtcNow.ToUnixTimeSeconds()}-{ctx.Request.Path}.json");
+        await using var file = File.Create(path);
+        await JsonSerializer.SerializeAsync(file, new {
+            Self = new {
+                Request = new {
+                    ctx.Request.Method,
+                    Url = ctx.Request.GetEncodedUrl(),
+                    ctx.Request.Headers
+                },
+                Response = new {
+                    ctx.Response.StatusCode,
+                    ctx.Response.Headers,
+                    ctx.Response.ContentType,
+                    ctx.Response.ContentLength
+                }
+            },
+            Upstream = new {
+                Request = new {
+                    req?.Method,
+                    Url = req?.RequestUri,
+                    req?.Headers
+                },
+                Response = new {
+                    resp.StatusCode,
+                    resp.Headers,
+                    resp.Content.Headers.ContentType,
+                    resp.Content.Headers.ContentLength
+                }
+            },
+            Exception = new {
+                Type = e?.GetType().ToString(),
+                Message = e?.Message.ReplaceLineEndings().Split(Environment.NewLine),
+                StackTrace = e?.StackTrace?.ReplaceLineEndings().Split(Environment.NewLine)
+            }
+        });
+    }
+}
\ No newline at end of file