diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs
index 9436230..19ae4f8 100644
--- a/MatrixMediaGate/Program.cs
+++ b/MatrixMediaGate/Program.cs
@@ -18,56 +18,93 @@ builder.Services.AddSingleton<HttpClient>(services => {
var app = builder.Build();
+app.Map("{*_}", ProxyMaybeAuth);
+app.Map("/_matrix/federation/{*_}", Proxy); // Don't bother with auth for federation
+
+foreach (var route in (string[]) [ // Require recent auth for these routes
+ "/_matrix/media/{version}/download/{serverName}/{mediaId}",
+ "/_matrix/media/{version}/download/{serverName}/{mediaId}/{fileName}",
+ "/_matrix/media/{version}/thumbnail/{serverName}/{mediaId}",
+ ])
+ app.Map(route, ProxyMedia);
+
+app.Run();
+
+// Proxy a request
async Task Proxy(HttpClient hc, ProxyConfiguration cfg, HttpContext ctx, ILogger<Program> logger) {
- var path = ctx.Request.GetEncodedPathAndQuery();
- if (path.StartsWith('/'))
- path = path[1..];
-
- var method = new HttpMethod(ctx.Request.Method);
- using var req = new HttpRequestMessage(method, path);
- hc.DefaultRequestHeaders.Clear();
- req.Headers.Clear();
- foreach (var header in ctx.Request.Headers) {
- if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") {
- // req.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
- req.Headers.Remove(header.Key);
- req.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
+ HttpRequestMessage? upstreamRequest = null;
+ HttpResponseMessage? upstreamResponse = null;
+ Exception? exception = null;
+ try {
+ var path = ctx.Request.GetEncodedPathAndQuery();
+ if (path.StartsWith('/'))
+ path = path[1..];
+
+ var method = new HttpMethod(ctx.Request.Method);
+ upstreamRequest = new HttpRequestMessage(method, path);
+ hc.DefaultRequestHeaders.Clear();
+ upstreamRequest.Headers.Clear();
+ foreach (var header in ctx.Request.Headers) {
+ if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") {
+ // req.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
+ upstreamRequest.Headers.Remove(header.Key);
+ upstreamRequest.Headers.TryAddWithoutValidation(header.Key, header.Value.ToArray());
+ }
}
- }
- req.Headers.Host = cfg.Host;
+ upstreamRequest.Headers.Host = cfg.Host;
- if (ctx.Request.ContentLength > 0) {
- req.Content = new StreamContent(ctx.Request.Body);
- if (ctx.Request.ContentType != null) req.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(ctx.Request.ContentType);
-
- if (ctx.Request.ContentLength != null) req.Content.Headers.ContentLength = ctx.Request.ContentLength;
- }
+ if (ctx.Request.ContentLength > 0) {
+ upstreamRequest.Content = new StreamContent(ctx.Request.Body);
+ if (ctx.Request.ContentType != null) upstreamRequest.Content.Headers.ContentType = MediaTypeHeaderValue.Parse(ctx.Request.ContentType);
- logger.LogInformation("Proxying {method} {path} to {target}", method, path, hc.BaseAddress + path);
+ if (ctx.Request.ContentLength != null) upstreamRequest.Content.Headers.ContentLength = ctx.Request.ContentLength;
+ }
+
+ logger.LogInformation("Proxying {method} {path} to {target}", method, path, hc.BaseAddress + path);
- using var response = await hc.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);
- ctx.Response.Headers.Clear();
- foreach (var header in response.Headers) {
- if (header.Key != "Transfer-Encoding")
- ctx.Response.Headers[header.Key] = header.Value.ToArray();
+ upstreamResponse = await hc.SendAsync(upstreamRequest, HttpCompletionOption.ResponseHeadersRead);
+ ctx.Response.Headers.Clear();
+ foreach (var header in upstreamResponse.Headers) {
+ if (header.Key != "Transfer-Encoding")
+ ctx.Response.Headers[header.Key] = header.Value.ToArray();
+ }
+
+ ctx.Response.StatusCode = (int)upstreamResponse.StatusCode;
+ ctx.Response.ContentType = upstreamResponse.Content.Headers.ContentType?.ToString() ?? "application/json";
+ if (upstreamResponse.Content.Headers.ContentLength != null) ctx.Response.ContentLength = upstreamResponse.Content.Headers.ContentLength;
+ await ctx.Response.StartAsync();
+ await using var content = await upstreamResponse.Content.ReadAsStreamAsync();
+ await content.CopyToAsync(ctx.Response.Body);
+ }
+ catch (HttpRequestException e) {
+ exception = e;
+ logger.LogError(e, "Failed to proxy request");
+ ctx.Response.StatusCode = 502;
+ ctx.Response.ContentType = "application/json";
+ await ctx.Response.StartAsync();
+ await JsonSerializer.SerializeAsync(ctx.Response.Body, new { errcode = "M_UNAVAILABLE", error = "Failed to proxy request" });
+ }
+ finally {
+ await ctx.Response.Body.FlushAsync();
+ await ctx.Response.CompleteAsync();
+ if (ctx.Response.StatusCode >= 400) {
+ await ProxyDump(cfg, ctx, upstreamRequest, upstreamResponse, exception);
+ }
+
+ upstreamRequest?.Dispose();
+ upstreamResponse?.Dispose();
}
- ctx.Response.StatusCode = (int)response.StatusCode;
- ctx.Response.ContentType = response.Content.Headers.ContentType?.ToString() ?? "application/json";
- if (response.Content.Headers.ContentLength != null) ctx.Response.ContentLength = response.Content.Headers.ContentLength;
- await ctx.Response.StartAsync();
- await using var content = await response.Content.ReadAsStreamAsync();
- await content.CopyToAsync(ctx.Response.Body);
- await ctx.Response.Body.FlushAsync();
- await ctx.Response.CompleteAsync();
}
+// We attempt to update auth, but we don't require it
async Task ProxyMaybeAuth(HttpClient hc, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
await auth.UpdateAuth(ctx);
await Proxy(hc, cfg, ctx, logger);
}
+// We know this is a media path, we require recent auth here to prevent abuse
async Task ProxyMedia(string serverName, ProxyConfiguration cfg, HttpClient hc, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
// Some clients may send Authorization header, so we handle this last...
if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth(ctx) || await auth.UpdateAuth(ctx)) {
@@ -83,14 +120,46 @@ async Task ProxyMedia(string serverName, ProxyConfiguration cfg, HttpClient hc,
}
}
-app.Map("{*_}", ProxyMaybeAuth);
-app.Map("/_matrix/federation/{*_}", Proxy); // Don't bother with auth for federation
-
-foreach (var route in (string[]) [ // Require recent auth for these routes
- "/_matrix/media/{version}/download/{serverName}/{mediaId}",
- "/_matrix/media/{version}/download/{serverName}/{mediaId}/{fileName}",
- "/_matrix/media/{version}/thumbnail/{serverName}/{mediaId}",
- ])
- app.Map(route, ProxyMedia);
-
-app.Run();
\ No newline at end of file
+var jsonOptions = new JsonSerializerOptions {
+ WriteIndented = true
+};
+// We dump failed requests to disk
+async Task ProxyDump(ProxyConfiguration cfg, HttpContext ctx, HttpRequestMessage? req, HttpResponseMessage? resp, Exception? e) {
+ if (ctx.Response.StatusCode >= 400 && cfg.DumpFailedRequests) {
+ var path = Path.Combine(cfg.DumpPath, "failed_requests", $"{resp.StatusCode}-{DateTimeOffset.UtcNow.ToUnixTimeSeconds()}-{ctx.Request.Path}.json");
+ await using var file = File.Create(path);
+ await JsonSerializer.SerializeAsync(file, new {
+ Self = new {
+ Request = new {
+ ctx.Request.Method,
+ Url = ctx.Request.GetEncodedUrl(),
+ ctx.Request.Headers
+ },
+ Response = new {
+ ctx.Response.StatusCode,
+ ctx.Response.Headers,
+ ctx.Response.ContentType,
+ ctx.Response.ContentLength
+ }
+ },
+ Upstream = new {
+ Request = new {
+ req?.Method,
+ Url = req?.RequestUri,
+ req?.Headers
+ },
+ Response = new {
+ resp.StatusCode,
+ resp.Headers,
+ resp.Content.Headers.ContentType,
+ resp.Content.Headers.ContentLength
+ }
+ },
+ Exception = new {
+ Type = e?.GetType().ToString(),
+ Message = e?.Message.ReplaceLineEndings().Split(Environment.NewLine),
+ StackTrace = e?.StackTrace?.ReplaceLineEndings().Split(Environment.NewLine)
+ }
+ });
+ }
+}
\ No newline at end of file
|