diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs
index a812d78..9ebac51 100644
--- a/MatrixMediaGate/Program.cs
+++ b/MatrixMediaGate/Program.cs
@@ -15,14 +15,10 @@ var app = builder.Build();
async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
if (ctx is null) return;
var path = ctx.Request.Path.Value;
- if(path is null) return;
+ if (path is null) return;
if (path.StartsWith('/'))
path = path[1..];
path += ctx.Request.QueryString.Value;
-
-#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
- auth.UpdateAuth();
-#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
using var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None };
@@ -30,7 +26,8 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
var method = new HttpMethod(ctx.Request.Method);
using var req = new HttpRequestMessage(method, path);
foreach (var header in ctx.Request.Headers) {
- if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") req.Headers.Add(header.Key, header.Value.ToArray());
+ // if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length")
+ req.Headers.Add(header.Key, header.Value.ToArray());
}
if (ctx.Request.ContentLength > 0) {
@@ -44,7 +41,8 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
using var response = await hc.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);
ctx.Response.Headers.Clear();
foreach (var header in response.Headers) {
- if (header.Key != "Transfer-Encoding") ctx.Response.Headers[header.Key] = header.Value.ToArray();
+ // if (header.Key != "Transfer-Encoding")
+ ctx.Response.Headers[header.Key] = header.Value.ToArray();
}
ctx.Response.StatusCode = (int)response.StatusCode;
@@ -56,9 +54,17 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
await ctx.Response.CompleteAsync();
}
+async Task ProxyMaybeAuth(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
+ if (ctx is null) return;
+
+ await auth.UpdateAuth();
+
+ await Proxy(cfg, auth, ctx, logger);
+}
+
async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth()) {
- await Proxy(cfg, auth, ctx, logger);
+ await ProxyMaybeAuth(cfg, auth, ctx, logger); // Some clients may send Authorization header...
}
else {
ctx.Response.StatusCode = 403;
@@ -71,7 +77,8 @@ async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator a
}
}
-app.Map("{*_}", Proxy);
+app.Map("{*_}", ProxyMaybeAuth);
+app.Map("/_matrix/federation/{*_}", Proxy);
foreach (var route in (string[]) [
"/_matrix/media/{version}/download/{serverName}/{mediaId}",
diff --git a/MatrixMediaGate/Services/AuthValidator.cs b/MatrixMediaGate/Services/AuthValidator.cs
index 08ccd14..6f2b0c1 100644
--- a/MatrixMediaGate/Services/AuthValidator.cs
+++ b/MatrixMediaGate/Services/AuthValidator.cs
@@ -6,43 +6,36 @@ namespace MatrixMediaGate.Services;
public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg, IHttpContextAccessor ctx) {
private static Dictionary<string, DateTime> _authCache = new();
- public async Task<bool> UpdateAuth() {
- if (ctx.HttpContext is null) return false;
- if (ctx.HttpContext.Connection.RemoteIpAddress is null) return false;
+ public async Task UpdateAuth() {
+ if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return;
var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString();
-
-
+
if (_authCache.TryGetValue(remote, out var value)) {
if (value > DateTime.Now.AddSeconds(30)) {
- return true;
+ return;
}
_authCache.Remove(remote);
}
string? token = getToken();
- if (token is null) return false;
+ if (token is null) return;
using var hc = new HttpClient();
using var req = new HttpRequestMessage(HttpMethod.Get, $"{cfg.Upstream}/_matrix/client/v3/account/whoami?access_token={token}");
- req.Headers.Host = cfg.Host;
var response = await hc.SendAsync(req);
- if (response.Content.Headers.ContentType?.MediaType != "application/json") return false;
+ if (response.Content.Headers.ContentType?.MediaType != "application/json") return;
var content = await response.Content.ReadAsStringAsync();
var json = JsonDocument.Parse(content);
if (json.RootElement.TryGetProperty("user_id", out var userId)) {
_authCache[remote] = DateTime.Now.AddMinutes(5);
logger.LogInformation("Authenticated {userId} on {remote}, expiring at {time}", userId, remote, _authCache[remote]);
- return true;
}
-
- return false;
}
public bool ValidateAuth() {
- if (ctx.HttpContext is null) return false;
- if (ctx.HttpContext.Connection.RemoteIpAddress is null) return false;
+ if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return false;
var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString();
if (_authCache.ContainsKey(remote)) {
@@ -57,9 +50,9 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg
}
private string? getToken() {
- if (ctx is null) return null;
- if (ctx.HttpContext.Request.Headers.ContainsKey("Authorization")) {
- return ctx.HttpContext.Request.Headers["Authorization"].ToString().Split(' ', 2)[1];
+ if (ctx.HttpContext is null) return null;
+ if (ctx.HttpContext.Request.Headers.TryGetValue("Authorization", out var header)) {
+ return header.ToString().Split(' ', 2)[1];
}
else if (ctx.HttpContext.Request.Query.ContainsKey("access_token")) {
return ctx.HttpContext.Request.Query["access_token"]!;
|