diff options
author | Rory& <root@rory.gay> | 2024-02-26 17:10:46 +0100 |
---|---|---|
committer | Rory& <root@rory.gay> | 2024-02-26 17:10:46 +0100 |
commit | 9371a2eb9d10d9492a168fbb735ba0b0e4d76671 (patch) | |
tree | bca4c1ba38370e46265251f00945e9ef1f71f438 | |
parent | Nix: add mainProgram (diff) | |
download | MatrixMediaGate-9371a2eb9d10d9492a168fbb735ba0b0e4d76671.tar.xz |
Fix auth code to be excluded on federation.
-rw-r--r-- | .idea/.idea.MatrixMediaGate/.idea/indexLayout.xml | 8 | ||||
-rw-r--r-- | .idea/.idea.MatrixMediaGate/.idea/vcs.xml | 6 | ||||
-rw-r--r-- | MatrixMediaGate/Program.cs | 25 | ||||
-rw-r--r-- | MatrixMediaGate/Services/AuthValidator.cs | 27 |
4 files changed, 40 insertions, 26 deletions
diff --git a/.idea/.idea.MatrixMediaGate/.idea/indexLayout.xml b/.idea/.idea.MatrixMediaGate/.idea/indexLayout.xml new file mode 100644 index 0000000..7b08163 --- /dev/null +++ b/.idea/.idea.MatrixMediaGate/.idea/indexLayout.xml @@ -0,0 +1,8 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="UserContentModel"> + <attachedFolders /> + <explicitIncludes /> + <explicitExcludes /> + </component> +</project> \ No newline at end of file diff --git a/.idea/.idea.MatrixMediaGate/.idea/vcs.xml b/.idea/.idea.MatrixMediaGate/.idea/vcs.xml new file mode 100644 index 0000000..35eb1dd --- /dev/null +++ b/.idea/.idea.MatrixMediaGate/.idea/vcs.xml @@ -0,0 +1,6 @@ +<?xml version="1.0" encoding="UTF-8"?> +<project version="4"> + <component name="VcsDirectoryMappings"> + <mapping directory="" vcs="Git" /> + </component> +</project> \ No newline at end of file diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs index a812d78..9ebac51 100644 --- a/MatrixMediaGate/Program.cs +++ b/MatrixMediaGate/Program.cs @@ -15,14 +15,10 @@ var app = builder.Build(); async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) { if (ctx is null) return; var path = ctx.Request.Path.Value; - if(path is null) return; + if (path is null) return; if (path.StartsWith('/')) path = path[1..]; path += ctx.Request.QueryString.Value; - -#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - auth.UpdateAuth(); -#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed using var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None }; @@ -30,7 +26,8 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL var method = new HttpMethod(ctx.Request.Method); using var req = new HttpRequestMessage(method, path); foreach (var header in ctx.Request.Headers) { - if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") req.Headers.Add(header.Key, header.Value.ToArray()); + // if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") + req.Headers.Add(header.Key, header.Value.ToArray()); } if (ctx.Request.ContentLength > 0) { @@ -44,7 +41,8 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL using var response = await hc.SendAsync(req, HttpCompletionOption.ResponseHeadersRead); ctx.Response.Headers.Clear(); foreach (var header in response.Headers) { - if (header.Key != "Transfer-Encoding") ctx.Response.Headers[header.Key] = header.Value.ToArray(); + // if (header.Key != "Transfer-Encoding") + ctx.Response.Headers[header.Key] = header.Value.ToArray(); } ctx.Response.StatusCode = (int)response.StatusCode; @@ -56,9 +54,17 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL await ctx.Response.CompleteAsync(); } +async Task ProxyMaybeAuth(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) { + if (ctx is null) return; + + await auth.UpdateAuth(); + + await Proxy(cfg, auth, ctx, logger); +} + async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) { if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth()) { - await Proxy(cfg, auth, ctx, logger); + await ProxyMaybeAuth(cfg, auth, ctx, logger); // Some clients may send Authorization header... } else { ctx.Response.StatusCode = 403; @@ -71,7 +77,8 @@ async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator a } } -app.Map("{*_}", Proxy); +app.Map("{*_}", ProxyMaybeAuth); +app.Map("/_matrix/federation/{*_}", Proxy); foreach (var route in (string[]) [ "/_matrix/media/{version}/download/{serverName}/{mediaId}", diff --git a/MatrixMediaGate/Services/AuthValidator.cs b/MatrixMediaGate/Services/AuthValidator.cs index 08ccd14..6f2b0c1 100644 --- a/MatrixMediaGate/Services/AuthValidator.cs +++ b/MatrixMediaGate/Services/AuthValidator.cs @@ -6,43 +6,36 @@ namespace MatrixMediaGate.Services; public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg, IHttpContextAccessor ctx) { private static Dictionary<string, DateTime> _authCache = new(); - public async Task<bool> UpdateAuth() { - if (ctx.HttpContext is null) return false; - if (ctx.HttpContext.Connection.RemoteIpAddress is null) return false; + public async Task UpdateAuth() { + if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return; var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString(); - - + if (_authCache.TryGetValue(remote, out var value)) { if (value > DateTime.Now.AddSeconds(30)) { - return true; + return; } _authCache.Remove(remote); } string? token = getToken(); - if (token is null) return false; + if (token is null) return; using var hc = new HttpClient(); using var req = new HttpRequestMessage(HttpMethod.Get, $"{cfg.Upstream}/_matrix/client/v3/account/whoami?access_token={token}"); - req.Headers.Host = cfg.Host; var response = await hc.SendAsync(req); - if (response.Content.Headers.ContentType?.MediaType != "application/json") return false; + if (response.Content.Headers.ContentType?.MediaType != "application/json") return; var content = await response.Content.ReadAsStringAsync(); var json = JsonDocument.Parse(content); if (json.RootElement.TryGetProperty("user_id", out var userId)) { _authCache[remote] = DateTime.Now.AddMinutes(5); logger.LogInformation("Authenticated {userId} on {remote}, expiring at {time}", userId, remote, _authCache[remote]); - return true; } - - return false; } public bool ValidateAuth() { - if (ctx.HttpContext is null) return false; - if (ctx.HttpContext.Connection.RemoteIpAddress is null) return false; + if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return false; var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString(); if (_authCache.ContainsKey(remote)) { @@ -57,9 +50,9 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg } private string? getToken() { - if (ctx is null) return null; - if (ctx.HttpContext.Request.Headers.ContainsKey("Authorization")) { - return ctx.HttpContext.Request.Headers["Authorization"].ToString().Split(' ', 2)[1]; + if (ctx.HttpContext is null) return null; + if (ctx.HttpContext.Request.Headers.TryGetValue("Authorization", out var header)) { + return header.ToString().Split(' ', 2)[1]; } else if (ctx.HttpContext.Request.Query.ContainsKey("access_token")) { return ctx.HttpContext.Request.Query["access_token"]!; |