diff --git a/MatrixMediaGate/Program.cs b/MatrixMediaGate/Program.cs
index 4a62ebc..ee06b9a 100644
--- a/MatrixMediaGate/Program.cs
+++ b/MatrixMediaGate/Program.cs
@@ -68,8 +68,9 @@ async Task ProxyMaybeAuth(HttpClient hc, ProxyConfiguration cfg, AuthValidator a
}
async Task ProxyMedia(string serverName, ProxyConfiguration cfg, HttpClient hc, AuthValidator auth, HttpContext ctx, ILogger<Program> logger) {
- if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth(ctx)) {
- await ProxyMaybeAuth(hc, cfg, auth, ctx, logger); // Some clients may send Authorization header...
+ // Some clients may send Authorization header, so we handle this last...
+ if (cfg.TrustedServers.Contains(serverName) || auth.ValidateAuth(ctx) || await auth.UpdateAuth(ctx)) {
+ await Proxy(hc, cfg, ctx, logger);
}
else {
ctx.Response.StatusCode = 403;
diff --git a/MatrixMediaGate/Services/AuthValidator.cs b/MatrixMediaGate/Services/AuthValidator.cs
index cf53cef..4b74006 100644
--- a/MatrixMediaGate/Services/AuthValidator.cs
+++ b/MatrixMediaGate/Services/AuthValidator.cs
@@ -5,6 +5,7 @@ namespace MatrixMediaGate.Services;
public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg) {
private readonly Dictionary<string, DateTime> _authCache = new();
+
private readonly HttpClient _hc = new() {
BaseAddress = new Uri(cfg.Upstream),
DefaultRequestHeaders = {
@@ -12,36 +13,45 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg
}
};
- public async Task UpdateAuth(HttpContext ctx) {
- if (ctx.Connection.RemoteIpAddress is null) return;
- var remote = ctx.Connection.RemoteIpAddress.ToString();
-
+ public async Task<bool> UpdateAuth(HttpContext ctx) {
+ if (ctx.Connection.RemoteIpAddress is null) return false;
+ var remote = GetRemote(ctx);
+ if (string.IsNullOrWhiteSpace(remote)) return false;
+
if (_authCache.TryGetValue(remote, out var value)) {
if (value > DateTime.Now.AddSeconds(30)) {
- return;
+ return true;
}
_authCache.Remove(remote);
}
- string? token = getToken(ctx);
- if (token is null) return;
+ string? token = GetToken(ctx);
+ if (string.IsNullOrWhiteSpace(token)) return false;
using var req = new HttpRequestMessage(HttpMethod.Get, $"{cfg.Upstream}/_matrix/client/v3/account/whoami?access_token={token}");
var response = await _hc.SendAsync(req);
- if (response.Content.Headers.ContentType?.MediaType != "application/json") return;
- var content = await response.Content.ReadAsStringAsync();
- var json = JsonDocument.Parse(content);
- if (json.RootElement.TryGetProperty("user_id", out var userId)) {
- _authCache[remote] = DateTime.Now.AddMinutes(5);
- logger.LogInformation("Authenticated {userId} on {remote}, expiring at {time}", userId, remote, _authCache[remote]);
+ try {
+ var content = await response.Content.ReadAsStringAsync();
+ var json = JsonDocument.Parse(content);
+ if (json.RootElement.TryGetProperty("user_id", out var userId)) {
+ _authCache[remote] = DateTime.Now.AddMinutes(5);
+ logger.LogInformation("Authenticated {userId} on {remote}, expiring at {time}", userId, remote, _authCache[remote]);
+ return true;
+ }
+ }
+ catch (Exception e) {
+ logger.LogError(e, "Failed to authenticate {remote}", remote);
+ return false;
}
+
+ return false;
}
public bool ValidateAuth(HttpContext ctx) {
if (ctx.Connection.RemoteIpAddress is null) return false;
var remote = ctx.Connection.RemoteIpAddress.ToString();
-
+
if (_authCache.ContainsKey(remote)) {
if (_authCache[remote] > DateTime.Now) {
return true;
@@ -53,7 +63,7 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg
return false;
}
- private string? getToken(HttpContext ctx) {
+ public string? GetToken(HttpContext ctx) {
if (ctx.Request.Headers.TryGetValue("Authorization", out var header)) {
return header.ToString().Split(' ', 2)[1];
}
@@ -64,4 +74,12 @@ public class AuthValidator(ILogger<AuthValidator> logger, ProxyConfiguration cfg
return null;
}
}
+
+ private string? GetRemote(HttpContext ctx) {
+ foreach (var (key, value) in ctx.Request.Headers) {
+ Console.WriteLine($"Authorized (ignore me) - Headers: {key}: {value}");
+ }
+
+ return ctx.Connection.RemoteIpAddress?.ToString();
+ }
}
\ No newline at end of file
|