about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--.idea/.idea.MatrixMediaGate/.idea/indexLayout.xml8
-rw-r--r--.idea/.idea.MatrixMediaGate/.idea/vcs.xml6
-rw-r--r--MatrixMediaGate/Program.cs25
-rw-r--r--MatrixMediaGate/Services/AuthValidator.cs27
4 files changed, 40 insertions, 26 deletions
diff --git a/.idea/.idea.MatrixMediaGate/.idea/indexLayout.xml b/.idea/.idea.MatrixMediaGate/.idea/indexLayout.xml
new file mode 100644
index 0000000..7b08163
--- /dev/null
+++ b/.idea/.idea.MatrixMediaGate/.idea/indexLayout.xml
@@ -0,0 +1,8 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="UserContentModel">
+    <attachedFolders />
+    <explicitIncludes />
+    <explicitExcludes />
+  </component>
+</project>
\ No newline at end of file
diff --git a/.idea/.idea.MatrixMediaGate/.idea/vcs.xml b/.idea/.idea.MatrixMediaGate/.idea/vcs.xml
new file mode 100644
index 0000000..35eb1dd
--- /dev/null
+++ b/.idea/.idea.MatrixMediaGate/.idea/vcs.xml
@@ -0,0 +1,6 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<project version="4">
+  <component name="VcsDirectoryMappings">
+    <mapping directory="" vcs="Git" />
+  </component>
+</project>
\ No newline at end of file
diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs
index a812d78..9ebac51 100644
--- a/MatrixMediaGate/Program.cs
+++ b/MatrixMediaGate/Program.cs
@@ -15,14 +15,10 @@ var app = builder.Build();
 async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {

     if (ctx is null) return;

     var path = ctx.Request.Path.Value;

-    if(path is null) return;

+    if (path is null) return;

     if (path.StartsWith('/'))

         path = path[1..];

     path += ctx.Request.QueryString.Value;

-    

-#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed

-    auth.UpdateAuth();

-#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed

 

     using var handler = new HttpClientHandler() { AutomaticDecompression = DecompressionMethods.None };

 

@@ -30,7 +26,8 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
     var method = new HttpMethod(ctx.Request.Method);

     using var req = new HttpRequestMessage(method, path);

     foreach (var header in ctx.Request.Headers) {

-        if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length") req.Headers.Add(header.Key, header.Value.ToArray());

+        // if (header.Key != "Accept-Encoding" && header.Key != "Content-Type" && header.Key != "Content-Length")

+            req.Headers.Add(header.Key, header.Value.ToArray());

     }

 

     if (ctx.Request.ContentLength > 0) {

@@ -44,7 +41,8 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
     using var response = await hc.SendAsync(req, HttpCompletionOption.ResponseHeadersRead);

     ctx.Response.Headers.Clear();

     foreach (var header in response.Headers) {

-        if (header.Key != "Transfer-Encoding") ctx.Response.Headers[header.Key] = header.Value.ToArray();

+        // if (header.Key != "Transfer-Encoding")

+            ctx.Response.Headers[header.Key] = header.Value.ToArray();

     }

 

     ctx.Response.StatusCode = (int)response.StatusCode;

@@ -56,9 +54,17 @@ async Task Proxy(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, IL
     await ctx.Response.CompleteAsync();

 }

 

+async Task ProxyMaybeAuth(ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {

+    if (ctx is null) return;

+

+    await auth.UpdateAuth();

+

+    await Proxy(cfg, auth, ctx, logger);

+}

+

 async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {

     if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth()) {

-        await Proxy(cfg, auth, ctx, logger);

+        await ProxyMaybeAuth(cfg, auth, ctx, logger); // Some clients may send Authorization header...

     }

     else {

         ctx.Response.StatusCode = 403;

@@ -71,7 +77,8 @@ async Task ProxyMedia(string serverName, ProxyConfiguration cfg, AuthValidator a
     }

 }

 

-app.Map("{*_}", Proxy);

+app.Map("{*_}", ProxyMaybeAuth);

+app.Map("/_matrix/federation/{*_}", Proxy);

 

 foreach (var route in (string[]) [

              "/_matrix/media/{version}/download/{serverName}/{mediaId}",

diff --git a/MatrixMediaGate/Services/AuthValidator.cs b/MatrixMediaGate/Services/AuthValidator.cs
index 08ccd14..6f2b0c1 100644
--- a/MatrixMediaGate/Services/AuthValidator.cs
+++ b/MatrixMediaGate/Services/AuthValidator.cs
@@ -6,43 +6,36 @@ namespace MatrixMediaGate.Services;
 public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg, IHttpContextAccessor ctx) {
     private static Dictionary<string, DateTime> _authCache = new();
 
-    public async Task<bool> UpdateAuth() {
-        if (ctx.HttpContext is null) return false;
-        if (ctx.HttpContext.Connection.RemoteIpAddress is null) return false;
+    public async Task UpdateAuth() {
+        if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return;
         var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString();
-
-
+        
         if (_authCache.TryGetValue(remote, out var value)) {
             if (value > DateTime.Now.AddSeconds(30)) {
-                return true;
+                return;
             }
 
             _authCache.Remove(remote);
         }
 
         string? token = getToken();
-        if (token is null) return false;
+        if (token is null) return;
         
         using var hc = new HttpClient();
         using var req = new HttpRequestMessage(HttpMethod.Get, $"{cfg.Upstream}/_matrix/client/v3/account/whoami?access_token={token}");
-        req.Headers.Host = cfg.Host;
         var response = await hc.SendAsync(req);
 
-        if (response.Content.Headers.ContentType?.MediaType != "application/json") return false;
+        if (response.Content.Headers.ContentType?.MediaType != "application/json") return;
         var content = await response.Content.ReadAsStringAsync();
         var json = JsonDocument.Parse(content);
         if (json.RootElement.TryGetProperty("user_id", out var userId)) {
             _authCache[remote] = DateTime.Now.AddMinutes(5);
             logger.LogInformation("Authenticated {userId} on {remote}, expiring at {time}", userId, remote, _authCache[remote]);
-            return true;
         }
-
-        return false;
     }
 
     public bool ValidateAuth() {
-        if (ctx.HttpContext is null) return false;
-        if (ctx.HttpContext.Connection.RemoteIpAddress is null) return false;
+        if (ctx.HttpContext?.Connection.RemoteIpAddress is null) return false;
         var remote = ctx.HttpContext.Connection.RemoteIpAddress.ToString();
         
         if (_authCache.ContainsKey(remote)) {
@@ -57,9 +50,9 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg
     }
 
     private string? getToken() {
-        if (ctx is null) return null;
-        if (ctx.HttpContext.Request.Headers.ContainsKey("Authorization")) {
-            return ctx.HttpContext.Request.Headers["Authorization"].ToString().Split(' ', 2)[1];
+        if (ctx.HttpContext is null) return null;
+        if (ctx.HttpContext.Request.Headers.TryGetValue("Authorization", out var header)) {
+            return header.ToString().Split(' ', 2)[1];
         }
         else if (ctx.HttpContext.Request.Query.ContainsKey("access_token")) {
             return ctx.HttpContext.Request.Query["access_token"]!;