diff options
author | TheArcaneBrony <myrainbowdash949@gmail.com> | 2023-11-09 07:36:02 +0100 |
---|---|---|
committer | TheArcaneBrony <myrainbowdash949@gmail.com> | 2023-11-09 07:36:02 +0100 |
commit | 2e8aa30daa4a33fa33622bccb344dfc24483e320 (patch) | |
tree | ee08ce4e83382b81a1fbabaac85c763971408bbe /MxApiExtensions | |
parent | Fix some null checks (diff) | |
download | MxApiExtensions-2e8aa30daa4a33fa33622bccb344dfc24483e320.tar.xz |
Fix sync
Diffstat (limited to 'MxApiExtensions')
-rw-r--r-- | MxApiExtensions/Classes/CustomLogFormatter.cs | 12 | ||||
-rw-r--r-- | MxApiExtensions/Classes/MxApiExtensionsUserConfiguration.cs | 21 | ||||
-rw-r--r-- | MxApiExtensions/Classes/SyncFilter.cs | 7 | ||||
-rw-r--r-- | MxApiExtensions/Classes/SyncState.cs | 77 | ||||
-rw-r--r-- | MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs | 2 | ||||
-rw-r--r-- | MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs | 31 | ||||
-rw-r--r-- | MxApiExtensions/Controllers/Client/SyncController.cs | 302 | ||||
-rw-r--r-- | MxApiExtensions/Controllers/Extensions/DebugController.cs | 23 | ||||
-rw-r--r-- | MxApiExtensions/Controllers/Other/MediaProxyController.cs | 5 | ||||
-rw-r--r-- | MxApiExtensions/MxApiExtensionsConfiguration.cs | 7 | ||||
-rw-r--r-- | MxApiExtensions/Program.cs | 4 | ||||
-rw-r--r-- | MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs | 4 | ||||
-rw-r--r-- | MxApiExtensions/Services/AuthenticationService.cs | 31 | ||||
-rw-r--r-- | MxApiExtensions/Services/UserContextService.cs | 44 | ||||
-rw-r--r-- | MxApiExtensions/appsettings.json | 7 |
15 files changed, 358 insertions, 219 deletions
diff --git a/MxApiExtensions/Classes/CustomLogFormatter.cs b/MxApiExtensions/Classes/CustomLogFormatter.cs new file mode 100644 index 0000000..69812e5 --- /dev/null +++ b/MxApiExtensions/Classes/CustomLogFormatter.cs @@ -0,0 +1,12 @@ +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Logging.Console; + +namespace MxApiExtensions.Classes; + +public class CustomLogFormatter : ConsoleFormatter { + public CustomLogFormatter(string name) : base(name) { } + + public override void Write<TState>(in LogEntry<TState> logEntry, IExternalScopeProvider? scopeProvider, TextWriter textWriter) { + Console.WriteLine("Log message"); + } +} \ No newline at end of file diff --git a/MxApiExtensions/Classes/MxApiExtensionsUserConfiguration.cs b/MxApiExtensions/Classes/MxApiExtensionsUserConfiguration.cs new file mode 100644 index 0000000..db83637 --- /dev/null +++ b/MxApiExtensions/Classes/MxApiExtensionsUserConfiguration.cs @@ -0,0 +1,21 @@ +using LibMatrix.EventTypes; +using LibMatrix.Interfaces; + +namespace MxApiExtensions.Classes; + +[MatrixEvent(EventName = EventId)] +public class MxApiExtensionsUserConfiguration : EventContent { + public const string EventId = "gay.rory.mxapiextensions.userconfig"; + public ProtocolChangeConfiguration ProtocolChanges { get; set; } = new(); + public InitialSyncConfiguration InitialSyncPreload { get; set; } = new(); + + public class InitialSyncConfiguration { + public bool Enable { get; set; } = true; + } + + public class ProtocolChangeConfiguration { + public bool DisableThreads { get; set; } = false; + public bool DisableVoip { get; set; } = false; + public bool AutoFollowTombstones { get; set; } = false; + } +} \ No newline at end of file diff --git a/MxApiExtensions/Classes/SyncFilter.cs b/MxApiExtensions/Classes/SyncFilter.cs new file mode 100644 index 0000000..7f2bd08 --- /dev/null +++ b/MxApiExtensions/Classes/SyncFilter.cs @@ -0,0 +1,7 @@ +using LibMatrix.Responses; + +namespace MxApiExtensions.Classes; + +public interface ISyncFilter { + public Task<SyncResponse> Apply(SyncResponse syncResponse); +} \ No newline at end of file diff --git a/MxApiExtensions/Classes/SyncState.cs b/MxApiExtensions/Classes/SyncState.cs index e44d35c..733f26d 100644 --- a/MxApiExtensions/Classes/SyncState.cs +++ b/MxApiExtensions/Classes/SyncState.cs @@ -10,13 +10,12 @@ using Microsoft.OpenApi.Extensions; namespace MxApiExtensions.Classes; public class SyncState { - private Task? _nextSyncResponse; + private Task<HttpResponseMessage>? _nextSyncResponse; public string? NextBatch { get; set; } public ConcurrentQueue<SyncResponse> SyncQueue { get; set; } = new(); - public bool IsInitialSync { get; set; } [JsonIgnore] - public Task? NextSyncResponse { + public Task<HttpResponseMessage>? NextSyncResponse { get => _nextSyncResponse; set { _nextSyncResponse = value; @@ -25,7 +24,7 @@ public class SyncState { } public DateTime NextSyncResponseStartedAt { get; set; } = DateTime.Now; - + [JsonIgnore] public AuthenticatedHomeserverGeneric Homeserver { get; set; } @@ -39,53 +38,39 @@ public class SyncState { NextSyncResponse?.IsFaulted, Status = NextSyncResponse?.Status.GetDisplayName() }; - #endregion - public void SendEphemeralTimelineEventInRoom(string roomId, StateEventResponse @event) { - SyncQueue.Enqueue(new() { - NextBatch = NextBatch ?? "null", - Rooms = new() { - Join = new() { - { - roomId, - new() { - Timeline = new() { - Events = new() { - @event - } - } - } - } - } - } - }); + public SyncResponse SendEphemeralTimelineEventInRoom(string roomId, StateEventResponse @event, SyncResponse? existingResponse = null) { + if(existingResponse is null) + SyncQueue.Enqueue(existingResponse = new()); + existingResponse.Rooms ??= new(); + existingResponse.Rooms.Join ??= new(); + existingResponse.Rooms.Join.TryAdd(roomId, new()); + existingResponse.Rooms.Join[roomId].Timeline ??= new(); + existingResponse.Rooms.Join[roomId].Timeline.Events ??= new(); + existingResponse.Rooms.Join[roomId].Timeline.Events.Add(@event); + return existingResponse; } - public void SendStatusMessage(string text) { - SyncQueue.Enqueue(new() { - NextBatch = NextBatch ?? "null", - Presence = new() { - Events = new() { - new StateEventResponse { - TypedContent = new PresenceEventContent { - DisplayName = "MxApiExtensions", - Presence = "online", - StatusMessage = text, - // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl - AvatarUrl = "", - LastActiveAgo = 15, - CurrentlyActive = true - }, - Type = "m.presence", - StateKey = Homeserver.WhoAmI.UserId, - Sender = Homeserver.WhoAmI.UserId, - EventId = Guid.NewGuid().ToString(), - OriginServerTs = 0 - } - } - } + public SyncResponse SendStatusMessage(string text, SyncResponse? existingResponse = null) { + if(existingResponse is null) + SyncQueue.Enqueue(existingResponse = new()); + existingResponse.Presence ??= new(); + // existingResponse.Presence.Events ??= new(); + existingResponse.Presence.Events.RemoveAll(x => x.Sender == Homeserver.WhoAmI.UserId); + existingResponse.Presence.Events.Add(new StateEventResponse { + TypedContent = new PresenceEventContent { + Presence = "online", + StatusMessage = text, + LastActiveAgo = 15, + CurrentlyActive = true + }, + Type = "m.presence", + StateKey = "", + Sender = Homeserver.WhoAmI.UserId, + OriginServerTs = 0 }); + return existingResponse; } } \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs index 3d1d4e2..b756582 100644 --- a/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs +++ b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs @@ -13,6 +13,6 @@ public class RoomController(ILogger<LoginController> logger, HomeserverResolverS public async Task<Dictionary<string, List<string>>> GetRoomMembersByHomeserver(string _, [FromRoute] string roomId, [FromQuery] bool joinedOnly = true) { var hs = await hsProvider.GetHomeserver(); var room = hs.GetRoom(roomId); - return await room.GetMembersByHomeserverAsync(); + return await room.GetMembersByHomeserverAsync(joinedOnly); } } \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs index 47d9899..e882c8a 100644 --- a/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs +++ b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs @@ -13,23 +13,24 @@ using LibMatrix.Services; using Microsoft.AspNetCore.Mvc; using MxApiExtensions.Classes; using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Extensions; using MxApiExtensions.Services; namespace MxApiExtensions.Controllers; [ApiController] [Route("/")] -public class RoomsSendMessageController(ILogger<LoginController> logger, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf, - AuthenticatedHomeserverProviderService hsProvider) +public class RoomsSendMessageController(ILogger<LoginController> logger, UserContextService userContextService) : ControllerBase { [HttpPut("/_matrix/client/{_}/rooms/{roomId}/send/m.room.message/{txnId}")] public async Task Proxy([FromBody] JsonObject request, [FromRoute] string roomId, [FromRoute] string txnId, string _) { - var hs = await hsProvider.GetHomeserver(); + var uc = await userContextService.GetCurrentUserContext(); + // var hs = await hsProvider.GetHomeserver(); var msg = request.Deserialize<RoomMessageEventContent>(); if (msg is not null && msg.Body.StartsWith("mxae!")) { #pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed - handleMxaeCommand(hs, roomId, msg); + handleMxaeCommand(uc, roomId, msg); #pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed await Response.WriteAsJsonAsync(new EventIdResponse() { EventId = "$" + string.Join("", Random.Shared.GetItems("abcdefghijklmnopqrstuvwxyzABCDEFGHIJLKMNOPQRSTUVWXYZ0123456789".ToCharArray(), 100)) @@ -38,13 +39,14 @@ public class RoomsSendMessageController(ILogger<LoginController> logger, Homeser } else { try { - var resp = await hs.ClientHttpClient.PutAsJsonAsync($"{Request.Path}{Request.QueryString}", request); - var loginResp = await resp.Content.ReadAsStringAsync(); - Response.StatusCode = (int)resp.StatusCode; - Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; - await Response.StartAsync(); - await Response.WriteAsync(loginResp); - await Response.CompleteAsync(); + var resp = await uc.Homeserver.ClientHttpClient.PutAsJsonAsync($"{Request.Path}{Request.QueryString}", request); + await Response.WriteHttpResponse(resp); + // var loginResp = await resp.Content.ReadAsStringAsync(); + // Response.StatusCode = (int)resp.StatusCode; + // Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; + // await Response.StartAsync(); + // await Response.WriteAsync(loginResp); + // await Response.CompleteAsync(); } catch (MatrixException e) { await Response.StartAsync(); @@ -54,10 +56,9 @@ public class RoomsSendMessageController(ILogger<LoginController> logger, Homeser } } - private async Task handleMxaeCommand(AuthenticatedHomeserverGeneric hs, string roomId, RoomMessageEventContent msg) { - var syncState = SyncController.SyncStates.GetValueOrDefault(hs.AccessToken); - if (syncState is null) return; - syncState.SendEphemeralTimelineEventInRoom(roomId, new() { + private async Task handleMxaeCommand(UserContextService.UserContext hs, string roomId, RoomMessageEventContent msg) { + if (hs.SyncState is null) return; + hs.SyncState.SendEphemeralTimelineEventInRoom(roomId, new() { Sender = "@mxae:" + Request.Host.Value, Type = "m.room.message", TypedContent = MessageFormatter.FormatSuccess("Thinking..."), diff --git a/MxApiExtensions/Controllers/Client/SyncController.cs b/MxApiExtensions/Controllers/Client/SyncController.cs index 8a5ba06..7f9ed1d 100644 --- a/MxApiExtensions/Controllers/Client/SyncController.cs +++ b/MxApiExtensions/Controllers/Client/SyncController.cs @@ -4,6 +4,7 @@ using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; using System.Web; +using System.Xml; using ArcaneLibs; using LibMatrix; using LibMatrix.EventTypes.Spec.State; @@ -21,102 +22,76 @@ namespace MxApiExtensions.Controllers; [ApiController] [Route("/")] -public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hsProvider) +public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hsProvider, + UserContextService userContextService) : ControllerBase { - public static readonly ConcurrentDictionary<string, SyncState> SyncStates = new(); - - private static SemaphoreSlim _semaphoreSlim = new(1, 1); + private UserContextService.UserContext userContext; private Stopwatch _syncElapsed = Stopwatch.StartNew(); + private static SemaphoreSlim _semaphoreSlim = new(1, 1); [HttpGet("/_matrix/client/{_}/sync")] public async Task Sync(string _, [FromQuery] string? since, [FromQuery] int timeout = 1000) { + // temporary variables + bool startedNewTask = false; Task? preloadTask = null; - AuthenticatedHomeserverGeneric? hs = null; - try { - hs = await hsProvider.GetHomeserver(); - } - catch (Exception e) { - Console.WriteLine(e); - } + // get user context based on authentication + userContext = await userContextService.GetCurrentUserContext(); var qs = HttpUtility.ParseQueryString(Request.QueryString.Value!); qs.Remove("access_token"); if (since == "null") qs.Remove("since"); - if (!config.FastInitialSync.Enabled) { - logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}"); - await Response.WriteHttpResponse(result); - return; - } + // if (!userContext.UserConfiguration.InitialSyncPreload.Enable) { + // logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + // var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}"); + // await Response.WriteHttpResponse(result); + // return; + // } + //prevent duplicate initialisation await _semaphoreSlim.WaitAsync(); - var syncState = SyncStates.GetOrAdd($"{hs.WhoAmI.UserId}/{hs.WhoAmI.DeviceId}/{hs.ServerName}:{hs.AccessToken}", _ => { - logger.LogInformation("Started tracking sync state for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - var ss = new SyncState { - IsInitialSync = string.IsNullOrWhiteSpace(since), - Homeserver = hs + + //if we don't have a sync state for this user... + if (userContext.SyncState is null) { + logger.LogInformation("Started tracking sync state for {} on {} ({})", userContext.Homeserver.WhoAmI.UserId, userContext.Homeserver.ServerName, + userContext.Homeserver.AccessToken); + + //create a new sync state + userContext.SyncState = new SyncState { + Homeserver = userContext.Homeserver, + NextSyncResponse = Task.Run(async () => { + if (string.IsNullOrWhiteSpace(since) && userContext.UserConfiguration.InitialSyncPreload.Enable) + await Task.Delay(15_000); + logger.LogInformation("Sync for {} on {} ({}) starting", userContext.Homeserver.WhoAmI.UserId, userContext.Homeserver.ServerName, + userContext.Homeserver.AccessToken); + return await userContext.Homeserver.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}"); + }) }; - if (ss.IsInitialSync) { - preloadTask = EnqueuePreloadData(ss); + startedNewTask = true; + + //if this is an initial sync, and the user has enabled this, preload data + if (string.IsNullOrWhiteSpace(since) && userContext.UserConfiguration.InitialSyncPreload.Enable) { + logger.LogInformation("Sync data preload for {} on {} ({}) starting", userContext.Homeserver.WhoAmI.UserId, userContext.Homeserver.ServerName, + userContext.Homeserver.AccessToken); + preloadTask = EnqueuePreloadData(userContext.SyncState); } - - logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - - ss.NextSyncResponseStartedAt = DateTime.Now; - ss.NextSyncResponse = Task.Delay(15_000); - ss.NextSyncResponse.ContinueWith(x => { - logger.LogInformation("Sync for {} on {} ({}) starting", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - ss.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}"); - (ss.NextSyncResponse as Task<HttpResponseMessage>).ContinueWith(async x => EnqueueSyncResponse(ss, await x)); - }); - return ss; - }); - _semaphoreSlim.Release(); - - if (syncState.SyncQueue.Count > 0) { - logger.LogInformation("Sync for {} on {} ({}) has {} queued results", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, syncState.SyncQueue.Count); - syncState.SyncQueue.TryDequeue(out var result); - while (result is null) - syncState.SyncQueue.TryDequeue(out result); - Response.StatusCode = StatusCodes.Status200OK; - Response.ContentType = "application/json"; - await Response.StartAsync(); - result.NextBatch ??= since ?? syncState.NextBatch!; - await JsonSerializer.SerializeAsync(Response.Body, result, new JsonSerializerOptions { - WriteIndented = true, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull - }); - await Response.CompleteAsync(); - return; } - var newTimeout = Math.Clamp(timeout, 0, syncState.IsInitialSync ? syncState.SyncQueue.Count >= 2 ? 0 : 250 : timeout); - logger.LogInformation("Sync for {} on {} ({}) is still running, waiting for {}ms, {} elapsed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, newTimeout, - DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)); - - try { - if (syncState.NextSyncResponse is not null) - await syncState.NextSyncResponse.WaitAsync(TimeSpan.FromMilliseconds(newTimeout)); - else { - syncState.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}"); - (syncState.NextSyncResponse as Task<HttpResponseMessage>)!.ContinueWith(async x => EnqueueSyncResponse(syncState, await x)); - // await Task.Delay(250); - } + if (userContext.SyncState.NextSyncResponse is null) { + userContext.SyncState.NextSyncResponse = userContext.Homeserver.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}"); + startedNewTask = true; } - catch (TimeoutException) { } - // if (_syncElapsed.ElapsedMilliseconds > timeout) - if(syncState.NextSyncResponse?.IsCompleted == false) - syncState.SendStatusMessage( - $"M={Util.BytesToString(Process.GetCurrentProcess().WorkingSet64)} TE={DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} S={syncState.NextSyncResponse?.Status} QL={syncState.SyncQueue.Count}"); + _semaphoreSlim.Release(); + + //get the next sync response + var syncResponse = await GetNextSyncResponse(timeout); + //send it to the client Response.StatusCode = StatusCodes.Status200OK; Response.ContentType = "application/json"; await Response.StartAsync(); - var response = syncState.SyncQueue.FirstOrDefault(); - if (response is null) - response = new(); - response.NextBatch ??= since ?? syncState.NextBatch!; + var response = syncResponse; + response.NextBatch ??= since ?? "null"; await JsonSerializer.SerializeAsync(Response.Body, response, new JsonSerializerOptions { WriteIndented = true, DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull @@ -124,20 +99,117 @@ public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfi await Response.CompleteAsync(); Response.Body.Close(); + + //await scope-local tasks in order to prevent disposal if (preloadTask is not null) { await preloadTask; preloadTask.Dispose(); } + + if (startedNewTask && userContext.SyncState?.NextSyncResponse is not null) { + var resp = await userContext.SyncState.NextSyncResponse; + var sr = await resp.Content.ReadFromJsonAsync<JsonObject>(); + if (sr!.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!; + userContext.SyncState.NextBatch = sr["next_batch"]!.GetValue<string>(); + // userContext.SyncState.IsInitialSync = false; + var syncResp = sr.Deserialize<SyncResponse>(); + userContext.SyncState.SyncQueue.Enqueue(syncResp); + userContext.SyncState.NextSyncResponse.Dispose(); + userContext.SyncState.NextSyncResponse = null; + } + } + + private async Task<SyncResponse> GetNextSyncResponse(int timeout = 0) { + do { + if (userContext.SyncState is null) throw new NullReferenceException("syncState is null!"); + // if (userContext.SyncState.NextSyncResponse is null) throw new NullReferenceException("NextSyncResponse is null"); + + //check if upstream has responded, if so, return upstream response + // if (userContext.SyncState.NextSyncResponse is { IsCompleted: true } syncResponse) { + // var resp = await syncResponse; + // var sr = await resp.Content.ReadFromJsonAsync<JsonObject>(); + // if (sr!.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!; + // userContext.SyncState.NextBatch = sr["next_batch"]!.GetValue<string>(); + // // userContext.SyncState.IsInitialSync = false; + // var syncResp = sr.Deserialize<SyncResponse>(); + // return syncResp; + // } + + //else, return the first item in queue, if any + if (userContext.SyncState.SyncQueue.Count > 0) { + logger.LogInformation("Sync for {} on {} ({}) has {} queued results", userContext.SyncState.Homeserver.WhoAmI.UserId, userContext.SyncState.Homeserver.ServerName, + userContext.SyncState.Homeserver.AccessToken, userContext.SyncState.SyncQueue.Count); + userContext.SyncState.SyncQueue.TryDequeue(out var result); + while (result is null) + userContext.SyncState.SyncQueue.TryDequeue(out result); + return result; + } + + // await Task.Delay(Math.Clamp(timeout, 25, 250)); //wait 25-250ms between checks + await Task.Delay(Math.Clamp(userContextService.SessionCount * 10 ,25, 500)); + } while (_syncElapsed.ElapsedMilliseconds < timeout + 500); //... while we haven't gone >500ms over expected timeout + + //we didn't get a response, send a bogus response + return userContext.SyncState.SendStatusMessage( + $"M={Util.BytesToString(Process.GetCurrentProcess().WorkingSet64)} TE={DateTime.Now.Subtract(userContext.SyncState.NextSyncResponseStartedAt)} QL={userContext.SyncState.SyncQueue.Count}", + new()); } + private async Task EnqueuePreloadData(SyncState syncState) { + await EnqueuePreloadAccountData(syncState); + await EnqueuePreloadRooms(syncState); + } + + private static List<string> CommonAccountDataKeys = new() { + "gay.rory.dm_space", + "im.fluffychat.account_bundles", + "im.ponies.emote_rooms", + "im.vector.analytics", + "im.vector.setting.breadcrumbs", + "im.vector.setting.integration_provisioning", + "im.vector.web.settings", + "io.element.recent_emoji", + "m.cross_signing.master", + "m.cross_signing.self_signing", + "m.cross_signing.user_signing", + "m.direct", + "m.megolm_backup.v1", + "m.push_rules", + "m.secret_storage.default_key", + "gay.rory.mxapiextensions.userconfig" + }; + //enqueue common account data + private async Task EnqueuePreloadAccountData(SyncState syncState) { + var syncMsg = new SyncResponse() { + AccountData = new() { + Events = new() + } + }; + foreach (var key in CommonAccountDataKeys) { + try { + syncMsg.AccountData.Events.Add(new() { + Type = key, + RawContent = await syncState.Homeserver.GetAccountDataAsync<JsonObject>(key) + }); + } + catch {} + } + syncState.SyncQueue.Enqueue(syncMsg); + } + + private async Task EnqueuePreloadRooms(SyncState syncState) { + //get the users's rooms var rooms = await syncState.Homeserver.GetJoinedRooms(); - var dmRooms = (await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct")).Aggregate(new List<string>(), (list, entry) => { - list.AddRange(entry.Value); - return list; - }); + + //get the user's DM rooms + var mDirectContent = await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct"); + var dmRooms = mDirectContent.SelectMany(pair => pair.Value); + //get our own homeserver's server_name var ownHs = syncState.Homeserver.WhoAmI!.UserId!.Split(':')[1]; + + //order rooms by expected state size, since large rooms take a long time to return rooms = rooms.OrderBy(x => { if (dmRooms.Contains(x.RoomId)) return -1; var parts = x.RoomId.Split(':'); @@ -145,38 +217,35 @@ public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfi if (HomeserverWeightEstimation.EstimatedSize.ContainsKey(parts[1])) return HomeserverWeightEstimation.EstimatedSize[parts[1]] + parts[0].Length; return 5000; }).ToList(); + + //start all fetch tasks var roomDataTasks = rooms.Select(room => EnqueueRoomData(syncState, room)).ToList(); logger.LogInformation("Preloading data for {} rooms on {} ({})", roomDataTasks.Count, syncState.Homeserver.ServerName, syncState.Homeserver.AccessToken); + //wait for all of them to finish await Task.WhenAll(roomDataTasks); } - private readonly SemaphoreSlim _roomDataSemaphore = new(32, 32); + private static readonly SemaphoreSlim _roomDataSemaphore = new(4, 4); private async Task EnqueueRoomData(SyncState syncState, GenericRoom room) { + //limit concurrent requests, to not overload upstream await _roomDataSemaphore.WaitAsync(); + //get the room's state var roomState = room.GetFullStateAsync(); + //get the room's timeline, reversed var timeline = await room.GetMessagesAsync(limit: 100, dir: "b"); timeline.Chunk.Reverse(); + //queue up this data as a sync response var syncResponse = new SyncResponse { Rooms = new() { Join = new() { { room.RoomId, new SyncResponse.RoomsDataStructure.JoinedRoomDataStructure { - AccountData = new() { - Events = new() - }, - Ephemeral = new() { - Events = new() - }, State = new() { Events = timeline.State }, - UnreadNotifications = new() { - HighlightCount = 0, - NotificationCount = 0 - }, Timeline = new() { Events = timeline.Chunk, Limited = false, @@ -190,57 +259,24 @@ public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfi } } } - }, - Presence = new() { - Events = new() { - await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status} {room.RoomId}") - } - }, - NextBatch = "" + } }; + //calculate invited/joined member counts, and add other events to state await foreach (var stateEvent in roomState) { + if (stateEvent is { Type: "m.room.member" }) { + if (stateEvent.TypedContent is RoomMemberEventContent { Membership: "join" }) + syncResponse.Rooms.Join[room.RoomId].Summary.JoinedMemberCount++; + else if (stateEvent.TypedContent is RoomMemberEventContent { Membership: "invite" }) + syncResponse.Rooms.Join[room.RoomId].Summary.InvitedMemberCount++; + else continue; + } + syncResponse.Rooms.Join[room.RoomId].State!.Events!.Add(stateEvent!); } - var joinRoom = syncResponse.Rooms.Join[room.RoomId]; - joinRoom.Summary!.Heroes.AddRange(joinRoom.State!.Events! - .Where(x => - x.Type == "m.room.member" - && x.StateKey != syncState.Homeserver.WhoAmI!.UserId - && (x.TypedContent as RoomMemberEventContent)!.Membership == "join" - ) - .Select(x => x.StateKey)); - joinRoom.Summary.JoinedMemberCount = joinRoom.Summary.Heroes.Count; - + //finally, actually put the response in queue syncState.SyncQueue.Enqueue(syncResponse); _roomDataSemaphore.Release(); } - - private async Task<StateEventResponse> GetStatusMessage(SyncState syncState, string message) { - return new StateEventResponse { - TypedContent = new PresenceEventContent { - DisplayName = "MxApiExtensions", - Presence = "online", - StatusMessage = message, - // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl - AvatarUrl = "" - }, - Type = "m.presence", - StateKey = syncState.Homeserver.WhoAmI!.UserId!, - Sender = syncState.Homeserver.WhoAmI!.UserId!, - EventId = Guid.NewGuid().ToString(), - OriginServerTs = 0 - }; - } - - private async Task EnqueueSyncResponse(SyncState ss, HttpResponseMessage task) { - var sr = await task.Content.ReadFromJsonAsync<JsonObject>(); - if (sr!.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!; - ss.NextBatch = sr["next_batch"]!.GetValue<string>(); - ss.IsInitialSync = false; - ss.SyncQueue.Enqueue(sr.Deserialize<SyncResponse>()!); - task.Dispose(); - ss.NextSyncResponse = null; - } } \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Extensions/DebugController.cs b/MxApiExtensions/Controllers/Extensions/DebugController.cs index c65df56..ae9ecc5 100644 --- a/MxApiExtensions/Controllers/Extensions/DebugController.cs +++ b/MxApiExtensions/Controllers/Extensions/DebugController.cs @@ -7,24 +7,17 @@ namespace MxApiExtensions.Controllers.Extensions; [ApiController] [Route("/")] -public class DebugController : ControllerBase { - private readonly ILogger _logger; - private readonly MxApiExtensionsConfiguration _config; - private readonly AuthenticationService _authenticationService; +public class DebugController(ILogger<ProxyConfigurationController> logger, MxApiExtensionsConfiguration config, UserContextService userContextService) + : ControllerBase { + private readonly ILogger _logger = logger; private static ConcurrentDictionary<string, RoomInfoEntry> _roomInfoCache = new(); - public DebugController(ILogger<ProxyConfigurationController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, - AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) { - _logger = logger; - _config = config; - _authenticationService = authenticationService; - } - [HttpGet("debug")] public async Task<object?> GetDebug() { - var mxid = await _authenticationService.GetMxidFromToken(); - if(!_config.Admins.Contains(mxid)) { + var user = await userContextService.GetCurrentUserContext(); + var mxid = user.Homeserver.UserId; + if(!config.Admins.Contains(mxid)) { _logger.LogWarning("Got debug request for {user}, but they are not an admin", mxid); Response.StatusCode = StatusCodes.Status403Forbidden; Response.ContentType = "application/json"; @@ -38,8 +31,6 @@ public class DebugController : ControllerBase { } _logger.LogInformation("Got debug request for {user}", mxid); - return new { - SyncStates = SyncController.SyncStates - }; + return UserContextService.UserContextStore; } } diff --git a/MxApiExtensions/Controllers/Other/MediaProxyController.cs b/MxApiExtensions/Controllers/Other/MediaProxyController.cs index 03b68ba..fb40aa2 100644 --- a/MxApiExtensions/Controllers/Other/MediaProxyController.cs +++ b/MxApiExtensions/Controllers/Other/MediaProxyController.cs @@ -23,7 +23,7 @@ public class MediaProxyController(ILogger<GenericController> logger, MxApiExtens private static SemaphoreSlim _semaphore = new(1, 1); [HttpGet("/_matrix/media/{_}/download/{serverName}/{mediaId}")] - public async Task Proxy(string? _, string serverName, string mediaId) { + public async Task ProxyMedia(string? _, string serverName, string mediaId) { try { logger.LogInformation("Proxying media: {}{}", serverName, mediaId); @@ -75,4 +75,7 @@ public class MediaProxyController(ILogger<GenericController> logger, MxApiExtens await Response.CompleteAsync(); } } + + [HttpGet("/_matrix/media/{_}/thumbnail/{serverName}/{mediaId}")] + public async Task ProxyThumbnail(string? _, string serverName, string mediaId) => await ProxyMedia(_, serverName, mediaId); } diff --git a/MxApiExtensions/MxApiExtensionsConfiguration.cs b/MxApiExtensions/MxApiExtensionsConfiguration.cs index c3b6297..8069e81 100644 --- a/MxApiExtensions/MxApiExtensionsConfiguration.cs +++ b/MxApiExtensions/MxApiExtensionsConfiguration.cs @@ -1,8 +1,12 @@ +using ArcaneLibs.Extensions; +using MxApiExtensions.Classes; + namespace MxApiExtensions; public class MxApiExtensionsConfiguration { public MxApiExtensionsConfiguration(IConfiguration config) { config.GetRequiredSection("MxApiExtensions").Bind(this); + if (DefaultUserConfiguration is null) throw new ArgumentNullException(nameof(DefaultUserConfiguration), $"Default user configuration not configured! Example: {new MxApiExtensionsUserConfiguration().ToJson()}"); } public List<string> AuthHomeservers { get; set; } = new(); @@ -11,7 +15,7 @@ public class MxApiExtensionsConfiguration { public FastInitialSyncConfiguration FastInitialSync { get; set; } = new(); public CacheConfiguration Cache { get; set; } = new(); - + public MxApiExtensionsUserConfiguration DefaultUserConfiguration { get; set; } public class FastInitialSyncConfiguration { public bool Enabled { get; set; } = true; @@ -26,4 +30,5 @@ public class MxApiExtensionsConfiguration { public TimeSpan ExtraTtlPerState { get; set; } = TimeSpan.FromMilliseconds(100); } } + } diff --git a/MxApiExtensions/Program.cs b/MxApiExtensions/Program.cs index 72a5dc9..21d8ba4 100644 --- a/MxApiExtensions/Program.cs +++ b/MxApiExtensions/Program.cs @@ -3,7 +3,9 @@ using LibMatrix; using LibMatrix.Services; using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Http.Timeouts; +using Microsoft.Extensions.Logging.Console; using MxApiExtensions; +using MxApiExtensions.Classes; using MxApiExtensions.Classes.LibMatrix; using MxApiExtensions.Services; @@ -22,6 +24,7 @@ builder.Services.AddSingleton<MxApiExtensionsConfiguration>(); builder.Services.AddScoped<AuthenticationService>(); builder.Services.AddScoped<AuthenticatedHomeserverProviderService>(); +builder.Services.AddScoped<UserContextService>(); builder.Services.AddSingleton<TieredStorageService>(x => { var config = x.GetRequiredService<MxApiExtensionsConfiguration>(); @@ -54,6 +57,7 @@ builder.Services.AddCors(options => { "Open", policy => policy.AllowAnyOrigin().AllowAnyHeader()); }); +// builder.Logging.AddConsole(x => x.FormatterName = "custom").AddConsoleFormatter<CustomLogFormatter, SimpleConsoleFormatterOptions>(); var app = builder.Build(); diff --git a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs index e0f9db5..741beb3 100644 --- a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs +++ b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs @@ -6,6 +6,7 @@ using MxApiExtensions.Classes.LibMatrix; namespace MxApiExtensions.Services; public class AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService, IHttpContextAccessor request) { + public HttpContext? _context = request.HttpContext; public async Task<RemoteHomeserver?> TryGetRemoteHomeserver() { try { return await GetRemoteHomeserver(); @@ -21,7 +22,8 @@ public class AuthenticatedHomeserverProviderService(AuthenticationService authen } catch (MxApiMatrixException e) { if (e is not { ErrorCode: "M_MISSING_TOKEN" }) throw; - if (!request.HttpContext!.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM")) + if (request is null) throw new MxApiMatrixException() { ErrorCode = "M_UNKNOWN", Error = "[MxApiExtensions] Request was null for unauthenticated request!" }; + if (!_context.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM")) throw new MxApiMatrixException() { ErrorCode = "MXAE_MISSING_UPSTREAM", Error = "[MxApiExtensions] Missing MXAE_UPSTREAM header for unauthenticated request, this should be a server_name!" diff --git a/MxApiExtensions/Services/AuthenticationService.cs b/MxApiExtensions/Services/AuthenticationService.cs index 0dcc8b1..7430fcd 100644 --- a/MxApiExtensions/Services/AuthenticationService.cs +++ b/MxApiExtensions/Services/AuthenticationService.cs @@ -1,3 +1,5 @@ +using ArcaneLibs.Extensions; +using LibMatrix; using LibMatrix.Services; using MxApiExtensions.Classes.LibMatrix; @@ -17,7 +19,7 @@ public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiE token = _request.Query["access_token"]; } - if (token == null && fail) { + if (string.IsNullOrWhiteSpace(token) && fail) { throw new MxApiMatrixException { ErrorCode = "M_MISSING_TOKEN", Error = "Missing access token" @@ -29,7 +31,7 @@ public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiE public async Task<string> GetMxidFromToken(string? token = null, bool fail = true) { token ??= GetToken(fail); - if (token == null) { + if (string.IsNullOrWhiteSpace(token)) { if (fail) { throw new MxApiMatrixException { ErrorCode = "M_MISSING_TOKEN", @@ -44,15 +46,34 @@ public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiE _tokenMap = (await File.ReadAllLinesAsync("token_map")) .Select(l => l.Split('\t')) .ToDictionary(l => l[0], l => l[1]); + + //THIS IS BROKEN, DO NOT USE! + // foreach (var (mapToken, mapUser) in _tokenMap) { + // try { + // var hs = await homeserverProviderService.GetAuthenticatedWithToken(mapUser.Split(':', 2)[1], mapToken); + // } + // catch (MatrixException e) { + // if (e is { ErrorCode: "M_UNKNOWN_TOKEN" }) _tokenMap[mapToken] = ""; + // } + // catch { + // // ignored + // } + // } + // _tokenMap.RemoveAll((x, y) => string.IsNullOrWhiteSpace(y)); + // await File.WriteAllTextAsync("token_map", _tokenMap.Aggregate("", (x, y) => $"{y.Key}\t{y.Value}\n")); } + if (_tokenMap.TryGetValue(token, out var mxid)) return mxid; var lookupTasks = new Dictionary<string, Task<string?>>(); foreach (var homeserver in config.AuthHomeservers) { - lookupTasks.Add(homeserver, GetMxidFromToken(token, homeserver)); - await lookupTasks[homeserver].WaitAsync(TimeSpan.FromMilliseconds(250)); - if(lookupTasks[homeserver].IsCompletedSuccessfully && !string.IsNullOrWhiteSpace(lookupTasks[homeserver].Result)) break; + try { + lookupTasks.Add(homeserver, GetMxidFromToken(token, homeserver)); + await lookupTasks[homeserver].WaitAsync(TimeSpan.FromMilliseconds(500)); + if (lookupTasks[homeserver].IsCompletedSuccessfully && !string.IsNullOrWhiteSpace(lookupTasks[homeserver].Result)) break; + } + catch {} } await Task.WhenAll(lookupTasks.Values); diff --git a/MxApiExtensions/Services/UserContextService.cs b/MxApiExtensions/Services/UserContextService.cs new file mode 100644 index 0000000..ef19ced --- /dev/null +++ b/MxApiExtensions/Services/UserContextService.cs @@ -0,0 +1,44 @@ +using System.Collections.Concurrent; +using System.Text.Json.Serialization; +using ArcaneLibs.Extensions; +using LibMatrix; +using LibMatrix.Homeservers; +using MxApiExtensions.Classes; + +namespace MxApiExtensions.Services; + +public class UserContextService(MxApiExtensionsConfiguration config, AuthenticatedHomeserverProviderService hsProvider) { + internal static ConcurrentDictionary<string, UserContext> UserContextStore { get; set; } = new(); + public int SessionCount = UserContextStore.Count; + + public class UserContext { + public SyncState? SyncState { get; set; } + [JsonIgnore] + public AuthenticatedHomeserverGeneric Homeserver { get; set; } + public MxApiExtensionsUserConfiguration UserConfiguration { get; set; } + } + + private readonly SemaphoreSlim _getUserContextSemaphore = new SemaphoreSlim(1, 1); + public async Task<UserContext> GetCurrentUserContext() { + var hs = await hsProvider.GetHomeserver(); + // await _getUserContextSemaphore.WaitAsync(); + var ucs = await UserContextStore.GetOrCreateAsync($"{hs.WhoAmI.UserId}/{hs.WhoAmI.DeviceId}/{hs.ServerName}:{hs.AccessToken}", async x => { + var userContext = new UserContext() { + Homeserver = hs + }; + try { + userContext.UserConfiguration = await hs.GetAccountDataAsync<MxApiExtensionsUserConfiguration>(MxApiExtensionsUserConfiguration.EventId); + } + catch (MatrixException e) { + if (e is not { ErrorCode: "M_NOT_FOUND" }) throw; + userContext.UserConfiguration = config.DefaultUserConfiguration; + } + + await hs.SetAccountDataAsync(MxApiExtensionsUserConfiguration.EventId, userContext.UserConfiguration); + + return userContext; + }, _getUserContextSemaphore); + // _getUserContextSemaphore.Release(); + return ucs; + } +} \ No newline at end of file diff --git a/MxApiExtensions/appsettings.json b/MxApiExtensions/appsettings.json index 5a73cbe..b16968a 100644 --- a/MxApiExtensions/appsettings.json +++ b/MxApiExtensions/appsettings.json @@ -29,6 +29,13 @@ "BaseTtl": "00:01:00", "ExtraTtlPerState": "00:00:00.1000000" } + }, + "DefaultUserConfiguration": { + "ProtocolChanges": { + "DisableThreads": false, + "DisableVoip": false, + "AutoFollowTombstones": false + } } } } |