summary refs log tree commit diff
path: root/MxApiExtensions/Controllers
diff options
context:
space:
mode:
authorTheArcaneBrony <myrainbowdash949@gmail.com>2023-11-05 17:59:38 +0100
committerTheArcaneBrony <myrainbowdash949@gmail.com>2023-11-05 17:59:38 +0100
commit2abb132234546e61bb0aff3897dc49e72ea84f5d (patch)
treec885c03d35e7a0a6b8fc21bd0b259216c61c877c /MxApiExtensions/Controllers
parentUpdate (diff)
downloadMxApiExtensions-2abb132234546e61bb0aff3897dc49e72ea84f5d.tar.xz
Working sync proxy
Diffstat (limited to 'MxApiExtensions/Controllers')
-rw-r--r--MxApiExtensions/Controllers/Client/ClientVersionsController.cs (renamed from MxApiExtensions/Controllers/ClientVersionsController.cs)0
-rw-r--r--MxApiExtensions/Controllers/Client/LoginController.cs73
-rw-r--r--MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs18
-rw-r--r--MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs72
-rw-r--r--MxApiExtensions/Controllers/Client/SyncController.cs243
-rw-r--r--MxApiExtensions/Controllers/Extensions/DebugController.cs45
-rw-r--r--MxApiExtensions/Controllers/LoginController.cs70
-rw-r--r--MxApiExtensions/Controllers/Other/GenericProxyController.cs (renamed from MxApiExtensions/Controllers/GenericProxyController.cs)54
-rw-r--r--MxApiExtensions/Controllers/Other/MediaProxyController.cs78
-rw-r--r--MxApiExtensions/Controllers/Other/WellKnownController.cs (renamed from MxApiExtensions/Controllers/WellKnownController.cs)10
-rw-r--r--MxApiExtensions/Controllers/SyncController.cs274
11 files changed, 553 insertions, 384 deletions
diff --git a/MxApiExtensions/Controllers/ClientVersionsController.cs b/MxApiExtensions/Controllers/Client/ClientVersionsController.cs

index d29e3b2..d29e3b2 100644 --- a/MxApiExtensions/Controllers/ClientVersionsController.cs +++ b/MxApiExtensions/Controllers/Client/ClientVersionsController.cs
diff --git a/MxApiExtensions/Controllers/Client/LoginController.cs b/MxApiExtensions/Controllers/Client/LoginController.cs new file mode 100644
index 0000000..009aaef --- /dev/null +++ b/MxApiExtensions/Controllers/Client/LoginController.cs
@@ -0,0 +1,73 @@ +using System.Net.Http.Headers; +using ArcaneLibs.Extensions; +using LibMatrix; +using LibMatrix.Extensions; +using LibMatrix.Responses; +using LibMatrix.Services; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers; + +[ApiController] +[Route("/")] +public class LoginController(ILogger<LoginController> logger, HomeserverProviderService hsProvider, HomeserverResolverService hsResolver, AuthenticationService auth, + MxApiExtensionsConfiguration conf) + : ControllerBase { + private readonly ILogger _logger = logger; + private readonly HomeserverProviderService _hsProvider = hsProvider; + private readonly MxApiExtensionsConfiguration _conf = conf; + + [HttpPost("/_matrix/client/{_}/login")] + public async Task Proxy([FromBody] LoginRequest request, string _) { + string hsCanonical = null; + if (Request.Headers.Keys.Any(x => x.ToUpper() == "MXAE_UPSTREAM")) { + hsCanonical = Request.Headers.GetByCaseInsensitiveKey("MXAE_UPSTREAM")[0]!; + _logger.LogInformation("Found upstream: {}", hsCanonical); + } + else { + if (!request.Identifier.User.Contains("#")) { + Response.StatusCode = (int)StatusCodes.Status403Forbidden; + Response.ContentType = "application/json"; + await Response.StartAsync(); + await Response.WriteAsync(new MxApiMatrixException { + ErrorCode = "M_FORBIDDEN", + Error = "[MxApiExtensions] Invalid username, must be of the form @user#domain:" + Request.Host.Value + }.GetAsJson() ?? ""); + await Response.CompleteAsync(); + } + + hsCanonical = request.Identifier.User.Split('#')[1].Split(':')[0]; + request.Identifier.User = request.Identifier.User.Split(':')[0].Replace('#', ':'); + if (!request.Identifier.User.StartsWith('@')) request.Identifier.User = '@' + request.Identifier.User; + } + + var hs = await hsResolver.ResolveHomeserverFromWellKnown(hsCanonical); + //var hs = await _hsProvider.Login(hsCanonical, mxid, request.Password); + var hsClient = new MatrixHttpClient { BaseAddress = new Uri(hs.Client) }; + //hsClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", hsClient.DefaultRequestHeaders.Authorization!.Parameter); + if (!string.IsNullOrWhiteSpace(request.InitialDeviceDisplayName)) + request.InitialDeviceDisplayName += $" (via MxApiExtensions at {Request.Host.Value})"; + var resp = await hsClient.PostAsJsonAsync("/_matrix/client/r0/login", request); + var loginResp = await resp.Content.ReadAsStringAsync(); + Response.StatusCode = (int)resp.StatusCode; + Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; + await Response.StartAsync(); + await Response.WriteAsync(loginResp); + await Response.CompleteAsync(); + var token = (await resp.Content.ReadFromJsonAsync<LoginResponse>())!.AccessToken; + await auth.SaveMxidForToken(token, request.Identifier.User); + } + + [HttpGet("/_matrix/client/{_}/login")] + public async Task<object> Proxy(string? _) { + return new { + flows = new[] { + new { + type = "m.login.password" + } + } + }; + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs new file mode 100644
index 0000000..3d1d4e2 --- /dev/null +++ b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs
@@ -0,0 +1,18 @@ +using LibMatrix.Services; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers.Client.Room; + +[ApiController] +[Route("/")] +public class RoomController(ILogger<LoginController> logger, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf, + AuthenticatedHomeserverProviderService hsProvider) + : ControllerBase { + [HttpGet("/_matrix/client/{_}/rooms/{roomId}/members_by_homeserver")] + public async Task<Dictionary<string, List<string>>> GetRoomMembersByHomeserver(string _, [FromRoute] string roomId, [FromQuery] bool joinedOnly = true) { + var hs = await hsProvider.GetHomeserver(); + var room = hs.GetRoom(roomId); + return await room.GetMembersByHomeserverAsync(); + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs new file mode 100644
index 0000000..6d3a774 --- /dev/null +++ b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs
@@ -0,0 +1,72 @@ +using System.Buffers.Text; +using System.Net.Http.Headers; +using System.Text.Json; +using System.Text.Json.Nodes; +using ArcaneLibs.Extensions; +using LibMatrix; +using LibMatrix.EventTypes.Spec; +using LibMatrix.Extensions; +using LibMatrix.Helpers; +using LibMatrix.Homeservers; +using LibMatrix.Responses; +using LibMatrix.Services; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers; + +[ApiController] +[Route("/")] +public class RoomsSendMessageController(ILogger<LoginController> logger, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf, + AuthenticatedHomeserverProviderService hsProvider) + : ControllerBase { + [HttpPut("/_matrix/client/{_}/rooms/{roomId}/send/m.room.message/{txnId}")] + public async Task Proxy([FromBody] JsonObject request, [FromRoute] string roomId, [FromRoute] string txnId, string _) { + var hs = await hsProvider.GetHomeserver(); + + var msg = request.Deserialize<RoomMessageEventContent>(); + if (msg is not null && msg.Body.StartsWith("mxae!")) { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + handleMxaeCommand(hs, roomId, msg); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + await Response.WriteAsJsonAsync(new EventIdResponse() { + EventId = "$" + string.Join("", Random.Shared.GetItems("abcdefghijklmnopqrstuvwxyzABCDEFGHIJLKMNOPQRSTUVWXYZ0123456789".ToCharArray(), 100)) + }); + await Response.CompleteAsync(); + } + else { + try { + var resp = await hs.ClientHttpClient.PutAsJsonAsync($"{Request.Path}{Request.QueryString}", request); + var loginResp = await resp.Content.ReadAsStringAsync(); + Response.StatusCode = (int)resp.StatusCode; + Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; + await Response.StartAsync(); + await Response.WriteAsync(loginResp); + await Response.CompleteAsync(); + } + catch (MatrixException e) { + await Response.StartAsync(); + await Response.WriteAsync(e.GetAsJson()); + await Response.CompleteAsync(); + } + } + } + + private async Task handleMxaeCommand(AuthenticatedHomeserverGeneric hs, string roomId, RoomMessageEventContent msg) { + var syncState = SyncController._syncStates.GetValueOrDefault(hs.AccessToken); + if (syncState is null) return; + syncState.SendEphemeralTimelineEventInRoom(roomId, new() { + Sender = "@mxae:" + Request.Host.Value, + Type = "m.room.message", + TypedContent = MessageFormatter.FormatSuccess("Thinking..."), + OriginServerTs = (ulong)new DateTimeOffset(DateTime.UtcNow.ToUniversalTime()).ToUnixTimeMilliseconds(), + Unsigned = new() { + Age = 1 + }, + RoomId = roomId, + EventId = "$" + string.Join("", Random.Shared.GetItems("abcdefghijklmnopqrstuvwxyzABCDEFGHIJLKMNOPQRSTUVWXYZ0123456789".ToCharArray(), 100)) + }); + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Client/SyncController.cs b/MxApiExtensions/Controllers/Client/SyncController.cs new file mode 100644
index 0000000..2944c3b --- /dev/null +++ b/MxApiExtensions/Controllers/Client/SyncController.cs
@@ -0,0 +1,243 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Web; +using ArcaneLibs; +using LibMatrix; +using LibMatrix.EventTypes.Spec.State; +using LibMatrix.Helpers; +using LibMatrix.Homeservers; +using LibMatrix.Responses; +using LibMatrix.RoomTypes; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Extensions; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers; + +[ApiController] +[Route("/")] +public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hsProvider) + : ControllerBase { + public static readonly ConcurrentDictionary<string, SyncState> _syncStates = new(); + + private static SemaphoreSlim _semaphoreSlim = new(1, 1); + private Stopwatch _syncElapsed = Stopwatch.StartNew(); + + [HttpGet("/_matrix/client/{_}/sync")] + public async Task Sync(string _, [FromQuery] string? since, [FromQuery] int timeout = 1000) { + Task? preloadTask = null; + AuthenticatedHomeserverGeneric? hs = null; + try { + hs = await hsProvider.GetHomeserver(); + } + catch (Exception e) { + Console.WriteLine(e); + } + + var qs = HttpUtility.ParseQueryString(Request.QueryString.Value!); + qs.Remove("access_token"); + if (since == "null") qs.Remove("since"); + + if (!config.FastInitialSync.Enabled) { + logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}"); + await Response.WriteHttpResponse(result); + return; + } + + await _semaphoreSlim.WaitAsync(); + var syncState = _syncStates.GetOrAdd($"{hs.WhoAmI.UserId}/{hs.WhoAmI.DeviceId}/{hs.ServerName}:{hs.AccessToken}", _ => { + logger.LogInformation("Started tracking sync state for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + var ss = new SyncState { + IsInitialSync = string.IsNullOrWhiteSpace(since), + Homeserver = hs + }; + if (ss.IsInitialSync) { + preloadTask = EnqueuePreloadData(ss); + } + + logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + + ss.NextSyncResponseStartedAt = DateTime.Now; + ss.NextSyncResponse = Task.Delay(15_000); + ss.NextSyncResponse.ContinueWith(async x => { + logger.LogInformation("Sync for {} on {} ({}) starting", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + ss.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}"); + (ss.NextSyncResponse as Task<HttpResponseMessage>).ContinueWith(async x => EnqueueSyncResponse(ss, await x)); + }); + return ss; + }); + _semaphoreSlim.Release(); + + if (syncState.SyncQueue.Count > 0) { + logger.LogInformation("Sync for {} on {} ({}) has {} queued results", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, syncState.SyncQueue.Count); + syncState.SyncQueue.TryDequeue(out var result); + while (result is null) + syncState.SyncQueue.TryDequeue(out result); + Response.StatusCode = StatusCodes.Status200OK; + Response.ContentType = "application/json"; + await Response.StartAsync(); + result.NextBatch ??= since ?? syncState.NextBatch; + await JsonSerializer.SerializeAsync(Response.Body, result, new JsonSerializerOptions { + WriteIndented = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }); + await Response.CompleteAsync(); + return; + } + + var newTimeout = Math.Clamp(timeout, 0, syncState.IsInitialSync ? syncState.SyncQueue.Count >= 2 ? 0 : 250 : timeout); + logger.LogInformation("Sync for {} on {} ({}) is still running, waiting for {}ms, {} elapsed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, newTimeout, + DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)); + + try { + if (syncState.NextSyncResponse is not null) + await syncState.NextSyncResponse.WaitAsync(TimeSpan.FromMilliseconds(newTimeout)); + else { + syncState.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}"); + (syncState.NextSyncResponse as Task<HttpResponseMessage>).ContinueWith(async x => EnqueueSyncResponse(syncState, await x)); + // await Task.Delay(250); + } + } + catch (TimeoutException) { } + + // if (_syncElapsed.ElapsedMilliseconds > timeout) + if(syncState.NextSyncResponse?.IsCompleted == false) + syncState.SendStatusMessage( + $"M={Util.BytesToString(Process.GetCurrentProcess().WorkingSet64)} TE={DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} S={syncState.NextSyncResponse?.Status} QL={syncState.SyncQueue.Count}"); + Response.StatusCode = StatusCodes.Status200OK; + Response.ContentType = "application/json"; + await Response.StartAsync(); + var response = syncState.SyncQueue.FirstOrDefault(); + if (response is null) + response = new(); + response.NextBatch ??= since ?? syncState.NextBatch; + await JsonSerializer.SerializeAsync(Response.Body, response, new JsonSerializerOptions { + WriteIndented = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }); + await Response.CompleteAsync(); + + Response.Body.Close(); + if (preloadTask is not null) + await preloadTask; + } + + private async Task EnqueuePreloadData(SyncState syncState) { + var rooms = await syncState.Homeserver.GetJoinedRooms(); + var dm_rooms = (await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct")).Aggregate(new List<string>(), (list, entry) => { + list.AddRange(entry.Value); + return list; + }); + + var ownHs = syncState.Homeserver.WhoAmI.UserId.Split(':')[1]; + rooms = rooms.OrderBy(x => { + if (dm_rooms.Contains(x.RoomId)) return -1; + var parts = x.RoomId.Split(':'); + if (parts[1] == ownHs) return 200; + if (HomeserverWeightEstimation.EstimatedSize.ContainsKey(parts[1])) return HomeserverWeightEstimation.EstimatedSize[parts[1]] + parts[0].Length; + return 5000; + }).ToList(); + var roomDataTasks = rooms.Select(room => EnqueueRoomData(syncState, room)).ToList(); + logger.LogInformation("Preloading data for {} rooms on {} ({})", roomDataTasks.Count, syncState.Homeserver.ServerName, syncState.Homeserver.AccessToken); + + await Task.WhenAll(roomDataTasks); + } + + private SemaphoreSlim _roomDataSemaphore = new(32, 32); + + private async Task EnqueueRoomData(SyncState syncState, GenericRoom room) { + await _roomDataSemaphore.WaitAsync(); + var roomState = room.GetFullStateAsync(); + var timeline = await room.GetMessagesAsync(limit: 100, dir: "b"); + timeline.Chunk.Reverse(); + var SyncResponse = new SyncResponse { + Rooms = new() { + Join = new() { + { + room.RoomId, + new SyncResponse.RoomsDataStructure.JoinedRoomDataStructure { + AccountData = new() { + Events = new() + }, + Ephemeral = new() { + Events = new() + }, + State = new() { + Events = timeline.State + }, + UnreadNotifications = new() { + HighlightCount = 0, + NotificationCount = 0 + }, + Timeline = new() { + Events = timeline.Chunk, + Limited = false, + PrevBatch = timeline.Start + }, + Summary = new() { + Heroes = new(), + InvitedMemberCount = 0, + JoinedMemberCount = 1 + } + } + } + } + }, + Presence = new() { + Events = new() { + await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status} {room.RoomId}") + } + }, + NextBatch = "" + }; + + await foreach (var stateEvent in roomState) { + SyncResponse.Rooms.Join[room.RoomId].State.Events.Add(stateEvent); + } + + var joinRoom = SyncResponse.Rooms.Join[room.RoomId]; + joinRoom.Summary.Heroes.AddRange(joinRoom.State.Events + .Where(x => + x.Type == "m.room.member" + && x.StateKey != syncState.Homeserver.WhoAmI.UserId + && (x.TypedContent as RoomMemberEventContent).Membership == "join" + ) + .Select(x => x.StateKey)); + joinRoom.Summary.JoinedMemberCount = joinRoom.Summary.Heroes.Count; + + syncState.SyncQueue.Enqueue(SyncResponse); + _roomDataSemaphore.Release(); + } + + private async Task<StateEventResponse> GetStatusMessage(SyncState syncState, string message) { + return new StateEventResponse { + TypedContent = new PresenceEventContent { + DisplayName = "MxApiExtensions", + Presence = "online", + StatusMessage = message, + // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl + AvatarUrl = "" + }, + Type = "m.presence", + StateKey = syncState.Homeserver.WhoAmI.UserId, + Sender = syncState.Homeserver.WhoAmI.UserId, + EventId = Guid.NewGuid().ToString(), + OriginServerTs = 0 + }; + } + + private async Task EnqueueSyncResponse(SyncState ss, HttpResponseMessage task) { + var sr = await task.Content.ReadFromJsonAsync<JsonObject>(); + if (sr.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!; + ss.NextBatch = sr["next_batch"].GetValue<string>(); + ss.IsInitialSync = false; + ss.SyncQueue.Enqueue(sr.Deserialize<SyncResponse>()); + ss.NextSyncResponse = null; + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Extensions/DebugController.cs b/MxApiExtensions/Controllers/Extensions/DebugController.cs new file mode 100644
index 0000000..79ed2f0 --- /dev/null +++ b/MxApiExtensions/Controllers/Extensions/DebugController.cs
@@ -0,0 +1,45 @@ +using System.Collections.Concurrent; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers.Extensions; + +[ApiController] +[Route("/")] +public class DebugController : ControllerBase { + private readonly ILogger _logger; + private readonly MxApiExtensionsConfiguration _config; + private readonly AuthenticationService _authenticationService; + + private static ConcurrentDictionary<string, RoomInfoEntry> _roomInfoCache = new(); + + public DebugController(ILogger<ProxyConfigurationController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, + AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) { + _logger = logger; + _config = config; + _authenticationService = authenticationService; + } + + [HttpGet("debug")] + public async Task<object?> GetDebug() { + var mxid = await _authenticationService.GetMxidFromToken(); + if(!_config.Admins.Contains(mxid)) { + _logger.LogWarning("Got debug request for {user}, but they are not an admin", mxid); + Response.StatusCode = StatusCodes.Status403Forbidden; + Response.ContentType = "application/json"; + + await Response.WriteAsJsonAsync(new { + ErrorCode = "M_FORBIDDEN", + Error = "You are not an admin" + }); + await Response.CompleteAsync(); + return null; + } + + _logger.LogInformation("Got debug request for {user}", mxid); + return new { + SyncStates = SyncController._syncStates + }; + } +} diff --git a/MxApiExtensions/Controllers/LoginController.cs b/MxApiExtensions/Controllers/LoginController.cs deleted file mode 100644
index bd354ef..0000000 --- a/MxApiExtensions/Controllers/LoginController.cs +++ /dev/null
@@ -1,70 +0,0 @@ -using System.Net.Http.Headers; -using LibMatrix; -using LibMatrix.Extensions; -using LibMatrix.Responses; -using LibMatrix.Services; -using Microsoft.AspNetCore.Mvc; -using MxApiExtensions.Classes.LibMatrix; -using MxApiExtensions.Services; - -namespace MxApiExtensions.Controllers; - -[ApiController] -[Route("/")] -public class LoginController : ControllerBase { - private readonly ILogger _logger; - private readonly HomeserverProviderService _hsProvider; - private readonly HomeserverResolverService _hsResolver; - private readonly AuthenticationService _auth; - private readonly MxApiExtensionsConfiguration _conf; - - public LoginController(ILogger<LoginController> logger, HomeserverProviderService hsProvider, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf) { - _logger = logger; - _hsProvider = hsProvider; - _hsResolver = hsResolver; - _auth = auth; - _conf = conf; - } - - [HttpPost("/_matrix/client/{_}/login")] - public async Task Proxy([FromBody] LoginRequest request, string _) { - if (!request.Identifier.User.Contains("#")) { - Response.StatusCode = (int)StatusCodes.Status403Forbidden; - Response.ContentType = "application/json"; - await Response.StartAsync(); - await Response.WriteAsync(new MxApiMatrixException { - ErrorCode = "M_FORBIDDEN", - Error = "[MxApiExtensions] Invalid username, must be of the form @user#domain:" + Request.Host.Value - }.GetAsJson() ?? ""); - await Response.CompleteAsync(); - } - var hsCanonical = request.Identifier.User.Split('#')[1].Split(':')[0]; - request.Identifier.User = request.Identifier.User.Split(':')[0].Replace('#', ':'); - if(!request.Identifier.User.StartsWith('@')) request.Identifier.User = '@' + request.Identifier.User; - var hs = await _hsResolver.ResolveHomeserverFromWellKnown(hsCanonical); - //var hs = await _hsProvider.Login(hsCanonical, mxid, request.Password); - var hsClient = new MatrixHttpClient { BaseAddress = new Uri(hs.client) }; - //hsClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", hsClient.DefaultRequestHeaders.Authorization!.Parameter); - var resp = await hsClient.PostAsJsonAsync("/_matrix/client/r0/login", request); - var loginResp = await resp.Content.ReadAsStringAsync(); - Response.StatusCode = (int)resp.StatusCode; - Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; - await Response.StartAsync(); - await Response.WriteAsync(loginResp); - await Response.CompleteAsync(); - var token = (await resp.Content.ReadFromJsonAsync<LoginResponse>())!.AccessToken; - await _auth.SaveMxidForToken(token, request.Identifier.User); - } - - - [HttpGet("/_matrix/client/{_}/login")] - public async Task<object> Proxy(string? _) { - return new { - flows = new[] { - new { - type = "m.login.password" - } - } - }; - } -} diff --git a/MxApiExtensions/Controllers/GenericProxyController.cs b/MxApiExtensions/Controllers/Other/GenericProxyController.cs
index c004fcb..bae07c0 100644 --- a/MxApiExtensions/Controllers/GenericProxyController.cs +++ b/MxApiExtensions/Controllers/Other/GenericProxyController.cs
@@ -7,29 +7,17 @@ namespace MxApiExtensions.Controllers; [ApiController] [Route("/{*_}")] -public class GenericController : ControllerBase { - private readonly ILogger<GenericController> _logger; - private readonly MxApiExtensionsConfiguration _config; - private readonly AuthenticationService _authenticationService; - private readonly AuthenticatedHomeserverProviderService _authenticatedHomeserverProviderService; - private static Dictionary<string, string> _tokenMap = new(); - - public GenericController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, - AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) { - _logger = logger; - _config = config; - _authenticationService = authenticationService; - _authenticatedHomeserverProviderService = authenticatedHomeserverProviderService; - } - +public class GenericController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, + AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) + : ControllerBase { [HttpGet] public async Task Proxy([FromQuery] string? access_token, string? _) { try { - access_token ??= _authenticationService.GetToken(fail: false); - var mxid = await _authenticationService.GetMxidFromToken(fail: false); - var hs = await _authenticatedHomeserverProviderService.GetHomeserver(); + // access_token ??= _authenticationService.GetToken(fail: false); + // var mxid = await _authenticationService.GetMxidFromToken(fail: false); + var hs = await authenticatedHomeserverProviderService.GetRemoteHomeserver(); - _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); + logger.LogInformation("Proxying request: {}{}", Request.Path, Request.QueryString); //remove access_token from query string Request.QueryString = new QueryString( @@ -55,7 +43,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (MxApiMatrixException e) { - _logger.LogError(e, "Matrix error"); + logger.LogError(e, "Matrix error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "application/json"; @@ -63,7 +51,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (Exception e) { - _logger.LogError(e, "Unhandled error"); + logger.LogError(e, "Unhandled error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "text/plain"; @@ -75,11 +63,11 @@ public class GenericController : ControllerBase { [HttpPost] public async Task ProxyPost([FromQuery] string? access_token, string _) { try { - access_token ??= _authenticationService.GetToken(fail: false); - var mxid = await _authenticationService.GetMxidFromToken(fail: false); - var hs = await _authenticatedHomeserverProviderService.GetHomeserver(); + access_token ??= authenticationService.GetToken(fail: false); + var mxid = await authenticationService.GetMxidFromToken(fail: false); + var hs = await authenticatedHomeserverProviderService.GetHomeserver(); - _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); + logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); using var hc = new HttpClient(); hc.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", access_token); @@ -112,7 +100,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (MxApiMatrixException e) { - _logger.LogError(e, "Matrix error"); + logger.LogError(e, "Matrix error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "application/json"; @@ -120,7 +108,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (Exception e) { - _logger.LogError(e, "Unhandled error"); + logger.LogError(e, "Unhandled error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "text/plain"; @@ -132,11 +120,11 @@ public class GenericController : ControllerBase { [HttpPut] public async Task ProxyPut([FromQuery] string? access_token, string _) { try { - access_token ??= _authenticationService.GetToken(fail: false); - var mxid = await _authenticationService.GetMxidFromToken(fail: false); - var hs = await _authenticatedHomeserverProviderService.GetHomeserver(); + access_token ??= authenticationService.GetToken(fail: false); + var mxid = await authenticationService.GetMxidFromToken(fail: false); + var hs = await authenticatedHomeserverProviderService.GetHomeserver(); - _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); + logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); using var hc = new HttpClient(); hc.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", access_token); @@ -169,7 +157,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (MxApiMatrixException e) { - _logger.LogError(e, "Matrix error"); + logger.LogError(e, "Matrix error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "application/json"; @@ -177,7 +165,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (Exception e) { - _logger.LogError(e, "Unhandled error"); + logger.LogError(e, "Unhandled error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "text/plain"; diff --git a/MxApiExtensions/Controllers/Other/MediaProxyController.cs b/MxApiExtensions/Controllers/Other/MediaProxyController.cs new file mode 100644
index 0000000..03b68ba --- /dev/null +++ b/MxApiExtensions/Controllers/Other/MediaProxyController.cs
@@ -0,0 +1,78 @@ +using System.Net.Http.Headers; +using LibMatrix.Homeservers; +using LibMatrix.Services; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers; + +[ApiController] +[Route("/")] +public class MediaProxyController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, + AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService, HomeserverProviderService hsProvider) + : ControllerBase { + private class MediaCacheEntry { + public DateTime LastRequested { get; set; } = DateTime.Now; + public byte[] Data { get; set; } + public string ContentType { get; set; } + public long Size => Data.LongCount(); + } + + private static Dictionary<string, MediaCacheEntry> _mediaCache = new(); + private static SemaphoreSlim _semaphore = new(1, 1); + + [HttpGet("/_matrix/media/{_}/download/{serverName}/{mediaId}")] + public async Task Proxy(string? _, string serverName, string mediaId) { + try { + logger.LogInformation("Proxying media: {}{}", serverName, mediaId); + + await _semaphore.WaitAsync(); + MediaCacheEntry entry; + if (!_mediaCache.ContainsKey($"{serverName}/{mediaId}")) { + _mediaCache.Add($"{serverName}/{mediaId}", entry = new()); + List<RemoteHomeserver> FeasibleHomeservers = new(); + { + var a = await authenticatedHomeserverProviderService.TryGetRemoteHomeserver(); + if(a is not null) + FeasibleHomeservers.Add(a); + } + + FeasibleHomeservers.Add(await hsProvider.GetRemoteHomeserver(serverName)); + + foreach (var homeserver in FeasibleHomeservers) { + var resp = await homeserver.ClientHttpClient.GetAsync($"{Request.Path}"); + if(!resp.IsSuccessStatusCode) continue; + entry.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; + entry.Data = await resp.Content.ReadAsByteArrayAsync(); + break; + } + } + else entry = _mediaCache[$"{serverName}/{mediaId}"]; + _semaphore.Release(); + + Response.StatusCode = 200; + Response.ContentType = entry.ContentType; + await Response.StartAsync(); + await Response.Body.WriteAsync(entry.Data, 0, entry.Data.Length); + await Response.Body.FlushAsync(); + await Response.CompleteAsync(); + } + catch (MxApiMatrixException e) { + logger.LogError(e, "Matrix error"); + Response.StatusCode = StatusCodes.Status500InternalServerError; + Response.ContentType = "application/json"; + + await Response.WriteAsync(e.GetAsJson()); + await Response.CompleteAsync(); + } + catch (Exception e) { + logger.LogError(e, "Unhandled error"); + Response.StatusCode = StatusCodes.Status500InternalServerError; + Response.ContentType = "text/plain"; + + await Response.WriteAsync(e.ToString()); + await Response.CompleteAsync(); + } + } +} diff --git a/MxApiExtensions/Controllers/WellKnownController.cs b/MxApiExtensions/Controllers/Other/WellKnownController.cs
index b27451f..c0e255f 100644 --- a/MxApiExtensions/Controllers/WellKnownController.cs +++ b/MxApiExtensions/Controllers/Other/WellKnownController.cs
@@ -5,18 +5,14 @@ namespace MxApiExtensions.Controllers; [ApiController] [Route("/")] -public class WellKnownController : ControllerBase { - private readonly MxApiExtensionsConfiguration _config; - - public WellKnownController(MxApiExtensionsConfiguration config) { - _config = config; - } +public class WellKnownController(MxApiExtensionsConfiguration config) : ControllerBase { + private readonly MxApiExtensionsConfiguration _config = config; [HttpGet("/.well-known/matrix/client")] public object GetWellKnown() { var res = new JsonObject(); res.Add("m.homeserver", new JsonObject { - { "base_url", Request.Scheme + "://" + Request.Host + "/" }, + { "base_url", Request.Scheme + "://" + Request.Host }, }); return res; } diff --git a/MxApiExtensions/Controllers/SyncController.cs b/MxApiExtensions/Controllers/SyncController.cs deleted file mode 100644
index 0b0007f..0000000 --- a/MxApiExtensions/Controllers/SyncController.cs +++ /dev/null
@@ -1,274 +0,0 @@ -using System.Collections.Concurrent; -using System.Text.Json; -using System.Text.Json.Serialization; -using System.Web; -using LibMatrix; -using LibMatrix.EventTypes.Spec.State; -using LibMatrix.Helpers; -using LibMatrix.Homeservers; -using LibMatrix.Responses; -using LibMatrix.RoomTypes; -using Microsoft.AspNetCore.Mvc; -using MxApiExtensions.Classes; -using MxApiExtensions.Classes.LibMatrix; -using MxApiExtensions.Extensions; -using MxApiExtensions.Services; - -namespace MxApiExtensions.Controllers; - -[ApiController] -[Route("/")] -public class SyncController : ControllerBase { - private readonly ILogger<SyncController> _logger; - private readonly MxApiExtensionsConfiguration _config; - private readonly AuthenticationService _auth; - private readonly AuthenticatedHomeserverProviderService _hs; - - private static readonly ConcurrentDictionary<string, SyncState> _syncStates = new(); - - public SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hs) { - _logger = logger; - _config = config; - _auth = auth; - _hs = hs; - } - - [HttpGet("/_matrix/client/v3/sync")] - public async Task Sync([FromQuery] string? since, [FromQuery] int timeout = 1000) { - Task? preloadTask = null; - AuthenticatedHomeserverGeneric? hs = null; - try { - hs = await _hs.GetHomeserver(); - } - catch (Exception e) { - Console.WriteLine(); - } - var qs = HttpUtility.ParseQueryString(Request.QueryString.Value!); - qs.Remove("access_token"); - - if (!_config.FastInitialSync.Enabled) { - _logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}"); - await Response.WriteHttpResponse(result); - return; - } - - try { - var syncState = _syncStates.GetOrAdd(hs.AccessToken, _ => { - _logger.LogInformation("Started tracking sync state for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - return new SyncState { - IsInitialSync = string.IsNullOrWhiteSpace(since), - Homeserver = hs - }; - }); - - if (syncState.NextSyncResponse is null) { - _logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - - if (syncState.IsInitialSync) { - preloadTask = EnqueuePreloadData(syncState); - } - - syncState.NextSyncResponseStartedAt = DateTime.Now; - syncState.NextSyncResponse = Task.Delay(30_000); - syncState.NextSyncResponse.ContinueWith(x => { - _logger.LogInformation("Sync for {} on {} ({}) starting", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - syncState.NextSyncResponse = hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}"); - }); - } - - if (syncState.SyncQueue.Count > 0) { - _logger.LogInformation("Sync for {} on {} ({}) has {} queued results", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, syncState.SyncQueue.Count); - syncState.SyncQueue.TryDequeue(out var result); - - Response.StatusCode = StatusCodes.Status200OK; - Response.ContentType = "application/json"; - await Response.StartAsync(); - await JsonSerializer.SerializeAsync(Response.Body, result, new JsonSerializerOptions { - WriteIndented = true, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull - }); - await Response.CompleteAsync(); - return; - } - - timeout = Math.Clamp(timeout, 0, 100); - _logger.LogInformation("Sync for {} on {} ({}) is still running, waiting for {}ms, {} elapsed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, timeout, - DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)); - - try { - await syncState.NextSyncResponse.WaitAsync(TimeSpan.FromMilliseconds(timeout)); - } - catch { } - - if (syncState.NextSyncResponse is Task<HttpResponseMessage> { IsCompleted: true } response) { - _logger.LogInformation("Sync for {} on {} ({}) completed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - var resp = await response; - await Response.WriteHttpResponse(resp); - return; - } - - // await Task.Delay(timeout); - _logger.LogInformation("Sync for {} on {} ({}): sending bogus response", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - Response.StatusCode = StatusCodes.Status200OK; - Response.ContentType = "application/json"; - await Response.StartAsync(); - var SyncResponse = new SyncResponse { - // NextBatch = "MxApiExtensions::Next" + Random.Shared.NextInt64(), - NextBatch = since ?? "", - Presence = new() { - Events = new() { - await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status}") - } - }, - Rooms = new() { - Invite = new(), - Join = new() - } - }; - await JsonSerializer.SerializeAsync(Response.Body, SyncResponse, new JsonSerializerOptions { - WriteIndented = true, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull - }); - await Response.CompleteAsync(); - } - catch (MxApiMatrixException e) { - _logger.LogError(e, "Error while syncing for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId, - _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken); - - Response.StatusCode = StatusCodes.Status500InternalServerError; - Response.ContentType = "application/json"; - - await Response.WriteAsJsonAsync(e.GetAsJson()); - await Response.CompleteAsync(); - } - - catch (Exception e) { - //catch SSL connection errors and retry - if (e.InnerException is HttpRequestException && e.InnerException.Message.Contains("The SSL connection could not be established")) { - _logger.LogWarning("Caught SSL connection error, retrying sync for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId, - _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken); - await Sync(since, timeout); - return; - } - - _logger.LogError(e, "Error while syncing for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId, - _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken); - - Response.StatusCode = StatusCodes.Status500InternalServerError; - Response.ContentType = "text/plain"; - - await Response.WriteAsync(e.ToString()); - await Response.CompleteAsync(); - } - - Response.Body.Close(); - if (preloadTask is not null) - await preloadTask; - } - - private async Task EnqueuePreloadData(SyncState syncState) { - var rooms = await syncState.Homeserver.GetJoinedRooms(); - var dm_rooms = (await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct")).Aggregate(new List<string>(), (list, entry) => { - list.AddRange(entry.Value); - return list; - }); - - var ownHs = syncState.Homeserver.WhoAmI.UserId.Split(':')[1]; - rooms = rooms.OrderBy(x => { - if (dm_rooms.Contains(x.RoomId)) return -1; - var parts = x.RoomId.Split(':'); - if (parts[1] == ownHs) return 200; - if (HomeserverWeightEstimation.EstimatedSize.ContainsKey(parts[1])) return HomeserverWeightEstimation.EstimatedSize[parts[1]] + parts[0].Length; - return 5000; - }).ToList(); - var roomDataTasks = rooms.Select(room => EnqueueRoomData(syncState, room)).ToList(); - _logger.LogInformation("Preloading data for {} rooms on {} ({})", roomDataTasks.Count, syncState.Homeserver.ServerName, syncState.Homeserver.AccessToken); - - await Task.WhenAll(roomDataTasks); - } - - private SemaphoreSlim _roomDataSemaphore = new(4, 4); - - private async Task EnqueueRoomData(SyncState syncState, GenericRoom room) { - await _roomDataSemaphore.WaitAsync(); - var roomState = room.GetFullStateAsync(); - var timeline = await room.GetMessagesAsync(limit: 100, dir: "b"); - timeline.Chunk.Reverse(); - var SyncResponse = new SyncResponse { - Rooms = new() { - Join = new() { - { - room.RoomId, - new SyncResponse.RoomsDataStructure.JoinedRoomDataStructure { - AccountData = new() { - Events = new() - }, - Ephemeral = new() { - Events = new() - }, - State = new() { - Events = timeline.State - }, - UnreadNotifications = new() { - HighlightCount = 0, - NotificationCount = 0 - }, - Timeline = new() { - Events = timeline.Chunk, - Limited = false, - PrevBatch = timeline.Start - }, - Summary = new() { - Heroes = new(), - InvitedMemberCount = 0, - JoinedMemberCount = 1 - } - } - } - } - }, - Presence = new() { - Events = new() { - await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status} {room.RoomId}") - } - }, - NextBatch = "" - }; - - await foreach (var stateEvent in roomState) { - SyncResponse.Rooms.Join[room.RoomId].State.Events.Add(stateEvent); - } - - var joinRoom = SyncResponse.Rooms.Join[room.RoomId]; - joinRoom.Summary.Heroes.AddRange(joinRoom.State.Events - .Where(x => - x.Type == "m.room.member" - && x.StateKey != syncState.Homeserver.WhoAmI.UserId - && (x.TypedContent as RoomMemberEventContent).Membership == "join" - ) - .Select(x => x.StateKey)); - joinRoom.Summary.JoinedMemberCount = joinRoom.Summary.Heroes.Count; - - syncState.SyncQueue.Enqueue(SyncResponse); - _roomDataSemaphore.Release(); - } - - private async Task<StateEventResponse> GetStatusMessage(SyncState syncState, string message) { - return new StateEventResponse { - TypedContent = new PresenceEventContent { - DisplayName = "MxApiExtensions", - Presence = "online", - StatusMessage = message, - // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl - AvatarUrl = "" - }, - Type = "m.presence", - StateKey = syncState.Homeserver.WhoAmI.UserId, - Sender = syncState.Homeserver.WhoAmI.UserId, - UserId = syncState.Homeserver.WhoAmI.UserId, - EventId = Guid.NewGuid().ToString(), - OriginServerTs = 0 - }; - } -}