diff --git a/MxApiExtensions/Classes/CustomLogFormatter.cs b/MxApiExtensions/Classes/CustomLogFormatter.cs
new file mode 100644
index 0000000..69812e5
--- /dev/null
+++ b/MxApiExtensions/Classes/CustomLogFormatter.cs
@@ -0,0 +1,12 @@
+using Microsoft.Extensions.Logging.Abstractions;
+using Microsoft.Extensions.Logging.Console;
+
+namespace MxApiExtensions.Classes;
+
+public class CustomLogFormatter : ConsoleFormatter {
+ public CustomLogFormatter(string name) : base(name) { }
+
+ public override void Write<TState>(in LogEntry<TState> logEntry, IExternalScopeProvider? scopeProvider, TextWriter textWriter) {
+ Console.WriteLine("Log message");
+ }
+}
\ No newline at end of file
diff --git a/MxApiExtensions/Classes/MxApiExtensionsUserConfiguration.cs b/MxApiExtensions/Classes/MxApiExtensionsUserConfiguration.cs
new file mode 100644
index 0000000..db83637
--- /dev/null
+++ b/MxApiExtensions/Classes/MxApiExtensionsUserConfiguration.cs
@@ -0,0 +1,21 @@
+using LibMatrix.EventTypes;
+using LibMatrix.Interfaces;
+
+namespace MxApiExtensions.Classes;
+
+[MatrixEvent(EventName = EventId)]
+public class MxApiExtensionsUserConfiguration : EventContent {
+ public const string EventId = "gay.rory.mxapiextensions.userconfig";
+ public ProtocolChangeConfiguration ProtocolChanges { get; set; } = new();
+ public InitialSyncConfiguration InitialSyncPreload { get; set; } = new();
+
+ public class InitialSyncConfiguration {
+ public bool Enable { get; set; } = true;
+ }
+
+ public class ProtocolChangeConfiguration {
+ public bool DisableThreads { get; set; } = false;
+ public bool DisableVoip { get; set; } = false;
+ public bool AutoFollowTombstones { get; set; } = false;
+ }
+}
\ No newline at end of file
diff --git a/MxApiExtensions/Classes/SyncFilter.cs b/MxApiExtensions/Classes/SyncFilter.cs
new file mode 100644
index 0000000..7f2bd08
--- /dev/null
+++ b/MxApiExtensions/Classes/SyncFilter.cs
@@ -0,0 +1,7 @@
+using LibMatrix.Responses;
+
+namespace MxApiExtensions.Classes;
+
+public interface ISyncFilter {
+ public Task<SyncResponse> Apply(SyncResponse syncResponse);
+}
\ No newline at end of file
diff --git a/MxApiExtensions/Classes/SyncState.cs b/MxApiExtensions/Classes/SyncState.cs
index e44d35c..733f26d 100644
--- a/MxApiExtensions/Classes/SyncState.cs
+++ b/MxApiExtensions/Classes/SyncState.cs
@@ -10,13 +10,12 @@ using Microsoft.OpenApi.Extensions;
namespace MxApiExtensions.Classes;
public class SyncState {
- private Task? _nextSyncResponse;
+ private Task<HttpResponseMessage>? _nextSyncResponse;
public string? NextBatch { get; set; }
public ConcurrentQueue<SyncResponse> SyncQueue { get; set; } = new();
- public bool IsInitialSync { get; set; }
[JsonIgnore]
- public Task? NextSyncResponse {
+ public Task<HttpResponseMessage>? NextSyncResponse {
get => _nextSyncResponse;
set {
_nextSyncResponse = value;
@@ -25,7 +24,7 @@ public class SyncState {
}
public DateTime NextSyncResponseStartedAt { get; set; } = DateTime.Now;
-
+
[JsonIgnore]
public AuthenticatedHomeserverGeneric Homeserver { get; set; }
@@ -39,53 +38,39 @@ public class SyncState {
NextSyncResponse?.IsFaulted,
Status = NextSyncResponse?.Status.GetDisplayName()
};
-
#endregion
- public void SendEphemeralTimelineEventInRoom(string roomId, StateEventResponse @event) {
- SyncQueue.Enqueue(new() {
- NextBatch = NextBatch ?? "null",
- Rooms = new() {
- Join = new() {
- {
- roomId,
- new() {
- Timeline = new() {
- Events = new() {
- @event
- }
- }
- }
- }
- }
- }
- });
+ public SyncResponse SendEphemeralTimelineEventInRoom(string roomId, StateEventResponse @event, SyncResponse? existingResponse = null) {
+ if(existingResponse is null)
+ SyncQueue.Enqueue(existingResponse = new());
+ existingResponse.Rooms ??= new();
+ existingResponse.Rooms.Join ??= new();
+ existingResponse.Rooms.Join.TryAdd(roomId, new());
+ existingResponse.Rooms.Join[roomId].Timeline ??= new();
+ existingResponse.Rooms.Join[roomId].Timeline.Events ??= new();
+ existingResponse.Rooms.Join[roomId].Timeline.Events.Add(@event);
+ return existingResponse;
}
- public void SendStatusMessage(string text) {
- SyncQueue.Enqueue(new() {
- NextBatch = NextBatch ?? "null",
- Presence = new() {
- Events = new() {
- new StateEventResponse {
- TypedContent = new PresenceEventContent {
- DisplayName = "MxApiExtensions",
- Presence = "online",
- StatusMessage = text,
- // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl
- AvatarUrl = "",
- LastActiveAgo = 15,
- CurrentlyActive = true
- },
- Type = "m.presence",
- StateKey = Homeserver.WhoAmI.UserId,
- Sender = Homeserver.WhoAmI.UserId,
- EventId = Guid.NewGuid().ToString(),
- OriginServerTs = 0
- }
- }
- }
+ public SyncResponse SendStatusMessage(string text, SyncResponse? existingResponse = null) {
+ if(existingResponse is null)
+ SyncQueue.Enqueue(existingResponse = new());
+ existingResponse.Presence ??= new();
+ // existingResponse.Presence.Events ??= new();
+ existingResponse.Presence.Events.RemoveAll(x => x.Sender == Homeserver.WhoAmI.UserId);
+ existingResponse.Presence.Events.Add(new StateEventResponse {
+ TypedContent = new PresenceEventContent {
+ Presence = "online",
+ StatusMessage = text,
+ LastActiveAgo = 15,
+ CurrentlyActive = true
+ },
+ Type = "m.presence",
+ StateKey = "",
+ Sender = Homeserver.WhoAmI.UserId,
+ OriginServerTs = 0
});
+ return existingResponse;
}
}
\ No newline at end of file
diff --git a/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs
index 3d1d4e2..b756582 100644
--- a/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs
+++ b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs
@@ -13,6 +13,6 @@ public class RoomController(ILogger<LoginController> logger, HomeserverResolverS
public async Task<Dictionary<string, List<string>>> GetRoomMembersByHomeserver(string _, [FromRoute] string roomId, [FromQuery] bool joinedOnly = true) {
var hs = await hsProvider.GetHomeserver();
var room = hs.GetRoom(roomId);
- return await room.GetMembersByHomeserverAsync();
+ return await room.GetMembersByHomeserverAsync(joinedOnly);
}
}
\ No newline at end of file
diff --git a/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs
index 47d9899..e882c8a 100644
--- a/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs
+++ b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs
@@ -13,23 +13,24 @@ using LibMatrix.Services;
using Microsoft.AspNetCore.Mvc;
using MxApiExtensions.Classes;
using MxApiExtensions.Classes.LibMatrix;
+using MxApiExtensions.Extensions;
using MxApiExtensions.Services;
namespace MxApiExtensions.Controllers;
[ApiController]
[Route("/")]
-public class RoomsSendMessageController(ILogger<LoginController> logger, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf,
- AuthenticatedHomeserverProviderService hsProvider)
+public class RoomsSendMessageController(ILogger<LoginController> logger, UserContextService userContextService)
: ControllerBase {
[HttpPut("/_matrix/client/{_}/rooms/{roomId}/send/m.room.message/{txnId}")]
public async Task Proxy([FromBody] JsonObject request, [FromRoute] string roomId, [FromRoute] string txnId, string _) {
- var hs = await hsProvider.GetHomeserver();
+ var uc = await userContextService.GetCurrentUserContext();
+ // var hs = await hsProvider.GetHomeserver();
var msg = request.Deserialize<RoomMessageEventContent>();
if (msg is not null && msg.Body.StartsWith("mxae!")) {
#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
- handleMxaeCommand(hs, roomId, msg);
+ handleMxaeCommand(uc, roomId, msg);
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
await Response.WriteAsJsonAsync(new EventIdResponse() {
EventId = "$" + string.Join("", Random.Shared.GetItems("abcdefghijklmnopqrstuvwxyzABCDEFGHIJLKMNOPQRSTUVWXYZ0123456789".ToCharArray(), 100))
@@ -38,13 +39,14 @@ public class RoomsSendMessageController(ILogger<LoginController> logger, Homeser
}
else {
try {
- var resp = await hs.ClientHttpClient.PutAsJsonAsync($"{Request.Path}{Request.QueryString}", request);
- var loginResp = await resp.Content.ReadAsStringAsync();
- Response.StatusCode = (int)resp.StatusCode;
- Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json";
- await Response.StartAsync();
- await Response.WriteAsync(loginResp);
- await Response.CompleteAsync();
+ var resp = await uc.Homeserver.ClientHttpClient.PutAsJsonAsync($"{Request.Path}{Request.QueryString}", request);
+ await Response.WriteHttpResponse(resp);
+ // var loginResp = await resp.Content.ReadAsStringAsync();
+ // Response.StatusCode = (int)resp.StatusCode;
+ // Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json";
+ // await Response.StartAsync();
+ // await Response.WriteAsync(loginResp);
+ // await Response.CompleteAsync();
}
catch (MatrixException e) {
await Response.StartAsync();
@@ -54,10 +56,9 @@ public class RoomsSendMessageController(ILogger<LoginController> logger, Homeser
}
}
- private async Task handleMxaeCommand(AuthenticatedHomeserverGeneric hs, string roomId, RoomMessageEventContent msg) {
- var syncState = SyncController.SyncStates.GetValueOrDefault(hs.AccessToken);
- if (syncState is null) return;
- syncState.SendEphemeralTimelineEventInRoom(roomId, new() {
+ private async Task handleMxaeCommand(UserContextService.UserContext hs, string roomId, RoomMessageEventContent msg) {
+ if (hs.SyncState is null) return;
+ hs.SyncState.SendEphemeralTimelineEventInRoom(roomId, new() {
Sender = "@mxae:" + Request.Host.Value,
Type = "m.room.message",
TypedContent = MessageFormatter.FormatSuccess("Thinking..."),
diff --git a/MxApiExtensions/Controllers/Client/SyncController.cs b/MxApiExtensions/Controllers/Client/SyncController.cs
index 8a5ba06..7f9ed1d 100644
--- a/MxApiExtensions/Controllers/Client/SyncController.cs
+++ b/MxApiExtensions/Controllers/Client/SyncController.cs
@@ -4,6 +4,7 @@ using System.Text.Json;
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Web;
+using System.Xml;
using ArcaneLibs;
using LibMatrix;
using LibMatrix.EventTypes.Spec.State;
@@ -21,102 +22,76 @@ namespace MxApiExtensions.Controllers;
[ApiController]
[Route("/")]
-public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hsProvider)
+public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hsProvider,
+ UserContextService userContextService)
: ControllerBase {
- public static readonly ConcurrentDictionary<string, SyncState> SyncStates = new();
-
- private static SemaphoreSlim _semaphoreSlim = new(1, 1);
+ private UserContextService.UserContext userContext;
private Stopwatch _syncElapsed = Stopwatch.StartNew();
+ private static SemaphoreSlim _semaphoreSlim = new(1, 1);
[HttpGet("/_matrix/client/{_}/sync")]
public async Task Sync(string _, [FromQuery] string? since, [FromQuery] int timeout = 1000) {
+ // temporary variables
+ bool startedNewTask = false;
Task? preloadTask = null;
- AuthenticatedHomeserverGeneric? hs = null;
- try {
- hs = await hsProvider.GetHomeserver();
- }
- catch (Exception e) {
- Console.WriteLine(e);
- }
+ // get user context based on authentication
+ userContext = await userContextService.GetCurrentUserContext();
var qs = HttpUtility.ParseQueryString(Request.QueryString.Value!);
qs.Remove("access_token");
if (since == "null") qs.Remove("since");
- if (!config.FastInitialSync.Enabled) {
- logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
- var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}");
- await Response.WriteHttpResponse(result);
- return;
- }
+ // if (!userContext.UserConfiguration.InitialSyncPreload.Enable) {
+ // logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
+ // var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}");
+ // await Response.WriteHttpResponse(result);
+ // return;
+ // }
+ //prevent duplicate initialisation
await _semaphoreSlim.WaitAsync();
- var syncState = SyncStates.GetOrAdd($"{hs.WhoAmI.UserId}/{hs.WhoAmI.DeviceId}/{hs.ServerName}:{hs.AccessToken}", _ => {
- logger.LogInformation("Started tracking sync state for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
- var ss = new SyncState {
- IsInitialSync = string.IsNullOrWhiteSpace(since),
- Homeserver = hs
+
+ //if we don't have a sync state for this user...
+ if (userContext.SyncState is null) {
+ logger.LogInformation("Started tracking sync state for {} on {} ({})", userContext.Homeserver.WhoAmI.UserId, userContext.Homeserver.ServerName,
+ userContext.Homeserver.AccessToken);
+
+ //create a new sync state
+ userContext.SyncState = new SyncState {
+ Homeserver = userContext.Homeserver,
+ NextSyncResponse = Task.Run(async () => {
+ if (string.IsNullOrWhiteSpace(since) && userContext.UserConfiguration.InitialSyncPreload.Enable)
+ await Task.Delay(15_000);
+ logger.LogInformation("Sync for {} on {} ({}) starting", userContext.Homeserver.WhoAmI.UserId, userContext.Homeserver.ServerName,
+ userContext.Homeserver.AccessToken);
+ return await userContext.Homeserver.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}");
+ })
};
- if (ss.IsInitialSync) {
- preloadTask = EnqueuePreloadData(ss);
+ startedNewTask = true;
+
+ //if this is an initial sync, and the user has enabled this, preload data
+ if (string.IsNullOrWhiteSpace(since) && userContext.UserConfiguration.InitialSyncPreload.Enable) {
+ logger.LogInformation("Sync data preload for {} on {} ({}) starting", userContext.Homeserver.WhoAmI.UserId, userContext.Homeserver.ServerName,
+ userContext.Homeserver.AccessToken);
+ preloadTask = EnqueuePreloadData(userContext.SyncState);
}
-
- logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
-
- ss.NextSyncResponseStartedAt = DateTime.Now;
- ss.NextSyncResponse = Task.Delay(15_000);
- ss.NextSyncResponse.ContinueWith(x => {
- logger.LogInformation("Sync for {} on {} ({}) starting", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken);
- ss.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}");
- (ss.NextSyncResponse as Task<HttpResponseMessage>).ContinueWith(async x => EnqueueSyncResponse(ss, await x));
- });
- return ss;
- });
- _semaphoreSlim.Release();
-
- if (syncState.SyncQueue.Count > 0) {
- logger.LogInformation("Sync for {} on {} ({}) has {} queued results", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, syncState.SyncQueue.Count);
- syncState.SyncQueue.TryDequeue(out var result);
- while (result is null)
- syncState.SyncQueue.TryDequeue(out result);
- Response.StatusCode = StatusCodes.Status200OK;
- Response.ContentType = "application/json";
- await Response.StartAsync();
- result.NextBatch ??= since ?? syncState.NextBatch!;
- await JsonSerializer.SerializeAsync(Response.Body, result, new JsonSerializerOptions {
- WriteIndented = true,
- DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
- });
- await Response.CompleteAsync();
- return;
}
- var newTimeout = Math.Clamp(timeout, 0, syncState.IsInitialSync ? syncState.SyncQueue.Count >= 2 ? 0 : 250 : timeout);
- logger.LogInformation("Sync for {} on {} ({}) is still running, waiting for {}ms, {} elapsed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, newTimeout,
- DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt));
-
- try {
- if (syncState.NextSyncResponse is not null)
- await syncState.NextSyncResponse.WaitAsync(TimeSpan.FromMilliseconds(newTimeout));
- else {
- syncState.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}");
- (syncState.NextSyncResponse as Task<HttpResponseMessage>)!.ContinueWith(async x => EnqueueSyncResponse(syncState, await x));
- // await Task.Delay(250);
- }
+ if (userContext.SyncState.NextSyncResponse is null) {
+ userContext.SyncState.NextSyncResponse = userContext.Homeserver.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}");
+ startedNewTask = true;
}
- catch (TimeoutException) { }
- // if (_syncElapsed.ElapsedMilliseconds > timeout)
- if(syncState.NextSyncResponse?.IsCompleted == false)
- syncState.SendStatusMessage(
- $"M={Util.BytesToString(Process.GetCurrentProcess().WorkingSet64)} TE={DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} S={syncState.NextSyncResponse?.Status} QL={syncState.SyncQueue.Count}");
+ _semaphoreSlim.Release();
+
+ //get the next sync response
+ var syncResponse = await GetNextSyncResponse(timeout);
+ //send it to the client
Response.StatusCode = StatusCodes.Status200OK;
Response.ContentType = "application/json";
await Response.StartAsync();
- var response = syncState.SyncQueue.FirstOrDefault();
- if (response is null)
- response = new();
- response.NextBatch ??= since ?? syncState.NextBatch!;
+ var response = syncResponse;
+ response.NextBatch ??= since ?? "null";
await JsonSerializer.SerializeAsync(Response.Body, response, new JsonSerializerOptions {
WriteIndented = true,
DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull
@@ -124,20 +99,117 @@ public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfi
await Response.CompleteAsync();
Response.Body.Close();
+
+ //await scope-local tasks in order to prevent disposal
if (preloadTask is not null) {
await preloadTask;
preloadTask.Dispose();
}
+
+ if (startedNewTask && userContext.SyncState?.NextSyncResponse is not null) {
+ var resp = await userContext.SyncState.NextSyncResponse;
+ var sr = await resp.Content.ReadFromJsonAsync<JsonObject>();
+ if (sr!.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!;
+ userContext.SyncState.NextBatch = sr["next_batch"]!.GetValue<string>();
+ // userContext.SyncState.IsInitialSync = false;
+ var syncResp = sr.Deserialize<SyncResponse>();
+ userContext.SyncState.SyncQueue.Enqueue(syncResp);
+ userContext.SyncState.NextSyncResponse.Dispose();
+ userContext.SyncState.NextSyncResponse = null;
+ }
+ }
+
+ private async Task<SyncResponse> GetNextSyncResponse(int timeout = 0) {
+ do {
+ if (userContext.SyncState is null) throw new NullReferenceException("syncState is null!");
+ // if (userContext.SyncState.NextSyncResponse is null) throw new NullReferenceException("NextSyncResponse is null");
+
+ //check if upstream has responded, if so, return upstream response
+ // if (userContext.SyncState.NextSyncResponse is { IsCompleted: true } syncResponse) {
+ // var resp = await syncResponse;
+ // var sr = await resp.Content.ReadFromJsonAsync<JsonObject>();
+ // if (sr!.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!;
+ // userContext.SyncState.NextBatch = sr["next_batch"]!.GetValue<string>();
+ // // userContext.SyncState.IsInitialSync = false;
+ // var syncResp = sr.Deserialize<SyncResponse>();
+ // return syncResp;
+ // }
+
+ //else, return the first item in queue, if any
+ if (userContext.SyncState.SyncQueue.Count > 0) {
+ logger.LogInformation("Sync for {} on {} ({}) has {} queued results", userContext.SyncState.Homeserver.WhoAmI.UserId, userContext.SyncState.Homeserver.ServerName,
+ userContext.SyncState.Homeserver.AccessToken, userContext.SyncState.SyncQueue.Count);
+ userContext.SyncState.SyncQueue.TryDequeue(out var result);
+ while (result is null)
+ userContext.SyncState.SyncQueue.TryDequeue(out result);
+ return result;
+ }
+
+ // await Task.Delay(Math.Clamp(timeout, 25, 250)); //wait 25-250ms between checks
+ await Task.Delay(Math.Clamp(userContextService.SessionCount * 10 ,25, 500));
+ } while (_syncElapsed.ElapsedMilliseconds < timeout + 500); //... while we haven't gone >500ms over expected timeout
+
+ //we didn't get a response, send a bogus response
+ return userContext.SyncState.SendStatusMessage(
+ $"M={Util.BytesToString(Process.GetCurrentProcess().WorkingSet64)} TE={DateTime.Now.Subtract(userContext.SyncState.NextSyncResponseStartedAt)} QL={userContext.SyncState.SyncQueue.Count}",
+ new());
}
+
private async Task EnqueuePreloadData(SyncState syncState) {
+ await EnqueuePreloadAccountData(syncState);
+ await EnqueuePreloadRooms(syncState);
+ }
+
+ private static List<string> CommonAccountDataKeys = new() {
+ "gay.rory.dm_space",
+ "im.fluffychat.account_bundles",
+ "im.ponies.emote_rooms",
+ "im.vector.analytics",
+ "im.vector.setting.breadcrumbs",
+ "im.vector.setting.integration_provisioning",
+ "im.vector.web.settings",
+ "io.element.recent_emoji",
+ "m.cross_signing.master",
+ "m.cross_signing.self_signing",
+ "m.cross_signing.user_signing",
+ "m.direct",
+ "m.megolm_backup.v1",
+ "m.push_rules",
+ "m.secret_storage.default_key",
+ "gay.rory.mxapiextensions.userconfig"
+ };
+ //enqueue common account data
+ private async Task EnqueuePreloadAccountData(SyncState syncState) {
+ var syncMsg = new SyncResponse() {
+ AccountData = new() {
+ Events = new()
+ }
+ };
+ foreach (var key in CommonAccountDataKeys) {
+ try {
+ syncMsg.AccountData.Events.Add(new() {
+ Type = key,
+ RawContent = await syncState.Homeserver.GetAccountDataAsync<JsonObject>(key)
+ });
+ }
+ catch {}
+ }
+ syncState.SyncQueue.Enqueue(syncMsg);
+ }
+
+ private async Task EnqueuePreloadRooms(SyncState syncState) {
+ //get the users's rooms
var rooms = await syncState.Homeserver.GetJoinedRooms();
- var dmRooms = (await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct")).Aggregate(new List<string>(), (list, entry) => {
- list.AddRange(entry.Value);
- return list;
- });
+
+ //get the user's DM rooms
+ var mDirectContent = await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct");
+ var dmRooms = mDirectContent.SelectMany(pair => pair.Value);
+ //get our own homeserver's server_name
var ownHs = syncState.Homeserver.WhoAmI!.UserId!.Split(':')[1];
+
+ //order rooms by expected state size, since large rooms take a long time to return
rooms = rooms.OrderBy(x => {
if (dmRooms.Contains(x.RoomId)) return -1;
var parts = x.RoomId.Split(':');
@@ -145,38 +217,35 @@ public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfi
if (HomeserverWeightEstimation.EstimatedSize.ContainsKey(parts[1])) return HomeserverWeightEstimation.EstimatedSize[parts[1]] + parts[0].Length;
return 5000;
}).ToList();
+
+ //start all fetch tasks
var roomDataTasks = rooms.Select(room => EnqueueRoomData(syncState, room)).ToList();
logger.LogInformation("Preloading data for {} rooms on {} ({})", roomDataTasks.Count, syncState.Homeserver.ServerName, syncState.Homeserver.AccessToken);
+ //wait for all of them to finish
await Task.WhenAll(roomDataTasks);
}
- private readonly SemaphoreSlim _roomDataSemaphore = new(32, 32);
+ private static readonly SemaphoreSlim _roomDataSemaphore = new(4, 4);
private async Task EnqueueRoomData(SyncState syncState, GenericRoom room) {
+ //limit concurrent requests, to not overload upstream
await _roomDataSemaphore.WaitAsync();
+ //get the room's state
var roomState = room.GetFullStateAsync();
+ //get the room's timeline, reversed
var timeline = await room.GetMessagesAsync(limit: 100, dir: "b");
timeline.Chunk.Reverse();
+ //queue up this data as a sync response
var syncResponse = new SyncResponse {
Rooms = new() {
Join = new() {
{
room.RoomId,
new SyncResponse.RoomsDataStructure.JoinedRoomDataStructure {
- AccountData = new() {
- Events = new()
- },
- Ephemeral = new() {
- Events = new()
- },
State = new() {
Events = timeline.State
},
- UnreadNotifications = new() {
- HighlightCount = 0,
- NotificationCount = 0
- },
Timeline = new() {
Events = timeline.Chunk,
Limited = false,
@@ -190,57 +259,24 @@ public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfi
}
}
}
- },
- Presence = new() {
- Events = new() {
- await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status} {room.RoomId}")
- }
- },
- NextBatch = ""
+ }
};
+ //calculate invited/joined member counts, and add other events to state
await foreach (var stateEvent in roomState) {
+ if (stateEvent is { Type: "m.room.member" }) {
+ if (stateEvent.TypedContent is RoomMemberEventContent { Membership: "join" })
+ syncResponse.Rooms.Join[room.RoomId].Summary.JoinedMemberCount++;
+ else if (stateEvent.TypedContent is RoomMemberEventContent { Membership: "invite" })
+ syncResponse.Rooms.Join[room.RoomId].Summary.InvitedMemberCount++;
+ else continue;
+ }
+
syncResponse.Rooms.Join[room.RoomId].State!.Events!.Add(stateEvent!);
}
- var joinRoom = syncResponse.Rooms.Join[room.RoomId];
- joinRoom.Summary!.Heroes.AddRange(joinRoom.State!.Events!
- .Where(x =>
- x.Type == "m.room.member"
- && x.StateKey != syncState.Homeserver.WhoAmI!.UserId
- && (x.TypedContent as RoomMemberEventContent)!.Membership == "join"
- )
- .Select(x => x.StateKey));
- joinRoom.Summary.JoinedMemberCount = joinRoom.Summary.Heroes.Count;
-
+ //finally, actually put the response in queue
syncState.SyncQueue.Enqueue(syncResponse);
_roomDataSemaphore.Release();
}
-
- private async Task<StateEventResponse> GetStatusMessage(SyncState syncState, string message) {
- return new StateEventResponse {
- TypedContent = new PresenceEventContent {
- DisplayName = "MxApiExtensions",
- Presence = "online",
- StatusMessage = message,
- // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl
- AvatarUrl = ""
- },
- Type = "m.presence",
- StateKey = syncState.Homeserver.WhoAmI!.UserId!,
- Sender = syncState.Homeserver.WhoAmI!.UserId!,
- EventId = Guid.NewGuid().ToString(),
- OriginServerTs = 0
- };
- }
-
- private async Task EnqueueSyncResponse(SyncState ss, HttpResponseMessage task) {
- var sr = await task.Content.ReadFromJsonAsync<JsonObject>();
- if (sr!.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!;
- ss.NextBatch = sr["next_batch"]!.GetValue<string>();
- ss.IsInitialSync = false;
- ss.SyncQueue.Enqueue(sr.Deserialize<SyncResponse>()!);
- task.Dispose();
- ss.NextSyncResponse = null;
- }
}
\ No newline at end of file
diff --git a/MxApiExtensions/Controllers/Extensions/DebugController.cs b/MxApiExtensions/Controllers/Extensions/DebugController.cs
index c65df56..ae9ecc5 100644
--- a/MxApiExtensions/Controllers/Extensions/DebugController.cs
+++ b/MxApiExtensions/Controllers/Extensions/DebugController.cs
@@ -7,24 +7,17 @@ namespace MxApiExtensions.Controllers.Extensions;
[ApiController]
[Route("/")]
-public class DebugController : ControllerBase {
- private readonly ILogger _logger;
- private readonly MxApiExtensionsConfiguration _config;
- private readonly AuthenticationService _authenticationService;
+public class DebugController(ILogger<ProxyConfigurationController> logger, MxApiExtensionsConfiguration config, UserContextService userContextService)
+ : ControllerBase {
+ private readonly ILogger _logger = logger;
private static ConcurrentDictionary<string, RoomInfoEntry> _roomInfoCache = new();
- public DebugController(ILogger<ProxyConfigurationController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService,
- AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) {
- _logger = logger;
- _config = config;
- _authenticationService = authenticationService;
- }
-
[HttpGet("debug")]
public async Task<object?> GetDebug() {
- var mxid = await _authenticationService.GetMxidFromToken();
- if(!_config.Admins.Contains(mxid)) {
+ var user = await userContextService.GetCurrentUserContext();
+ var mxid = user.Homeserver.UserId;
+ if(!config.Admins.Contains(mxid)) {
_logger.LogWarning("Got debug request for {user}, but they are not an admin", mxid);
Response.StatusCode = StatusCodes.Status403Forbidden;
Response.ContentType = "application/json";
@@ -38,8 +31,6 @@ public class DebugController : ControllerBase {
}
_logger.LogInformation("Got debug request for {user}", mxid);
- return new {
- SyncStates = SyncController.SyncStates
- };
+ return UserContextService.UserContextStore;
}
}
diff --git a/MxApiExtensions/Controllers/Other/MediaProxyController.cs b/MxApiExtensions/Controllers/Other/MediaProxyController.cs
index 03b68ba..fb40aa2 100644
--- a/MxApiExtensions/Controllers/Other/MediaProxyController.cs
+++ b/MxApiExtensions/Controllers/Other/MediaProxyController.cs
@@ -23,7 +23,7 @@ public class MediaProxyController(ILogger<GenericController> logger, MxApiExtens
private static SemaphoreSlim _semaphore = new(1, 1);
[HttpGet("/_matrix/media/{_}/download/{serverName}/{mediaId}")]
- public async Task Proxy(string? _, string serverName, string mediaId) {
+ public async Task ProxyMedia(string? _, string serverName, string mediaId) {
try {
logger.LogInformation("Proxying media: {}{}", serverName, mediaId);
@@ -75,4 +75,7 @@ public class MediaProxyController(ILogger<GenericController> logger, MxApiExtens
await Response.CompleteAsync();
}
}
+
+ [HttpGet("/_matrix/media/{_}/thumbnail/{serverName}/{mediaId}")]
+ public async Task ProxyThumbnail(string? _, string serverName, string mediaId) => await ProxyMedia(_, serverName, mediaId);
}
diff --git a/MxApiExtensions/MxApiExtensionsConfiguration.cs b/MxApiExtensions/MxApiExtensionsConfiguration.cs
index c3b6297..8069e81 100644
--- a/MxApiExtensions/MxApiExtensionsConfiguration.cs
+++ b/MxApiExtensions/MxApiExtensionsConfiguration.cs
@@ -1,8 +1,12 @@
+using ArcaneLibs.Extensions;
+using MxApiExtensions.Classes;
+
namespace MxApiExtensions;
public class MxApiExtensionsConfiguration {
public MxApiExtensionsConfiguration(IConfiguration config) {
config.GetRequiredSection("MxApiExtensions").Bind(this);
+ if (DefaultUserConfiguration is null) throw new ArgumentNullException(nameof(DefaultUserConfiguration), $"Default user configuration not configured! Example: {new MxApiExtensionsUserConfiguration().ToJson()}");
}
public List<string> AuthHomeservers { get; set; } = new();
@@ -11,7 +15,7 @@ public class MxApiExtensionsConfiguration {
public FastInitialSyncConfiguration FastInitialSync { get; set; } = new();
public CacheConfiguration Cache { get; set; } = new();
-
+ public MxApiExtensionsUserConfiguration DefaultUserConfiguration { get; set; }
public class FastInitialSyncConfiguration {
public bool Enabled { get; set; } = true;
@@ -26,4 +30,5 @@ public class MxApiExtensionsConfiguration {
public TimeSpan ExtraTtlPerState { get; set; } = TimeSpan.FromMilliseconds(100);
}
}
+
}
diff --git a/MxApiExtensions/Program.cs b/MxApiExtensions/Program.cs
index 72a5dc9..21d8ba4 100644
--- a/MxApiExtensions/Program.cs
+++ b/MxApiExtensions/Program.cs
@@ -3,7 +3,9 @@ using LibMatrix;
using LibMatrix.Services;
using Microsoft.AspNetCore.Diagnostics;
using Microsoft.AspNetCore.Http.Timeouts;
+using Microsoft.Extensions.Logging.Console;
using MxApiExtensions;
+using MxApiExtensions.Classes;
using MxApiExtensions.Classes.LibMatrix;
using MxApiExtensions.Services;
@@ -22,6 +24,7 @@ builder.Services.AddSingleton<MxApiExtensionsConfiguration>();
builder.Services.AddScoped<AuthenticationService>();
builder.Services.AddScoped<AuthenticatedHomeserverProviderService>();
+builder.Services.AddScoped<UserContextService>();
builder.Services.AddSingleton<TieredStorageService>(x => {
var config = x.GetRequiredService<MxApiExtensionsConfiguration>();
@@ -54,6 +57,7 @@ builder.Services.AddCors(options => {
"Open",
policy => policy.AllowAnyOrigin().AllowAnyHeader());
});
+// builder.Logging.AddConsole(x => x.FormatterName = "custom").AddConsoleFormatter<CustomLogFormatter, SimpleConsoleFormatterOptions>();
var app = builder.Build();
diff --git a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
index e0f9db5..741beb3 100644
--- a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
+++ b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs
@@ -6,6 +6,7 @@ using MxApiExtensions.Classes.LibMatrix;
namespace MxApiExtensions.Services;
public class AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService, IHttpContextAccessor request) {
+ public HttpContext? _context = request.HttpContext;
public async Task<RemoteHomeserver?> TryGetRemoteHomeserver() {
try {
return await GetRemoteHomeserver();
@@ -21,7 +22,8 @@ public class AuthenticatedHomeserverProviderService(AuthenticationService authen
}
catch (MxApiMatrixException e) {
if (e is not { ErrorCode: "M_MISSING_TOKEN" }) throw;
- if (!request.HttpContext!.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM"))
+ if (request is null) throw new MxApiMatrixException() { ErrorCode = "M_UNKNOWN", Error = "[MxApiExtensions] Request was null for unauthenticated request!" };
+ if (!_context.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM"))
throw new MxApiMatrixException() {
ErrorCode = "MXAE_MISSING_UPSTREAM",
Error = "[MxApiExtensions] Missing MXAE_UPSTREAM header for unauthenticated request, this should be a server_name!"
diff --git a/MxApiExtensions/Services/AuthenticationService.cs b/MxApiExtensions/Services/AuthenticationService.cs
index 0dcc8b1..7430fcd 100644
--- a/MxApiExtensions/Services/AuthenticationService.cs
+++ b/MxApiExtensions/Services/AuthenticationService.cs
@@ -1,3 +1,5 @@
+using ArcaneLibs.Extensions;
+using LibMatrix;
using LibMatrix.Services;
using MxApiExtensions.Classes.LibMatrix;
@@ -17,7 +19,7 @@ public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiE
token = _request.Query["access_token"];
}
- if (token == null && fail) {
+ if (string.IsNullOrWhiteSpace(token) && fail) {
throw new MxApiMatrixException {
ErrorCode = "M_MISSING_TOKEN",
Error = "Missing access token"
@@ -29,7 +31,7 @@ public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiE
public async Task<string> GetMxidFromToken(string? token = null, bool fail = true) {
token ??= GetToken(fail);
- if (token == null) {
+ if (string.IsNullOrWhiteSpace(token)) {
if (fail) {
throw new MxApiMatrixException {
ErrorCode = "M_MISSING_TOKEN",
@@ -44,15 +46,34 @@ public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiE
_tokenMap = (await File.ReadAllLinesAsync("token_map"))
.Select(l => l.Split('\t'))
.ToDictionary(l => l[0], l => l[1]);
+
+ //THIS IS BROKEN, DO NOT USE!
+ // foreach (var (mapToken, mapUser) in _tokenMap) {
+ // try {
+ // var hs = await homeserverProviderService.GetAuthenticatedWithToken(mapUser.Split(':', 2)[1], mapToken);
+ // }
+ // catch (MatrixException e) {
+ // if (e is { ErrorCode: "M_UNKNOWN_TOKEN" }) _tokenMap[mapToken] = "";
+ // }
+ // catch {
+ // // ignored
+ // }
+ // }
+ // _tokenMap.RemoveAll((x, y) => string.IsNullOrWhiteSpace(y));
+ // await File.WriteAllTextAsync("token_map", _tokenMap.Aggregate("", (x, y) => $"{y.Key}\t{y.Value}\n"));
}
+
if (_tokenMap.TryGetValue(token, out var mxid)) return mxid;
var lookupTasks = new Dictionary<string, Task<string?>>();
foreach (var homeserver in config.AuthHomeservers) {
- lookupTasks.Add(homeserver, GetMxidFromToken(token, homeserver));
- await lookupTasks[homeserver].WaitAsync(TimeSpan.FromMilliseconds(250));
- if(lookupTasks[homeserver].IsCompletedSuccessfully && !string.IsNullOrWhiteSpace(lookupTasks[homeserver].Result)) break;
+ try {
+ lookupTasks.Add(homeserver, GetMxidFromToken(token, homeserver));
+ await lookupTasks[homeserver].WaitAsync(TimeSpan.FromMilliseconds(500));
+ if (lookupTasks[homeserver].IsCompletedSuccessfully && !string.IsNullOrWhiteSpace(lookupTasks[homeserver].Result)) break;
+ }
+ catch {}
}
await Task.WhenAll(lookupTasks.Values);
diff --git a/MxApiExtensions/Services/UserContextService.cs b/MxApiExtensions/Services/UserContextService.cs
new file mode 100644
index 0000000..ef19ced
--- /dev/null
+++ b/MxApiExtensions/Services/UserContextService.cs
@@ -0,0 +1,44 @@
+using System.Collections.Concurrent;
+using System.Text.Json.Serialization;
+using ArcaneLibs.Extensions;
+using LibMatrix;
+using LibMatrix.Homeservers;
+using MxApiExtensions.Classes;
+
+namespace MxApiExtensions.Services;
+
+public class UserContextService(MxApiExtensionsConfiguration config, AuthenticatedHomeserverProviderService hsProvider) {
+ internal static ConcurrentDictionary<string, UserContext> UserContextStore { get; set; } = new();
+ public int SessionCount = UserContextStore.Count;
+
+ public class UserContext {
+ public SyncState? SyncState { get; set; }
+ [JsonIgnore]
+ public AuthenticatedHomeserverGeneric Homeserver { get; set; }
+ public MxApiExtensionsUserConfiguration UserConfiguration { get; set; }
+ }
+
+ private readonly SemaphoreSlim _getUserContextSemaphore = new SemaphoreSlim(1, 1);
+ public async Task<UserContext> GetCurrentUserContext() {
+ var hs = await hsProvider.GetHomeserver();
+ // await _getUserContextSemaphore.WaitAsync();
+ var ucs = await UserContextStore.GetOrCreateAsync($"{hs.WhoAmI.UserId}/{hs.WhoAmI.DeviceId}/{hs.ServerName}:{hs.AccessToken}", async x => {
+ var userContext = new UserContext() {
+ Homeserver = hs
+ };
+ try {
+ userContext.UserConfiguration = await hs.GetAccountDataAsync<MxApiExtensionsUserConfiguration>(MxApiExtensionsUserConfiguration.EventId);
+ }
+ catch (MatrixException e) {
+ if (e is not { ErrorCode: "M_NOT_FOUND" }) throw;
+ userContext.UserConfiguration = config.DefaultUserConfiguration;
+ }
+
+ await hs.SetAccountDataAsync(MxApiExtensionsUserConfiguration.EventId, userContext.UserConfiguration);
+
+ return userContext;
+ }, _getUserContextSemaphore);
+ // _getUserContextSemaphore.Release();
+ return ucs;
+ }
+}
\ No newline at end of file
diff --git a/MxApiExtensions/appsettings.json b/MxApiExtensions/appsettings.json
index 5a73cbe..b16968a 100644
--- a/MxApiExtensions/appsettings.json
+++ b/MxApiExtensions/appsettings.json
@@ -29,6 +29,13 @@
"BaseTtl": "00:01:00",
"ExtraTtlPerState": "00:00:00.1000000"
}
+ },
+ "DefaultUserConfiguration": {
+ "ProtocolChanges": {
+ "DisableThreads": false,
+ "DisableVoip": false,
+ "AutoFollowTombstones": false
+ }
}
}
}
|