diff --git a/MxApiExtensions/Classes/SyncState.cs b/MxApiExtensions/Classes/SyncState.cs
index 6950954..e44d35c 100644
--- a/MxApiExtensions/Classes/SyncState.cs
+++ b/MxApiExtensions/Classes/SyncState.cs
@@ -1,15 +1,91 @@
using System.Collections.Concurrent;
+using System.Text.Json.Serialization;
+using LibMatrix;
+using LibMatrix.EventTypes.Spec.State;
using LibMatrix.Helpers;
using LibMatrix.Homeservers;
using LibMatrix.Responses;
+using Microsoft.OpenApi.Extensions;
namespace MxApiExtensions.Classes;
public class SyncState {
+ private Task? _nextSyncResponse;
public string? NextBatch { get; set; }
public ConcurrentQueue<SyncResponse> SyncQueue { get; set; } = new();
public bool IsInitialSync { get; set; }
- public Task? NextSyncResponse { get; set; }
+
+ [JsonIgnore]
+ public Task? NextSyncResponse {
+ get => _nextSyncResponse;
+ set {
+ _nextSyncResponse = value;
+ NextSyncResponseStartedAt = DateTime.Now;
+ }
+ }
+
public DateTime NextSyncResponseStartedAt { get; set; } = DateTime.Now;
+
+ [JsonIgnore]
public AuthenticatedHomeserverGeneric Homeserver { get; set; }
-}
+
+#region Debug stuff
+
+ public object NextSyncResponseTaskInfo => new {
+ NextSyncResponse?.Id,
+ NextSyncResponse?.IsCompleted,
+ NextSyncResponse?.IsCompletedSuccessfully,
+ NextSyncResponse?.IsCanceled,
+ NextSyncResponse?.IsFaulted,
+ Status = NextSyncResponse?.Status.GetDisplayName()
+ };
+
+
+#endregion
+
+ public void SendEphemeralTimelineEventInRoom(string roomId, StateEventResponse @event) {
+ SyncQueue.Enqueue(new() {
+ NextBatch = NextBatch ?? "null",
+ Rooms = new() {
+ Join = new() {
+ {
+ roomId,
+ new() {
+ Timeline = new() {
+ Events = new() {
+ @event
+ }
+ }
+ }
+ }
+ }
+ }
+ });
+ }
+
+ public void SendStatusMessage(string text) {
+ SyncQueue.Enqueue(new() {
+ NextBatch = NextBatch ?? "null",
+ Presence = new() {
+ Events = new() {
+ new StateEventResponse {
+ TypedContent = new PresenceEventContent {
+ DisplayName = "MxApiExtensions",
+ Presence = "online",
+ StatusMessage = text,
+ // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl
+ AvatarUrl = "",
+ LastActiveAgo = 15,
+ CurrentlyActive = true
+ },
+ Type = "m.presence",
+ StateKey = Homeserver.WhoAmI.UserId,
+ Sender = Homeserver.WhoAmI.UserId,
+ EventId = Guid.NewGuid().ToString(),
+ OriginServerTs = 0
+ }
+ }
+ }
+ });
+ }
+}
\ No newline at end of file
diff --git a/MxApiExtensions/Controllers/ClientVersionsController.cs b/MxApiExtensions/Controllers/Client/ClientVersionsController.cs
index d29e3b2..d29e3b2 100644
--- a/MxApiExtensions/Controllers/ClientVersionsController.cs
+++ b/MxApiExtensions/Controllers/Client/ClientVersionsController.cs
diff --git a/MxApiExtensions/Controllers/Client/LoginController.cs b/MxApiExtensions/Controllers/Client/LoginController.cs
new file mode 100644
index 0000000..009aaef
--- /dev/null
+++ b/MxApiExtensions/Controllers/Client/LoginController.cs
@@ -0,0 +1,73 @@
+using System.Net.Http.Headers;
+using ArcaneLibs.Extensions;
+using LibMatrix;
+using LibMatrix.Extensions;
+using LibMatrix.Responses;
+using LibMatrix.Services;
+using Microsoft.AspNetCore.Mvc;
+using MxApiExtensions.Classes.LibMatrix;
+using MxApiExtensions.Services;
+
+namespace MxApiExtensions.Controllers;
+
+[ApiController]
+[Route("/")]
+public class LoginController(ILogger<LoginController> logger, HomeserverProviderService hsProvider, HomeserverResolverService hsResolver, AuthenticationService auth,
+ MxApiExtensionsConfiguration conf)
+ : ControllerBase {
+ private readonly ILogger _logger = logger;
+ private readonly HomeserverProviderService _hsProvider = hsProvider;
+ private readonly MxApiExtensionsConfiguration _conf = conf;
+
+ [HttpPost("/_matrix/client/{_}/login")]
+ public async Task Proxy([FromBody] LoginRequest request, string _) {
+ string hsCanonical = null;
+ if (Request.Headers.Keys.Any(x => x.ToUpper() == "MXAE_UPSTREAM")) {
+ hsCanonical = Request.Headers.GetByCaseInsensitiveKey("MXAE_UPSTREAM")[0]!;
+ _logger.LogInformation("Found upstream: {}", hsCanonical);
+ }
+ else {
+ if (!request.Identifier.User.Contains("#")) {
+ Response.StatusCode = (int)StatusCodes.Status403Forbidden;
+ Response.ContentType = "application/json";
+ await Response.StartAsync();
+ await Response.WriteAsync(new MxApiMatrixException {
+ ErrorCode = "M_FORBIDDEN",
+ Error = "[MxApiExtensions] Invalid username, must be of the form @user#domain:" + Request.Host.Value
+ }.GetAsJson() ?? "");
+ await Response.CompleteAsync();
+ }
+
+ hsCanonical = request.Identifier.User.Split('#')[1].Split(':')[0];
+ request.Identifier.User = request.Identifier.User.Split(':')[0].Replace('#', ':');
+ if (!request.Identifier.User.StartsWith('@')) request.Identifier.User = '@' + request.Identifier.User;
+ }
+
+ var hs = await hsResolver.ResolveHomeserverFromWellKnown(hsCanonical);
+ //var hs = await _hsProvider.Login(hsCanonical, mxid, request.Password);
+ var hsClient = new MatrixHttpClient { BaseAddress = new Uri(hs.Client) };
+ //hsClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", hsClient.DefaultRequestHeaders.Authorization!.Parameter);
+ if (!string.IsNullOrWhiteSpace(request.InitialDeviceDisplayName))
+ request.InitialDeviceDisplayName += $" (via MxApiExtensions at {Request.Host.Value})";
+ var resp = await hsClient.PostAsJsonAsync("/_matrix/client/r0/login", request);
+ var loginResp = await resp.Content.ReadAsStringAsync();
+ Response.StatusCode = (int)resp.StatusCode;
+ Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json";
+ await Response.StartAsync();
+ await Response.WriteAsync(loginResp);
+ await Response.CompleteAsync();
+ var token = (await resp.Content.ReadFromJsonAsync<LoginResponse>())!.AccessToken;
+ await auth.SaveMxidForToken(token, request.Identifier.User);
+ }
+
+ [HttpGet("/_matrix/client/{_}/login")]
+ public async Task<object> Proxy(string? _) {
+ return new {
+ flows = new[] {
+ new {
+ type = "m.login.password"
+ }
+ }
+ };
+ }
+}
\ No newline at end of file
diff --git a/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs
new file mode 100644
index 0000000..3d1d4e2
--- /dev/null
+++ b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs
@@ -0,0 +1,18 @@
+using LibMatrix.Services;
+using Microsoft.AspNetCore.Mvc;
+using MxApiExtensions.Services;
+
+namespace MxApiExtensions.Controllers.Client.Room;
+
+[ApiController]
+[Route("/")]
+public class RoomController(ILogger<LoginController> logger, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf,
+ AuthenticatedHomeserverProviderService hsProvider)
+ : ControllerBase {
+ [HttpGet("/_matrix/client/{_}/rooms/{roomId}/members_by_homeserver")]
+ public async Task<Dictionary<string, List<string>>> GetRoomMembersByHomeserver(string _, [FromRoute] string roomId, [FromQuery] bool joinedOnly = true) {
+ var hs = await hsProvider.GetHomeserver();
+ var room = hs.GetRoom(roomId);
+ return await room.GetMembersByHomeserverAsync();
+ }
+}
\ No newline at end of file
diff --git a/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs
new file mode 100644
index 0000000..6d3a774
--- /dev/null
+++ b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs
@@ -0,0 +1,72 @@
+using System.Buffers.Text;
+using System.Net.Http.Headers;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using ArcaneLibs.Extensions;
+using LibMatrix;
+using LibMatrix.EventTypes.Spec;
+using LibMatrix.Extensions;
+using LibMatrix.Helpers;
+using LibMatrix.Homeservers;
+using LibMatrix.Responses;
+using LibMatrix.Services;
+using Microsoft.AspNetCore.Mvc;
+using MxApiExtensions.Classes;
+using MxApiExtensions.Classes.LibMatrix;
+using MxApiExtensions.Services;
+
+namespace MxApiExtensions.Controllers;
+
+[ApiController]
+[Route("/")]
+public class RoomsSendMessageController(ILogger<LoginController> logger, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf,
+ AuthenticatedHomeserverProviderService hsProvider)
+ : ControllerBase {
+ [HttpPut("/_matrix/client/{_}/rooms/{roomId}/send/m.room.message/{txnId}")]
+ public async Task Proxy([FromBody] JsonObject request, [FromRoute] string roomId, [FromRoute] string txnId, string _) {
+ var hs = await hsProvider.GetHomeserver();
+
+ var msg = request.Deserialize<RoomMessageEventContent>();
+ if (msg is not null && msg.Body.StartsWith("mxae!")) {
+#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
+ handleMxaeCommand(hs, roomId, msg);
+#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
+ await Response.WriteAsJsonAsync(new EventIdResponse() {
+ EventId = "$" + string.Join("", Random.Shared.GetItems("abcdefghijklmnopqrstuvwxyzABCDEFGHIJLKMNOPQRSTUVWXYZ0123456789".ToCharArray(), 100))
+ });
+ await Response.CompleteAsync();
+ }
+ else {
+ try {
+ var resp = await hs.ClientHttpClient.PutAsJsonAsync($"{Request.Path}{Request.QueryString}", request);
+ var loginResp = await resp.Content.ReadAsStringAsync();
+ Response.StatusCode = (int)resp.StatusCode;
+ Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json";
+ await Response.StartAsync();
+ await Response.WriteAsync(loginResp);
+ await Response.CompleteAsync();
+ }
+ catch (MatrixException e) {
+ await Response.StartAsync();
+ await Response.WriteAsync(e.GetAsJson());
+ await Response.CompleteAsync();
+ }
+ }
+ }
+
+ private async Task handleMxaeCommand(AuthenticatedHomeserverGeneric hs, string roomId, RoomMessageEventContent msg) {
+ var syncState = SyncController._syncStates.GetValueOrDefault(hs.AccessToken);
+ if (syncState is null) return;
+ syncState.SendEphemeralTimelineEventInRoom(roomId, new() {
+ Sender = "@mxae:" + Request.Host.Value,
+ Type = "m.room.message",
+ TypedContent = MessageFormatter.FormatSuccess("Thinking..."),
+ OriginServerTs = (ulong)new DateTimeOffset(DateTime.UtcNow.ToUniversalTime()).ToUnixTimeMilliseconds(),
+ Unsigned = new() {
+ Age = 1
+ },
+ RoomId = roomId,
+ EventId = "$" + string.Join("", Random.Shared.GetItems("abcdefghijklmnopqrstuvwxyzABCDEFGHIJLKMNOPQRSTUVWXYZ0123456789".ToCharArray(), 100))
+ });
+ }
+}
\ No newline at end of file
diff --git a/MxApiExtensions/Controllers/Client/SyncController.cs b/MxApiExtensions/Controllers/Client/SyncController.cs
new file mode 100644
index 0000000..2944c3b
--- /dev/null
+++ b/MxApiExtensions/Controllers/Client/SyncController.cs
@@ -0,0 +1,243 @@
+using System.Collections.Concurrent;
+using System.Diagnostics;
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using System.Text.Json.Serialization;
+using System.Web;
+using ArcaneLibs;
+using LibMatrix;
+using LibMatrix.EventTypes.Spec.State;
+using LibMatrix.Helpers;
+using LibMatrix.Homeservers;
+using LibMatrix.Responses;
+using LibMatrix.RoomTypes;
+using Microsoft.AspNetCore.Mvc;
+using MxApiExtensions.Classes;
+using MxApiExtensions.Classes.LibMatrix;
+using MxApiExtensions.Extensions;
+using MxApiExtensions.Services;
+
+namespace MxApiExtensions.Controllers;
+
+[ApiController]
+[Route("/")]
+public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hsProvider)
+ : ControllerBase {
+ public static readonly ConcurrentDictionary<string, SyncState> _syncStates = new();
+
+ private static SemaphoreSlim _semaphoreSlim = new(1, 1);
+ private Stopwatch _syncElapsed = Stopwatch.StartNew();
+
+ [HttpGet("/_matrix/client/{_}/sync")]
+ public async Task Sync(string _, [FromQuery] string? since, [FromQuery] int timeout = 1000) {
+ Task? preloadTask = null;
+ AuthenticatedHomeserverGeneric? hs = null;
+ try {
+ hs = await hsProvider.GetHomeserver();
+ }
+ catch (Exception e) {
+ Console.WriteLine(e);
+ }
+
+ var qs = HttpUtility.ParseQueryString(Request.QueryString.Value!);
+ qs.Remove("access_token");
+ if (since == "null") qs.Remove("since");
+
+ if (!config.FastInitialSync.Enabled) {
+ logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
+ var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}");
+ await Response.WriteHttpResponse(result);
+ return;
+ }
+
+ await _semaphoreSlim.WaitAsync();
+ var syncState = _syncStates.GetOrAdd($"{hs.WhoAmI.UserId}/{hs.WhoAmI.DeviceId}/{hs.ServerName}:{hs.AccessToken}", _ => {
+ logger.LogInformation("Started tracking sync state for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
+ var ss = new SyncState {
+ IsInitialSync = string.IsNullOrWhiteSpace(since),
+ Homeserver = hs
+ };
+ if (ss.IsInitialSync) {
+ preloadTask = EnqueuePreloadData(ss);
+ }
+
+ logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
+
+ ss.NextSyncResponseStartedAt = DateTime.Now;
+ ss.NextSyncResponse = Task.Delay(15_000);
+ ss.NextSyncResponse.ContinueWith(async x => {
+ logger.LogInformation("Sync for {} on {} ({}) starting", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
+ ss.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}");
+ (ss.NextSyncResponse as Task<HttpResponseMessage>).ContinueWith(async x => EnqueueSyncResponse(ss, await x));
+ });
+ return ss;
+ });
+ _semaphoreSlim.Release();
+
+ if (syncState.SyncQueue.Count > 0) {
+ logger.LogInformation("Sync for {} on {} ({}) has {} queued results", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, syncState.SyncQueue.Count);
+ syncState.SyncQueue.TryDequeue(out var result);
+ while (result is null)
+ syncState.SyncQueue.TryDequeue(out result);
+ Response.StatusCode = StatusCodes.Status200OK;
+ Response.ContentType = "application/json";
+ await Response.StartAsync();
+ result.NextBatch ??= since ?? syncState.NextBatch;
+ await JsonSerializer.SerializeAsync(Response.Body, result, new JsonSerializerOptions {
+ WriteIndented = true,
+ DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
+ });
+ await Response.CompleteAsync();
+ return;
+ }
+
+ var newTimeout = Math.Clamp(timeout, 0, syncState.IsInitialSync ? syncState.SyncQueue.Count >= 2 ? 0 : 250 : timeout);
+ logger.LogInformation("Sync for {} on {} ({}) is still running, waiting for {}ms, {} elapsed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, newTimeout,
+ DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt));
+
+ try {
+ if (syncState.NextSyncResponse is not null)
+ await syncState.NextSyncResponse.WaitAsync(TimeSpan.FromMilliseconds(newTimeout));
+ else {
+ syncState.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}");
+ (syncState.NextSyncResponse as Task<HttpResponseMessage>).ContinueWith(async x => EnqueueSyncResponse(syncState, await x));
+ // await Task.Delay(250);
+ }
+ }
+ catch (TimeoutException) { }
+
+ // if (_syncElapsed.ElapsedMilliseconds > timeout)
+ if(syncState.NextSyncResponse?.IsCompleted == false)
+ syncState.SendStatusMessage(
+ $"M={Util.BytesToString(Process.GetCurrentProcess().WorkingSet64)} TE={DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} S={syncState.NextSyncResponse?.Status} QL={syncState.SyncQueue.Count}");
+ Response.StatusCode = StatusCodes.Status200OK;
+ Response.ContentType = "application/json";
+ await Response.StartAsync();
+ var response = syncState.SyncQueue.FirstOrDefault();
+ if (response is null)
+ response = new();
+ response.NextBatch ??= since ?? syncState.NextBatch;
+ await JsonSerializer.SerializeAsync(Response.Body, response, new JsonSerializerOptions {
+ WriteIndented = true,
+ DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
+ });
+ await Response.CompleteAsync();
+
+ Response.Body.Close();
+ if (preloadTask is not null)
+ await preloadTask;
+ }
+
+ private async Task EnqueuePreloadData(SyncState syncState) {
+ var rooms = await syncState.Homeserver.GetJoinedRooms();
+ var dm_rooms = (await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct")).Aggregate(new List<string>(), (list, entry) => {
+ list.AddRange(entry.Value);
+ return list;
+ });
+
+ var ownHs = syncState.Homeserver.WhoAmI.UserId.Split(':')[1];
+ rooms = rooms.OrderBy(x => {
+ if (dm_rooms.Contains(x.RoomId)) return -1;
+ var parts = x.RoomId.Split(':');
+ if (parts[1] == ownHs) return 200;
+ if (HomeserverWeightEstimation.EstimatedSize.ContainsKey(parts[1])) return HomeserverWeightEstimation.EstimatedSize[parts[1]] + parts[0].Length;
+ return 5000;
+ }).ToList();
+ var roomDataTasks = rooms.Select(room => EnqueueRoomData(syncState, room)).ToList();
+ logger.LogInformation("Preloading data for {} rooms on {} ({})", roomDataTasks.Count, syncState.Homeserver.ServerName, syncState.Homeserver.AccessToken);
+
+ await Task.WhenAll(roomDataTasks);
+ }
+
+ private SemaphoreSlim _roomDataSemaphore = new(32, 32);
+
+ private async Task EnqueueRoomData(SyncState syncState, GenericRoom room) {
+ await _roomDataSemaphore.WaitAsync();
+ var roomState = room.GetFullStateAsync();
+ var timeline = await room.GetMessagesAsync(limit: 100, dir: "b");
+ timeline.Chunk.Reverse();
+ var SyncResponse = new SyncResponse {
+ Rooms = new() {
+ Join = new() {
+ {
+ room.RoomId,
+ new SyncResponse.RoomsDataStructure.JoinedRoomDataStructure {
+ AccountData = new() {
+ Events = new()
+ },
+ Ephemeral = new() {
+ Events = new()
+ },
+ State = new() {
+ Events = timeline.State
+ },
+ UnreadNotifications = new() {
+ HighlightCount = 0,
+ NotificationCount = 0
+ },
+ Timeline = new() {
+ Events = timeline.Chunk,
+ Limited = false,
+ PrevBatch = timeline.Start
+ },
+ Summary = new() {
+ Heroes = new(),
+ InvitedMemberCount = 0,
+ JoinedMemberCount = 1
+ }
+ }
+ }
+ }
+ },
+ Presence = new() {
+ Events = new() {
+ await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status} {room.RoomId}")
+ }
+ },
+ NextBatch = ""
+ };
+
+ await foreach (var stateEvent in roomState) {
+ SyncResponse.Rooms.Join[room.RoomId].State.Events.Add(stateEvent);
+ }
+
+ var joinRoom = SyncResponse.Rooms.Join[room.RoomId];
+ joinRoom.Summary.Heroes.AddRange(joinRoom.State.Events
+ .Where(x =>
+ x.Type == "m.room.member"
+ && x.StateKey != syncState.Homeserver.WhoAmI.UserId
+ && (x.TypedContent as RoomMemberEventContent).Membership == "join"
+ )
+ .Select(x => x.StateKey));
+ joinRoom.Summary.JoinedMemberCount = joinRoom.Summary.Heroes.Count;
+
+ syncState.SyncQueue.Enqueue(SyncResponse);
+ _roomDataSemaphore.Release();
+ }
+
+ private async Task<StateEventResponse> GetStatusMessage(SyncState syncState, string message) {
+ return new StateEventResponse {
+ TypedContent = new PresenceEventContent {
+ DisplayName = "MxApiExtensions",
+ Presence = "online",
+ StatusMessage = message,
+ // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl
+ AvatarUrl = ""
+ },
+ Type = "m.presence",
+ StateKey = syncState.Homeserver.WhoAmI.UserId,
+ Sender = syncState.Homeserver.WhoAmI.UserId,
+ EventId = Guid.NewGuid().ToString(),
+ OriginServerTs = 0
+ };
+ }
+
+ private async Task EnqueueSyncResponse(SyncState ss, HttpResponseMessage task) {
+ var sr = await task.Content.ReadFromJsonAsync<JsonObject>();
+ if (sr.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!;
+ ss.NextBatch = sr["next_batch"].GetValue<string>();
+ ss.IsInitialSync = false;
+ ss.SyncQueue.Enqueue(sr.Deserialize<SyncResponse>());
+ ss.NextSyncResponse = null;
+ }
+}
\ No newline at end of file
diff --git a/MxApiExtensions/Controllers/Extensions/DebugController.cs b/MxApiExtensions/Controllers/Extensions/DebugController.cs
new file mode 100644
index 0000000..79ed2f0
--- /dev/null
+++ b/MxApiExtensions/Controllers/Extensions/DebugController.cs
@@ -0,0 +1,45 @@
+using System.Collections.Concurrent;
+using Microsoft.AspNetCore.Mvc;
+using MxApiExtensions.Classes.LibMatrix;
+using MxApiExtensions.Services;
+
+namespace MxApiExtensions.Controllers.Extensions;
+
+[ApiController]
+[Route("/")]
+public class DebugController : ControllerBase {
+ private readonly ILogger _logger;
+ private readonly MxApiExtensionsConfiguration _config;
+ private readonly AuthenticationService _authenticationService;
+
+ private static ConcurrentDictionary<string, RoomInfoEntry> _roomInfoCache = new();
+
+ public DebugController(ILogger<ProxyConfigurationController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService,
+ AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) {
+ _logger = logger;
+ _config = config;
+ _authenticationService = authenticationService;
+ }
+
+ [HttpGet("debug")]
+ public async Task<object?> GetDebug() {
+ var mxid = await _authenticationService.GetMxidFromToken();
+ if(!_config.Admins.Contains(mxid)) {
+ _logger.LogWarning("Got debug request for {user}, but they are not an admin", mxid);
+ Response.StatusCode = StatusCodes.Status403Forbidden;
+ Response.ContentType = "application/json";
+
+ await Response.WriteAsJsonAsync(new {
+ ErrorCode = "M_FORBIDDEN",
+ Error = "You are not an admin"
+ });
+ await Response.CompleteAsync();
+ return null;
+ }
+
+ _logger.LogInformation("Got debug request for {user}", mxid);
+ return new {
+ SyncStates = SyncController._syncStates
+ };
+ }
+}
diff --git a/MxApiExtensions/Controllers/LoginController.cs b/MxApiExtensions/Controllers/LoginController.cs
deleted file mode 100644
index bd354ef..0000000
--- a/MxApiExtensions/Controllers/LoginController.cs
+++ /dev/null
@@ -1,70 +0,0 @@
-using System.Net.Http.Headers;
-using LibMatrix;
-using LibMatrix.Extensions;
-using LibMatrix.Responses;
-using LibMatrix.Services;
-using Microsoft.AspNetCore.Mvc;
-using MxApiExtensions.Classes.LibMatrix;
-using MxApiExtensions.Services;
-
-namespace MxApiExtensions.Controllers;
-
-[ApiController]
-[Route("/")]
-public class LoginController : ControllerBase {
- private readonly ILogger _logger;
- private readonly HomeserverProviderService _hsProvider;
- private readonly HomeserverResolverService _hsResolver;
- private readonly AuthenticationService _auth;
- private readonly MxApiExtensionsConfiguration _conf;
-
- public LoginController(ILogger<LoginController> logger, HomeserverProviderService hsProvider, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf) {
- _logger = logger;
- _hsProvider = hsProvider;
- _hsResolver = hsResolver;
- _auth = auth;
- _conf = conf;
- }
-
- [HttpPost("/_matrix/client/{_}/login")]
- public async Task Proxy([FromBody] LoginRequest request, string _) {
- if (!request.Identifier.User.Contains("#")) {
- Response.StatusCode = (int)StatusCodes.Status403Forbidden;
- Response.ContentType = "application/json";
- await Response.StartAsync();
- await Response.WriteAsync(new MxApiMatrixException {
- ErrorCode = "M_FORBIDDEN",
- Error = "[MxApiExtensions] Invalid username, must be of the form @user#domain:" + Request.Host.Value
- }.GetAsJson() ?? "");
- await Response.CompleteAsync();
- }
- var hsCanonical = request.Identifier.User.Split('#')[1].Split(':')[0];
- request.Identifier.User = request.Identifier.User.Split(':')[0].Replace('#', ':');
- if(!request.Identifier.User.StartsWith('@')) request.Identifier.User = '@' + request.Identifier.User;
- var hs = await _hsResolver.ResolveHomeserverFromWellKnown(hsCanonical);
- //var hs = await _hsProvider.Login(hsCanonical, mxid, request.Password);
- var hsClient = new MatrixHttpClient { BaseAddress = new Uri(hs.client) };
- //hsClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", hsClient.DefaultRequestHeaders.Authorization!.Parameter);
- var resp = await hsClient.PostAsJsonAsync("/_matrix/client/r0/login", request);
- var loginResp = await resp.Content.ReadAsStringAsync();
- Response.StatusCode = (int)resp.StatusCode;
- Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json";
- await Response.StartAsync();
- await Response.WriteAsync(loginResp);
- await Response.CompleteAsync();
- var token = (await resp.Content.ReadFromJsonAsync<LoginResponse>())!.AccessToken;
- await _auth.SaveMxidForToken(token, request.Identifier.User);
- }
-
-
- [HttpGet("/_matrix/client/{_}/login")]
- public async Task<object> Proxy(string? _) {
- return new {
- flows = new[] {
- new {
- type = "m.login.password"
- }
- }
- };
- }
-}
diff --git a/MxApiExtensions/Controllers/GenericProxyController.cs b/MxApiExtensions/Controllers/Other/GenericProxyController.cs
index c004fcb..bae07c0 100644
--- a/MxApiExtensions/Controllers/GenericProxyController.cs
+++ b/MxApiExtensions/Controllers/Other/GenericProxyController.cs
@@ -7,29 +7,17 @@ namespace MxApiExtensions.Controllers;
[ApiController]
[Route("/{*_}")]
-public class GenericController : ControllerBase {
- private readonly ILogger<GenericController> _logger;
- private readonly MxApiExtensionsConfiguration _config;
- private readonly AuthenticationService _authenticationService;
- private readonly AuthenticatedHomeserverProviderService _authenticatedHomeserverProviderService;
- private static Dictionary<string, string> _tokenMap = new();
-
- public GenericController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService,
- AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) {
- _logger = logger;
- _config = config;
- _authenticationService = authenticationService;
- _authenticatedHomeserverProviderService = authenticatedHomeserverProviderService;
- }
-
+public class GenericController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService,
+ AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService)
+ : ControllerBase {
[HttpGet]
public async Task Proxy([FromQuery] string? access_token, string? _) {
try {
- access_token ??= _authenticationService.GetToken(fail: false);
- var mxid = await _authenticationService.GetMxidFromToken(fail: false);
- var hs = await _authenticatedHomeserverProviderService.GetHomeserver();
+ // access_token ??= _authenticationService.GetToken(fail: false);
+ // var mxid = await _authenticationService.GetMxidFromToken(fail: false);
+ var hs = await authenticatedHomeserverProviderService.GetRemoteHomeserver();
- _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString);
+ logger.LogInformation("Proxying request: {}{}", Request.Path, Request.QueryString);
//remove access_token from query string
Request.QueryString = new QueryString(
@@ -55,7 +43,7 @@ public class GenericController : ControllerBase {
await Response.CompleteAsync();
}
catch (MxApiMatrixException e) {
- _logger.LogError(e, "Matrix error");
+ logger.LogError(e, "Matrix error");
Response.StatusCode = StatusCodes.Status500InternalServerError;
Response.ContentType = "application/json";
@@ -63,7 +51,7 @@ public class GenericController : ControllerBase {
await Response.CompleteAsync();
}
catch (Exception e) {
- _logger.LogError(e, "Unhandled error");
+ logger.LogError(e, "Unhandled error");
Response.StatusCode = StatusCodes.Status500InternalServerError;
Response.ContentType = "text/plain";
@@ -75,11 +63,11 @@ public class GenericController : ControllerBase {
[HttpPost]
public async Task ProxyPost([FromQuery] string? access_token, string _) {
try {
- access_token ??= _authenticationService.GetToken(fail: false);
- var mxid = await _authenticationService.GetMxidFromToken(fail: false);
- var hs = await _authenticatedHomeserverProviderService.GetHomeserver();
+ access_token ??= authenticationService.GetToken(fail: false);
+ var mxid = await authenticationService.GetMxidFromToken(fail: false);
+ var hs = await authenticatedHomeserverProviderService.GetHomeserver();
- _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString);
+ logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString);
using var hc = new HttpClient();
hc.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", access_token);
@@ -112,7 +100,7 @@ public class GenericController : ControllerBase {
await Response.CompleteAsync();
}
catch (MxApiMatrixException e) {
- _logger.LogError(e, "Matrix error");
+ logger.LogError(e, "Matrix error");
Response.StatusCode = StatusCodes.Status500InternalServerError;
Response.ContentType = "application/json";
@@ -120,7 +108,7 @@ public class GenericController : ControllerBase {
await Response.CompleteAsync();
}
catch (Exception e) {
- _logger.LogError(e, "Unhandled error");
+ logger.LogError(e, "Unhandled error");
Response.StatusCode = StatusCodes.Status500InternalServerError;
Response.ContentType = "text/plain";
@@ -132,11 +120,11 @@ public class GenericController : ControllerBase {
[HttpPut]
public async Task ProxyPut([FromQuery] string? access_token, string _) {
try {
- access_token ??= _authenticationService.GetToken(fail: false);
- var mxid = await _authenticationService.GetMxidFromToken(fail: false);
- var hs = await _authenticatedHomeserverProviderService.GetHomeserver();
+ access_token ??= authenticationService.GetToken(fail: false);
+ var mxid = await authenticationService.GetMxidFromToken(fail: false);
+ var hs = await authenticatedHomeserverProviderService.GetHomeserver();
- _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString);
+ logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString);
using var hc = new HttpClient();
hc.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", access_token);
@@ -169,7 +157,7 @@ public class GenericController : ControllerBase {
await Response.CompleteAsync();
}
catch (MxApiMatrixException e) {
- _logger.LogError(e, "Matrix error");
+ logger.LogError(e, "Matrix error");
Response.StatusCode = StatusCodes.Status500InternalServerError;
Response.ContentType = "application/json";
@@ -177,7 +165,7 @@ public class GenericController : ControllerBase {
await Response.CompleteAsync();
}
catch (Exception e) {
- _logger.LogError(e, "Unhandled error");
+ logger.LogError(e, "Unhandled error");
Response.StatusCode = StatusCodes.Status500InternalServerError;
Response.ContentType = "text/plain";
diff --git a/MxApiExtensions/Controllers/Other/MediaProxyController.cs b/MxApiExtensions/Controllers/Other/MediaProxyController.cs
new file mode 100644
index 0000000..03b68ba
--- /dev/null
+++ b/MxApiExtensions/Controllers/Other/MediaProxyController.cs
@@ -0,0 +1,78 @@
+using System.Net.Http.Headers;
+using LibMatrix.Homeservers;
+using LibMatrix.Services;
+using Microsoft.AspNetCore.Mvc;
+using MxApiExtensions.Classes.LibMatrix;
+using MxApiExtensions.Services;
+
+namespace MxApiExtensions.Controllers;
+
+[ApiController]
+[Route("/")]
+public class MediaProxyController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService,
+ AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService, HomeserverProviderService hsProvider)
+ : ControllerBase {
+ private class MediaCacheEntry {
+ public DateTime LastRequested { get; set; } = DateTime.Now;
+ public byte[] Data { get; set; }
+ public string ContentType { get; set; }
+ public long Size => Data.LongCount();
+ }
+
+ private static Dictionary<string, MediaCacheEntry> _mediaCache = new();
+ private static SemaphoreSlim _semaphore = new(1, 1);
+
+ [HttpGet("/_matrix/media/{_}/download/{serverName}/{mediaId}")]
+ public async Task Proxy(string? _, string serverName, string mediaId) {
+ try {
+ logger.LogInformation("Proxying media: {}{}", serverName, mediaId);
+
+ await _semaphore.WaitAsync();
+ MediaCacheEntry entry;
+ if (!_mediaCache.ContainsKey($"{serverName}/{mediaId}")) {
+ _mediaCache.Add($"{serverName}/{mediaId}", entry = new());
+ List<RemoteHomeserver> FeasibleHomeservers = new();
+ {
+ var a = await authenticatedHomeserverProviderService.TryGetRemoteHomeserver();
+ if(a is not null)
+ FeasibleHomeservers.Add(a);
+ }
+
+ FeasibleHomeservers.Add(await hsProvider.GetRemoteHomeserver(serverName));
+
+ foreach (var homeserver in FeasibleHomeservers) {
+ var resp = await homeserver.ClientHttpClient.GetAsync($"{Request.Path}");
+ if(!resp.IsSuccessStatusCode) continue;
+ entry.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json";
+ entry.Data = await resp.Content.ReadAsByteArrayAsync();
+ break;
+ }
+ }
+ else entry = _mediaCache[$"{serverName}/{mediaId}"];
+ _semaphore.Release();
+
+ Response.StatusCode = 200;
+ Response.ContentType = entry.ContentType;
+ await Response.StartAsync();
+ await Response.Body.WriteAsync(entry.Data, 0, entry.Data.Length);
+ await Response.Body.FlushAsync();
+ await Response.CompleteAsync();
+ }
+ catch (MxApiMatrixException e) {
+ logger.LogError(e, "Matrix error");
+ Response.StatusCode = StatusCodes.Status500InternalServerError;
+ Response.ContentType = "application/json";
+
+ await Response.WriteAsync(e.GetAsJson());
+ await Response.CompleteAsync();
+ }
+ catch (Exception e) {
+ logger.LogError(e, "Unhandled error");
+ Response.StatusCode = StatusCodes.Status500InternalServerError;
+ Response.ContentType = "text/plain";
+
+ await Response.WriteAsync(e.ToString());
+ await Response.CompleteAsync();
+ }
+ }
+}
diff --git a/MxApiExtensions/Controllers/WellKnownController.cs b/MxApiExtensions/Controllers/Other/WellKnownController.cs
index b27451f..c0e255f 100644
--- a/MxApiExtensions/Controllers/WellKnownController.cs
+++ b/MxApiExtensions/Controllers/Other/WellKnownController.cs
@@ -5,18 +5,14 @@ namespace MxApiExtensions.Controllers;
[ApiController]
[Route("/")]
-public class WellKnownController : ControllerBase {
- private readonly MxApiExtensionsConfiguration _config;
-
- public WellKnownController(MxApiExtensionsConfiguration config) {
- _config = config;
- }
+public class WellKnownController(MxApiExtensionsConfiguration config) : ControllerBase {
+ private readonly MxApiExtensionsConfiguration _config = config;
[HttpGet("/.well-known/matrix/client")]
public object GetWellKnown() {
var res = new JsonObject();
res.Add("m.homeserver", new JsonObject {
- { "base_url", Request.Scheme + "://" + Request.Host + "/" },
+ { "base_url", Request.Scheme + "://" + Request.Host },
});
return res;
}
diff --git a/MxApiExtensions/Controllers/SyncController.cs b/MxApiExtensions/Controllers/SyncController.cs
deleted file mode 100644
index 0b0007f..0000000
--- a/MxApiExtensions/Controllers/SyncController.cs
+++ /dev/null
@@ -1,274 +0,0 @@
-using System.Collections.Concurrent;
-using System.Text.Json;
-using System.Text.Json.Serialization;
-using System.Web;
-using LibMatrix;
-using LibMatrix.EventTypes.Spec.State;
-using LibMatrix.Helpers;
-using LibMatrix.Homeservers;
-using LibMatrix.Responses;
-using LibMatrix.RoomTypes;
-using Microsoft.AspNetCore.Mvc;
-using MxApiExtensions.Classes;
-using MxApiExtensions.Classes.LibMatrix;
-using MxApiExtensions.Extensions;
-using MxApiExtensions.Services;
-
-namespace MxApiExtensions.Controllers;
-
-[ApiController]
-[Route("/")]
-public class SyncController : ControllerBase {
- private readonly ILogger<SyncController> _logger;
- private readonly MxApiExtensionsConfiguration _config;
- private readonly AuthenticationService _auth;
- private readonly AuthenticatedHomeserverProviderService _hs;
-
- private static readonly ConcurrentDictionary<string, SyncState> _syncStates = new();
-
- public SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hs) {
- _logger = logger;
- _config = config;
- _auth = auth;
- _hs = hs;
- }
-
- [HttpGet("/_matrix/client/v3/sync")]
- public async Task Sync([FromQuery] string? since, [FromQuery] int timeout = 1000) {
- Task? preloadTask = null;
- AuthenticatedHomeserverGeneric? hs = null;
- try {
- hs = await _hs.GetHomeserver();
- }
- catch (Exception e) {
- Console.WriteLine();
- }
- var qs = HttpUtility.ParseQueryString(Request.QueryString.Value!);
- qs.Remove("access_token");
-
- if (!_config.FastInitialSync.Enabled) {
- _logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
- var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}");
- await Response.WriteHttpResponse(result);
- return;
- }
-
- try {
- var syncState = _syncStates.GetOrAdd(hs.AccessToken, _ => {
- _logger.LogInformation("Started tracking sync state for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
- return new SyncState {
- IsInitialSync = string.IsNullOrWhiteSpace(since),
- Homeserver = hs
- };
- });
-
- if (syncState.NextSyncResponse is null) {
- _logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
-
- if (syncState.IsInitialSync) {
- preloadTask = EnqueuePreloadData(syncState);
- }
-
- syncState.NextSyncResponseStartedAt = DateTime.Now;
- syncState.NextSyncResponse = Task.Delay(30_000);
- syncState.NextSyncResponse.ContinueWith(x => {
- _logger.LogInformation("Sync for {} on {} ({}) starting", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
- syncState.NextSyncResponse = hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}");
- });
- }
-
- if (syncState.SyncQueue.Count > 0) {
- _logger.LogInformation("Sync for {} on {} ({}) has {} queued results", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, syncState.SyncQueue.Count);
- syncState.SyncQueue.TryDequeue(out var result);
-
- Response.StatusCode = StatusCodes.Status200OK;
- Response.ContentType = "application/json";
- await Response.StartAsync();
- await JsonSerializer.SerializeAsync(Response.Body, result, new JsonSerializerOptions {
- WriteIndented = true,
- DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
- });
- await Response.CompleteAsync();
- return;
- }
-
- timeout = Math.Clamp(timeout, 0, 100);
- _logger.LogInformation("Sync for {} on {} ({}) is still running, waiting for {}ms, {} elapsed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, timeout,
- DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt));
-
- try {
- await syncState.NextSyncResponse.WaitAsync(TimeSpan.FromMilliseconds(timeout));
- }
- catch { }
-
- if (syncState.NextSyncResponse is Task<HttpResponseMessage> { IsCompleted: true } response) {
- _logger.LogInformation("Sync for {} on {} ({}) completed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
- var resp = await response;
- await Response.WriteHttpResponse(resp);
- return;
- }
-
- // await Task.Delay(timeout);
- _logger.LogInformation("Sync for {} on {} ({}): sending bogus response", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
- Response.StatusCode = StatusCodes.Status200OK;
- Response.ContentType = "application/json";
- await Response.StartAsync();
- var SyncResponse = new SyncResponse {
- // NextBatch = "MxApiExtensions::Next" + Random.Shared.NextInt64(),
- NextBatch = since ?? "",
- Presence = new() {
- Events = new() {
- await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status}")
- }
- },
- Rooms = new() {
- Invite = new(),
- Join = new()
- }
- };
- await JsonSerializer.SerializeAsync(Response.Body, SyncResponse, new JsonSerializerOptions {
- WriteIndented = true,
- DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
- });
- await Response.CompleteAsync();
- }
- catch (MxApiMatrixException e) {
- _logger.LogError(e, "Error while syncing for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId,
- _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken);
-
- Response.StatusCode = StatusCodes.Status500InternalServerError;
- Response.ContentType = "application/json";
-
- await Response.WriteAsJsonAsync(e.GetAsJson());
- await Response.CompleteAsync();
- }
-
- catch (Exception e) {
- //catch SSL connection errors and retry
- if (e.InnerException is HttpRequestException && e.InnerException.Message.Contains("The SSL connection could not be established")) {
- _logger.LogWarning("Caught SSL connection error, retrying sync for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId,
- _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken);
- await Sync(since, timeout);
- return;
- }
-
- _logger.LogError(e, "Error while syncing for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId,
- _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken);
-
- Response.StatusCode = StatusCodes.Status500InternalServerError;
- Response.ContentType = "text/plain";
-
- await Response.WriteAsync(e.ToString());
- await Response.CompleteAsync();
- }
-
- Response.Body.Close();
- if (preloadTask is not null)
- await preloadTask;
- }
-
- private async Task EnqueuePreloadData(SyncState syncState) {
- var rooms = await syncState.Homeserver.GetJoinedRooms();
- var dm_rooms = (await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct")).Aggregate(new List<string>(), (list, entry) => {
- list.AddRange(entry.Value);
- return list;
- });
-
- var ownHs = syncState.Homeserver.WhoAmI.UserId.Split(':')[1];
- rooms = rooms.OrderBy(x => {
- if (dm_rooms.Contains(x.RoomId)) return -1;
- var parts = x.RoomId.Split(':');
- if (parts[1] == ownHs) return 200;
- if (HomeserverWeightEstimation.EstimatedSize.ContainsKey(parts[1])) return HomeserverWeightEstimation.EstimatedSize[parts[1]] + parts[0].Length;
- return 5000;
- }).ToList();
- var roomDataTasks = rooms.Select(room => EnqueueRoomData(syncState, room)).ToList();
- _logger.LogInformation("Preloading data for {} rooms on {} ({})", roomDataTasks.Count, syncState.Homeserver.ServerName, syncState.Homeserver.AccessToken);
-
- await Task.WhenAll(roomDataTasks);
- }
-
- private SemaphoreSlim _roomDataSemaphore = new(4, 4);
-
- private async Task EnqueueRoomData(SyncState syncState, GenericRoom room) {
- await _roomDataSemaphore.WaitAsync();
- var roomState = room.GetFullStateAsync();
- var timeline = await room.GetMessagesAsync(limit: 100, dir: "b");
- timeline.Chunk.Reverse();
- var SyncResponse = new SyncResponse {
- Rooms = new() {
- Join = new() {
- {
- room.RoomId,
- new SyncResponse.RoomsDataStructure.JoinedRoomDataStructure {
- AccountData = new() {
- Events = new()
- },
- Ephemeral = new() {
- Events = new()
- },
- State = new() {
- Events = timeline.State
- },
- UnreadNotifications = new() {
- HighlightCount = 0,
- NotificationCount = 0
- },
- Timeline = new() {
- Events = timeline.Chunk,
- Limited = false,
- PrevBatch = timeline.Start
- },
- Summary = new() {
- Heroes = new(),
- InvitedMemberCount = 0,
- JoinedMemberCount = 1
- }
- }
- }
- }
- },
- Presence = new() {
- Events = new() {
- await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status} {room.RoomId}")
- }
- },
- NextBatch = ""
- };
-
- await foreach (var stateEvent in roomState) {
- SyncResponse.Rooms.Join[room.RoomId].State.Events.Add(stateEvent);
- }
-
- var joinRoom = SyncResponse.Rooms.Join[room.RoomId];
- joinRoom.Summary.Heroes.AddRange(joinRoom.State.Events
- .Where(x =>
- x.Type == "m.room.member"
- && x.StateKey != syncState.Homeserver.WhoAmI.UserId
- && (x.TypedContent as RoomMemberEventContent).Membership == "join"
- )
- .Select(x => x.StateKey));
- joinRoom.Summary.JoinedMemberCount = joinRoom.Summary.Heroes.Count;
-
- syncState.SyncQueue.Enqueue(SyncResponse);
- _roomDataSemaphore.Release();
- }
-
- private async Task<StateEventResponse> GetStatusMessage(SyncState syncState, string message) {
- return new StateEventResponse {
- TypedContent = new PresenceEventContent {
- DisplayName = "MxApiExtensions",
- Presence = "online",
- StatusMessage = message,
- // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl
- AvatarUrl = ""
- },
- Type = "m.presence",
- StateKey = syncState.Homeserver.WhoAmI.UserId,
- Sender = syncState.Homeserver.WhoAmI.UserId,
- UserId = syncState.Homeserver.WhoAmI.UserId,
- EventId = Guid.NewGuid().ToString(),
- OriginServerTs = 0
- };
- }
-}
diff --git a/MxApiExtensions/MxApiExtensions.csproj b/MxApiExtensions/MxApiExtensions.csproj
index 92474e0..b34ef78 100644
--- a/MxApiExtensions/MxApiExtensions.csproj
+++ b/MxApiExtensions/MxApiExtensions.csproj
@@ -10,26 +10,15 @@
<ItemGroup>
<PackageReference Include="ArcaneLibs" Version="1.0.0-preview6437853305.78f6d30" />
+ <PackageReference Include="EasyCompressor.LZMA" Version="1.4.0" />
<PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.0-preview.7.23375.9" />
<PackageReference Include="Swashbuckle.AspNetCore" Version="6.4.0" />
</ItemGroup>
-
<ItemGroup>
<ProjectReference Include="..\..\ArcaneLibs\ArcaneLibs\ArcaneLibs.csproj" />
<ProjectReference Include="..\..\LibMatrix\LibMatrix\LibMatrix.csproj" />
<ProjectReference Include="..\MxApiExtensions.Classes.LibMatrix\MxApiExtensions.Classes.LibMatrix.csproj" />
<ProjectReference Include="..\MxApiExtensions.Classes\MxApiExtensions.Classes.csproj" />
-
-
-
-
</ItemGroup>
-
-
-
-
-
-
-
</Project>
diff --git a/MxApiExtensions/Program.cs b/MxApiExtensions/Program.cs
index cdbdb59..72a5dc9 100644
--- a/MxApiExtensions/Program.cs
+++ b/MxApiExtensions/Program.cs
@@ -1,4 +1,7 @@
+using System.Net.Mime;
+using LibMatrix;
using LibMatrix.Services;
+using Microsoft.AspNetCore.Diagnostics;
using Microsoft.AspNetCore.Http.Timeouts;
using MxApiExtensions;
using MxApiExtensions.Classes.LibMatrix;
@@ -8,9 +11,7 @@ var builder = WebApplication.CreateBuilder(args);
// Add services to the container.
-builder.Services.AddControllers().AddJsonOptions(options => {
- options.JsonSerializerOptions.WriteIndented = true;
-});
+builder.Services.AddControllers().AddJsonOptions(options => { options.JsonSerializerOptions.WriteIndented = true; });
// Learn more about configuring Swagger/OpenAPI at https://aka.ms/aspnetcore/swashbuckle
builder.Services.AddEndpointsApiExplorer();
builder.Services.AddSwaggerGen();
@@ -47,6 +48,13 @@ builder.Services.AddRequestTimeouts(x => {
};
});
+// builder.Services.AddCors(x => x.AddDefaultPolicy(y => y.AllowAnyHeader().AllowCredentials().AllowAnyOrigin().AllowAnyMethod()));
+builder.Services.AddCors(options => {
+ options.AddPolicy(
+ "Open",
+ policy => policy.AllowAnyOrigin().AllowAnyHeader());
+});
+
var app = builder.Build();
// Configure the HTTP request pipeline.
@@ -56,9 +64,38 @@ if (app.Environment.IsDevelopment()) {
}
// app.UseHttpsRedirection();
+app.UseCors("Open");
+
+app.UseExceptionHandler(exceptionHandlerApp =>
+{
+ exceptionHandlerApp.Run(async context =>
+ {
+
+ var exceptionHandlerPathFeature =
+ context.Features.Get<IExceptionHandlerPathFeature>();
+
+ if (exceptionHandlerPathFeature?.Error is MatrixException mxe) {
+ context.Response.StatusCode = mxe.ErrorCode switch {
+ "M_NOT_FOUND" => StatusCodes.Status404NotFound,
+ _ => StatusCodes.Status500InternalServerError
+ };
+ context.Response.ContentType = MediaTypeNames.Application.Json;
+ await context.Response.WriteAsync(mxe.GetAsJson()!);
+ }
+ else {
+ context.Response.StatusCode = StatusCodes.Status500InternalServerError;
+ context.Response.ContentType = MediaTypeNames.Application.Json;
+ await context.Response.WriteAsync(new MxApiMatrixException() {
+ ErrorCode = "M_UNKNOWN",
+ Error = exceptionHandlerPathFeature?.Error.ToString()
+ }.GetAsJson());
+ }
+ });
+});
+
-app.UseAuthorization();
+// app.UseAuthorization();
app.MapControllers();
-app.Run();
+app.Run();
\ No newline at end of file
diff --git a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
index dc8a8dc..e0f9db5 100644
--- a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
+++ b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
@@ -1,20 +1,37 @@
+using ArcaneLibs.Extensions;
using LibMatrix.Homeservers;
using LibMatrix.Services;
using MxApiExtensions.Classes.LibMatrix;
namespace MxApiExtensions.Services;
-public class AuthenticatedHomeserverProviderService {
- private readonly AuthenticationService _authenticationService;
- private readonly HomeserverProviderService _homeserverProviderService;
-
- public AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService) {
- _authenticationService = authenticationService;
- _homeserverProviderService = homeserverProviderService;
+public class AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService, IHttpContextAccessor request) {
+ public async Task<RemoteHomeserver?> TryGetRemoteHomeserver() {
+ try {
+ return await GetRemoteHomeserver();
+ }
+ catch {
+ return null;
+ }
+ }
+
+ public async Task<RemoteHomeserver> GetRemoteHomeserver() {
+ try {
+ return await GetHomeserver();
+ }
+ catch (MxApiMatrixException e) {
+ if (e is not { ErrorCode: "M_MISSING_TOKEN" }) throw;
+ if (!request.HttpContext!.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM"))
+ throw new MxApiMatrixException() {
+ ErrorCode = "MXAE_MISSING_UPSTREAM",
+ Error = "[MxApiExtensions] Missing MXAE_UPSTREAM header for unauthenticated request, this should be a server_name!"
+ };
+ return await homeserverProviderService.GetRemoteHomeserver(request.HttpContext.Request.Headers.GetByCaseInsensitiveKey("MXAE_UPSTREAM")[0]);
+ }
}
public async Task<AuthenticatedHomeserverGeneric> GetHomeserver() {
- var token = _authenticationService.GetToken();
+ var token = authenticationService.GetToken();
if (token == null) {
throw new MxApiMatrixException {
ErrorCode = "M_MISSING_TOKEN",
@@ -22,7 +39,7 @@ public class AuthenticatedHomeserverProviderService {
};
}
- var mxid = await _authenticationService.GetMxidFromToken(token);
+ var mxid = await authenticationService.GetMxidFromToken(token);
if (mxid == "@anonymous:*") {
throw new MxApiMatrixException {
ErrorCode = "M_MISSING_TOKEN",
@@ -31,6 +48,6 @@ public class AuthenticatedHomeserverProviderService {
}
var hsCanonical = string.Join(":", mxid.Split(':').Skip(1));
- return await _homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token);
+ return await homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token);
}
}
diff --git a/MxApiExtensions/Services/AuthenticationService.cs b/MxApiExtensions/Services/AuthenticationService.cs
index 9eac20a..0dcc8b1 100644
--- a/MxApiExtensions/Services/AuthenticationService.cs
+++ b/MxApiExtensions/Services/AuthenticationService.cs
@@ -3,21 +3,11 @@ using MxApiExtensions.Classes.LibMatrix;
namespace MxApiExtensions.Services;
-public class AuthenticationService {
- private readonly ILogger<AuthenticationService> _logger;
- private readonly MxApiExtensionsConfiguration _config;
- private readonly HomeserverProviderService _homeserverProviderService;
- private readonly HttpRequest _request;
+public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiExtensionsConfiguration config, IHttpContextAccessor request, HomeserverProviderService homeserverProviderService) {
+ private readonly HttpRequest _request = request.HttpContext!.Request;
private static Dictionary<string, string> _tokenMap = new();
- public AuthenticationService(ILogger<AuthenticationService> logger, MxApiExtensionsConfiguration config, IHttpContextAccessor request, HomeserverProviderService homeserverProviderService) {
- _logger = logger;
- _config = config;
- _homeserverProviderService = homeserverProviderService;
- _request = request.HttpContext!.Request;
- }
-
internal string? GetToken(bool fail = true) {
string? token;
if (_request.Headers.TryGetValue("Authorization", out var tokens)) {
@@ -59,7 +49,7 @@ public class AuthenticationService {
if (_tokenMap.TryGetValue(token, out var mxid)) return mxid;
var lookupTasks = new Dictionary<string, Task<string?>>();
- foreach (var homeserver in _config.AuthHomeservers) {
+ foreach (var homeserver in config.AuthHomeservers) {
lookupTasks.Add(homeserver, GetMxidFromToken(token, homeserver));
await lookupTasks[homeserver].WaitAsync(TimeSpan.FromMilliseconds(250));
if(lookupTasks[homeserver].IsCompletedSuccessfully && !string.IsNullOrWhiteSpace(lookupTasks[homeserver].Result)) break;
@@ -70,7 +60,7 @@ public class AuthenticationService {
if(mxid is null) {
throw new MxApiMatrixException {
ErrorCode = "M_UNKNOWN_TOKEN",
- Error = "Token not found on any configured homeservers: " + string.Join(", ", _config.AuthHomeservers)
+ Error = "Token not found on any configured homeservers: " + string.Join(", ", config.AuthHomeservers)
};
}
@@ -93,17 +83,17 @@ public class AuthenticationService {
//
// var json = (await JsonDocument.ParseAsync(await resp.Content.ReadAsStreamAsync())).RootElement;
// var mxid = json.GetProperty("user_id").GetString()!;
- _logger.LogInformation("Got mxid {} from token {}", mxid, token);
+ logger.LogInformation("Got mxid {} from token {}", mxid, token);
await SaveMxidForToken(token, mxid);
return mxid;
}
private async Task<string?> GetMxidFromToken(string token, string hsDomain) {
- _logger.LogInformation("Looking up mxid for token {} on {}", token, hsDomain);
- var hs = await _homeserverProviderService.GetAuthenticatedWithToken(hsDomain, token);
+ logger.LogInformation("Looking up mxid for token {} on {}", token, hsDomain);
+ var hs = await homeserverProviderService.GetAuthenticatedWithToken(hsDomain, token);
try {
var res = hs.WhoAmI.UserId;
- _logger.LogInformation("Got mxid {} for token {} on {}", res, token, hsDomain);
+ logger.LogInformation("Got mxid {} for token {} on {}", res, token, hsDomain);
return res;
}
catch (MxApiMatrixException e) {
|