diff options
author | TheArcaneBrony <myrainbowdash949@gmail.com> | 2023-11-05 17:59:38 +0100 |
---|---|---|
committer | TheArcaneBrony <myrainbowdash949@gmail.com> | 2023-11-05 17:59:38 +0100 |
commit | 2abb132234546e61bb0aff3897dc49e72ea84f5d (patch) | |
tree | c885c03d35e7a0a6b8fc21bd0b259216c61c877c /MxApiExtensions | |
parent | Update (diff) | |
download | MxApiExtensions-2abb132234546e61bb0aff3897dc49e72ea84f5d.tar.xz |
Working sync proxy
Diffstat (limited to 'MxApiExtensions')
16 files changed, 709 insertions, 431 deletions
diff --git a/MxApiExtensions/Classes/SyncState.cs b/MxApiExtensions/Classes/SyncState.cs index 6950954..e44d35c 100644 --- a/MxApiExtensions/Classes/SyncState.cs +++ b/MxApiExtensions/Classes/SyncState.cs @@ -1,15 +1,91 @@ using System.Collections.Concurrent; +using System.Text.Json.Serialization; +using LibMatrix; +using LibMatrix.EventTypes.Spec.State; using LibMatrix.Helpers; using LibMatrix.Homeservers; using LibMatrix.Responses; +using Microsoft.OpenApi.Extensions; namespace MxApiExtensions.Classes; public class SyncState { + private Task? _nextSyncResponse; public string? NextBatch { get; set; } public ConcurrentQueue<SyncResponse> SyncQueue { get; set; } = new(); public bool IsInitialSync { get; set; } - public Task? NextSyncResponse { get; set; } + + [JsonIgnore] + public Task? NextSyncResponse { + get => _nextSyncResponse; + set { + _nextSyncResponse = value; + NextSyncResponseStartedAt = DateTime.Now; + } + } + public DateTime NextSyncResponseStartedAt { get; set; } = DateTime.Now; + + [JsonIgnore] public AuthenticatedHomeserverGeneric Homeserver { get; set; } -} + +#region Debug stuff + + public object NextSyncResponseTaskInfo => new { + NextSyncResponse?.Id, + NextSyncResponse?.IsCompleted, + NextSyncResponse?.IsCompletedSuccessfully, + NextSyncResponse?.IsCanceled, + NextSyncResponse?.IsFaulted, + Status = NextSyncResponse?.Status.GetDisplayName() + }; + + +#endregion + + public void SendEphemeralTimelineEventInRoom(string roomId, StateEventResponse @event) { + SyncQueue.Enqueue(new() { + NextBatch = NextBatch ?? "null", + Rooms = new() { + Join = new() { + { + roomId, + new() { + Timeline = new() { + Events = new() { + @event + } + } + } + } + } + } + }); + } + + public void SendStatusMessage(string text) { + SyncQueue.Enqueue(new() { + NextBatch = NextBatch ?? "null", + Presence = new() { + Events = new() { + new StateEventResponse { + TypedContent = new PresenceEventContent { + DisplayName = "MxApiExtensions", + Presence = "online", + StatusMessage = text, + // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl + AvatarUrl = "", + LastActiveAgo = 15, + CurrentlyActive = true + }, + Type = "m.presence", + StateKey = Homeserver.WhoAmI.UserId, + Sender = Homeserver.WhoAmI.UserId, + EventId = Guid.NewGuid().ToString(), + OriginServerTs = 0 + } + } + } + }); + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/ClientVersionsController.cs b/MxApiExtensions/Controllers/Client/ClientVersionsController.cs index d29e3b2..d29e3b2 100644 --- a/MxApiExtensions/Controllers/ClientVersionsController.cs +++ b/MxApiExtensions/Controllers/Client/ClientVersionsController.cs diff --git a/MxApiExtensions/Controllers/Client/LoginController.cs b/MxApiExtensions/Controllers/Client/LoginController.cs new file mode 100644 index 0000000..009aaef --- /dev/null +++ b/MxApiExtensions/Controllers/Client/LoginController.cs @@ -0,0 +1,73 @@ +using System.Net.Http.Headers; +using ArcaneLibs.Extensions; +using LibMatrix; +using LibMatrix.Extensions; +using LibMatrix.Responses; +using LibMatrix.Services; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers; + +[ApiController] +[Route("/")] +public class LoginController(ILogger<LoginController> logger, HomeserverProviderService hsProvider, HomeserverResolverService hsResolver, AuthenticationService auth, + MxApiExtensionsConfiguration conf) + : ControllerBase { + private readonly ILogger _logger = logger; + private readonly HomeserverProviderService _hsProvider = hsProvider; + private readonly MxApiExtensionsConfiguration _conf = conf; + + [HttpPost("/_matrix/client/{_}/login")] + public async Task Proxy([FromBody] LoginRequest request, string _) { + string hsCanonical = null; + if (Request.Headers.Keys.Any(x => x.ToUpper() == "MXAE_UPSTREAM")) { + hsCanonical = Request.Headers.GetByCaseInsensitiveKey("MXAE_UPSTREAM")[0]!; + _logger.LogInformation("Found upstream: {}", hsCanonical); + } + else { + if (!request.Identifier.User.Contains("#")) { + Response.StatusCode = (int)StatusCodes.Status403Forbidden; + Response.ContentType = "application/json"; + await Response.StartAsync(); + await Response.WriteAsync(new MxApiMatrixException { + ErrorCode = "M_FORBIDDEN", + Error = "[MxApiExtensions] Invalid username, must be of the form @user#domain:" + Request.Host.Value + }.GetAsJson() ?? ""); + await Response.CompleteAsync(); + } + + hsCanonical = request.Identifier.User.Split('#')[1].Split(':')[0]; + request.Identifier.User = request.Identifier.User.Split(':')[0].Replace('#', ':'); + if (!request.Identifier.User.StartsWith('@')) request.Identifier.User = '@' + request.Identifier.User; + } + + var hs = await hsResolver.ResolveHomeserverFromWellKnown(hsCanonical); + //var hs = await _hsProvider.Login(hsCanonical, mxid, request.Password); + var hsClient = new MatrixHttpClient { BaseAddress = new Uri(hs.Client) }; + //hsClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", hsClient.DefaultRequestHeaders.Authorization!.Parameter); + if (!string.IsNullOrWhiteSpace(request.InitialDeviceDisplayName)) + request.InitialDeviceDisplayName += $" (via MxApiExtensions at {Request.Host.Value})"; + var resp = await hsClient.PostAsJsonAsync("/_matrix/client/r0/login", request); + var loginResp = await resp.Content.ReadAsStringAsync(); + Response.StatusCode = (int)resp.StatusCode; + Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; + await Response.StartAsync(); + await Response.WriteAsync(loginResp); + await Response.CompleteAsync(); + var token = (await resp.Content.ReadFromJsonAsync<LoginResponse>())!.AccessToken; + await auth.SaveMxidForToken(token, request.Identifier.User); + } + + [HttpGet("/_matrix/client/{_}/login")] + public async Task<object> Proxy(string? _) { + return new { + flows = new[] { + new { + type = "m.login.password" + } + } + }; + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs new file mode 100644 index 0000000..3d1d4e2 --- /dev/null +++ b/MxApiExtensions/Controllers/Client/Room/RoomsSendMessageController.cs @@ -0,0 +1,18 @@ +using LibMatrix.Services; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers.Client.Room; + +[ApiController] +[Route("/")] +public class RoomController(ILogger<LoginController> logger, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf, + AuthenticatedHomeserverProviderService hsProvider) + : ControllerBase { + [HttpGet("/_matrix/client/{_}/rooms/{roomId}/members_by_homeserver")] + public async Task<Dictionary<string, List<string>>> GetRoomMembersByHomeserver(string _, [FromRoute] string roomId, [FromQuery] bool joinedOnly = true) { + var hs = await hsProvider.GetHomeserver(); + var room = hs.GetRoom(roomId); + return await room.GetMembersByHomeserverAsync(); + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs new file mode 100644 index 0000000..6d3a774 --- /dev/null +++ b/MxApiExtensions/Controllers/Client/RoomsSendMessageController.cs @@ -0,0 +1,72 @@ +using System.Buffers.Text; +using System.Net.Http.Headers; +using System.Text.Json; +using System.Text.Json.Nodes; +using ArcaneLibs.Extensions; +using LibMatrix; +using LibMatrix.EventTypes.Spec; +using LibMatrix.Extensions; +using LibMatrix.Helpers; +using LibMatrix.Homeservers; +using LibMatrix.Responses; +using LibMatrix.Services; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers; + +[ApiController] +[Route("/")] +public class RoomsSendMessageController(ILogger<LoginController> logger, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf, + AuthenticatedHomeserverProviderService hsProvider) + : ControllerBase { + [HttpPut("/_matrix/client/{_}/rooms/{roomId}/send/m.room.message/{txnId}")] + public async Task Proxy([FromBody] JsonObject request, [FromRoute] string roomId, [FromRoute] string txnId, string _) { + var hs = await hsProvider.GetHomeserver(); + + var msg = request.Deserialize<RoomMessageEventContent>(); + if (msg is not null && msg.Body.StartsWith("mxae!")) { +#pragma warning disable CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + handleMxaeCommand(hs, roomId, msg); +#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed + await Response.WriteAsJsonAsync(new EventIdResponse() { + EventId = "$" + string.Join("", Random.Shared.GetItems("abcdefghijklmnopqrstuvwxyzABCDEFGHIJLKMNOPQRSTUVWXYZ0123456789".ToCharArray(), 100)) + }); + await Response.CompleteAsync(); + } + else { + try { + var resp = await hs.ClientHttpClient.PutAsJsonAsync($"{Request.Path}{Request.QueryString}", request); + var loginResp = await resp.Content.ReadAsStringAsync(); + Response.StatusCode = (int)resp.StatusCode; + Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; + await Response.StartAsync(); + await Response.WriteAsync(loginResp); + await Response.CompleteAsync(); + } + catch (MatrixException e) { + await Response.StartAsync(); + await Response.WriteAsync(e.GetAsJson()); + await Response.CompleteAsync(); + } + } + } + + private async Task handleMxaeCommand(AuthenticatedHomeserverGeneric hs, string roomId, RoomMessageEventContent msg) { + var syncState = SyncController._syncStates.GetValueOrDefault(hs.AccessToken); + if (syncState is null) return; + syncState.SendEphemeralTimelineEventInRoom(roomId, new() { + Sender = "@mxae:" + Request.Host.Value, + Type = "m.room.message", + TypedContent = MessageFormatter.FormatSuccess("Thinking..."), + OriginServerTs = (ulong)new DateTimeOffset(DateTime.UtcNow.ToUniversalTime()).ToUnixTimeMilliseconds(), + Unsigned = new() { + Age = 1 + }, + RoomId = roomId, + EventId = "$" + string.Join("", Random.Shared.GetItems("abcdefghijklmnopqrstuvwxyzABCDEFGHIJLKMNOPQRSTUVWXYZ0123456789".ToCharArray(), 100)) + }); + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Client/SyncController.cs b/MxApiExtensions/Controllers/Client/SyncController.cs new file mode 100644 index 0000000..2944c3b --- /dev/null +++ b/MxApiExtensions/Controllers/Client/SyncController.cs @@ -0,0 +1,243 @@ +using System.Collections.Concurrent; +using System.Diagnostics; +using System.Text.Json; +using System.Text.Json.Nodes; +using System.Text.Json.Serialization; +using System.Web; +using ArcaneLibs; +using LibMatrix; +using LibMatrix.EventTypes.Spec.State; +using LibMatrix.Helpers; +using LibMatrix.Homeservers; +using LibMatrix.Responses; +using LibMatrix.RoomTypes; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Extensions; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers; + +[ApiController] +[Route("/")] +public class SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hsProvider) + : ControllerBase { + public static readonly ConcurrentDictionary<string, SyncState> _syncStates = new(); + + private static SemaphoreSlim _semaphoreSlim = new(1, 1); + private Stopwatch _syncElapsed = Stopwatch.StartNew(); + + [HttpGet("/_matrix/client/{_}/sync")] + public async Task Sync(string _, [FromQuery] string? since, [FromQuery] int timeout = 1000) { + Task? preloadTask = null; + AuthenticatedHomeserverGeneric? hs = null; + try { + hs = await hsProvider.GetHomeserver(); + } + catch (Exception e) { + Console.WriteLine(e); + } + + var qs = HttpUtility.ParseQueryString(Request.QueryString.Value!); + qs.Remove("access_token"); + if (since == "null") qs.Remove("since"); + + if (!config.FastInitialSync.Enabled) { + logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}"); + await Response.WriteHttpResponse(result); + return; + } + + await _semaphoreSlim.WaitAsync(); + var syncState = _syncStates.GetOrAdd($"{hs.WhoAmI.UserId}/{hs.WhoAmI.DeviceId}/{hs.ServerName}:{hs.AccessToken}", _ => { + logger.LogInformation("Started tracking sync state for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + var ss = new SyncState { + IsInitialSync = string.IsNullOrWhiteSpace(since), + Homeserver = hs + }; + if (ss.IsInitialSync) { + preloadTask = EnqueuePreloadData(ss); + } + + logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + + ss.NextSyncResponseStartedAt = DateTime.Now; + ss.NextSyncResponse = Task.Delay(15_000); + ss.NextSyncResponse.ContinueWith(async x => { + logger.LogInformation("Sync for {} on {} ({}) starting", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); + ss.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}"); + (ss.NextSyncResponse as Task<HttpResponseMessage>).ContinueWith(async x => EnqueueSyncResponse(ss, await x)); + }); + return ss; + }); + _semaphoreSlim.Release(); + + if (syncState.SyncQueue.Count > 0) { + logger.LogInformation("Sync for {} on {} ({}) has {} queued results", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, syncState.SyncQueue.Count); + syncState.SyncQueue.TryDequeue(out var result); + while (result is null) + syncState.SyncQueue.TryDequeue(out result); + Response.StatusCode = StatusCodes.Status200OK; + Response.ContentType = "application/json"; + await Response.StartAsync(); + result.NextBatch ??= since ?? syncState.NextBatch; + await JsonSerializer.SerializeAsync(Response.Body, result, new JsonSerializerOptions { + WriteIndented = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }); + await Response.CompleteAsync(); + return; + } + + var newTimeout = Math.Clamp(timeout, 0, syncState.IsInitialSync ? syncState.SyncQueue.Count >= 2 ? 0 : 250 : timeout); + logger.LogInformation("Sync for {} on {} ({}) is still running, waiting for {}ms, {} elapsed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, newTimeout, + DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)); + + try { + if (syncState.NextSyncResponse is not null) + await syncState.NextSyncResponse.WaitAsync(TimeSpan.FromMilliseconds(newTimeout)); + else { + syncState.NextSyncResponse = hs.ClientHttpClient.GetAsync($"/_matrix/client/v3/sync?{qs}"); + (syncState.NextSyncResponse as Task<HttpResponseMessage>).ContinueWith(async x => EnqueueSyncResponse(syncState, await x)); + // await Task.Delay(250); + } + } + catch (TimeoutException) { } + + // if (_syncElapsed.ElapsedMilliseconds > timeout) + if(syncState.NextSyncResponse?.IsCompleted == false) + syncState.SendStatusMessage( + $"M={Util.BytesToString(Process.GetCurrentProcess().WorkingSet64)} TE={DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} S={syncState.NextSyncResponse?.Status} QL={syncState.SyncQueue.Count}"); + Response.StatusCode = StatusCodes.Status200OK; + Response.ContentType = "application/json"; + await Response.StartAsync(); + var response = syncState.SyncQueue.FirstOrDefault(); + if (response is null) + response = new(); + response.NextBatch ??= since ?? syncState.NextBatch; + await JsonSerializer.SerializeAsync(Response.Body, response, new JsonSerializerOptions { + WriteIndented = true, + DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull + }); + await Response.CompleteAsync(); + + Response.Body.Close(); + if (preloadTask is not null) + await preloadTask; + } + + private async Task EnqueuePreloadData(SyncState syncState) { + var rooms = await syncState.Homeserver.GetJoinedRooms(); + var dm_rooms = (await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct")).Aggregate(new List<string>(), (list, entry) => { + list.AddRange(entry.Value); + return list; + }); + + var ownHs = syncState.Homeserver.WhoAmI.UserId.Split(':')[1]; + rooms = rooms.OrderBy(x => { + if (dm_rooms.Contains(x.RoomId)) return -1; + var parts = x.RoomId.Split(':'); + if (parts[1] == ownHs) return 200; + if (HomeserverWeightEstimation.EstimatedSize.ContainsKey(parts[1])) return HomeserverWeightEstimation.EstimatedSize[parts[1]] + parts[0].Length; + return 5000; + }).ToList(); + var roomDataTasks = rooms.Select(room => EnqueueRoomData(syncState, room)).ToList(); + logger.LogInformation("Preloading data for {} rooms on {} ({})", roomDataTasks.Count, syncState.Homeserver.ServerName, syncState.Homeserver.AccessToken); + + await Task.WhenAll(roomDataTasks); + } + + private SemaphoreSlim _roomDataSemaphore = new(32, 32); + + private async Task EnqueueRoomData(SyncState syncState, GenericRoom room) { + await _roomDataSemaphore.WaitAsync(); + var roomState = room.GetFullStateAsync(); + var timeline = await room.GetMessagesAsync(limit: 100, dir: "b"); + timeline.Chunk.Reverse(); + var SyncResponse = new SyncResponse { + Rooms = new() { + Join = new() { + { + room.RoomId, + new SyncResponse.RoomsDataStructure.JoinedRoomDataStructure { + AccountData = new() { + Events = new() + }, + Ephemeral = new() { + Events = new() + }, + State = new() { + Events = timeline.State + }, + UnreadNotifications = new() { + HighlightCount = 0, + NotificationCount = 0 + }, + Timeline = new() { + Events = timeline.Chunk, + Limited = false, + PrevBatch = timeline.Start + }, + Summary = new() { + Heroes = new(), + InvitedMemberCount = 0, + JoinedMemberCount = 1 + } + } + } + } + }, + Presence = new() { + Events = new() { + await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status} {room.RoomId}") + } + }, + NextBatch = "" + }; + + await foreach (var stateEvent in roomState) { + SyncResponse.Rooms.Join[room.RoomId].State.Events.Add(stateEvent); + } + + var joinRoom = SyncResponse.Rooms.Join[room.RoomId]; + joinRoom.Summary.Heroes.AddRange(joinRoom.State.Events + .Where(x => + x.Type == "m.room.member" + && x.StateKey != syncState.Homeserver.WhoAmI.UserId + && (x.TypedContent as RoomMemberEventContent).Membership == "join" + ) + .Select(x => x.StateKey)); + joinRoom.Summary.JoinedMemberCount = joinRoom.Summary.Heroes.Count; + + syncState.SyncQueue.Enqueue(SyncResponse); + _roomDataSemaphore.Release(); + } + + private async Task<StateEventResponse> GetStatusMessage(SyncState syncState, string message) { + return new StateEventResponse { + TypedContent = new PresenceEventContent { + DisplayName = "MxApiExtensions", + Presence = "online", + StatusMessage = message, + // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl + AvatarUrl = "" + }, + Type = "m.presence", + StateKey = syncState.Homeserver.WhoAmI.UserId, + Sender = syncState.Homeserver.WhoAmI.UserId, + EventId = Guid.NewGuid().ToString(), + OriginServerTs = 0 + }; + } + + private async Task EnqueueSyncResponse(SyncState ss, HttpResponseMessage task) { + var sr = await task.Content.ReadFromJsonAsync<JsonObject>(); + if (sr.ContainsKey("error")) throw sr.Deserialize<MatrixException>()!; + ss.NextBatch = sr["next_batch"].GetValue<string>(); + ss.IsInitialSync = false; + ss.SyncQueue.Enqueue(sr.Deserialize<SyncResponse>()); + ss.NextSyncResponse = null; + } +} \ No newline at end of file diff --git a/MxApiExtensions/Controllers/Extensions/DebugController.cs b/MxApiExtensions/Controllers/Extensions/DebugController.cs new file mode 100644 index 0000000..79ed2f0 --- /dev/null +++ b/MxApiExtensions/Controllers/Extensions/DebugController.cs @@ -0,0 +1,45 @@ +using System.Collections.Concurrent; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers.Extensions; + +[ApiController] +[Route("/")] +public class DebugController : ControllerBase { + private readonly ILogger _logger; + private readonly MxApiExtensionsConfiguration _config; + private readonly AuthenticationService _authenticationService; + + private static ConcurrentDictionary<string, RoomInfoEntry> _roomInfoCache = new(); + + public DebugController(ILogger<ProxyConfigurationController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, + AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) { + _logger = logger; + _config = config; + _authenticationService = authenticationService; + } + + [HttpGet("debug")] + public async Task<object?> GetDebug() { + var mxid = await _authenticationService.GetMxidFromToken(); + if(!_config.Admins.Contains(mxid)) { + _logger.LogWarning("Got debug request for {user}, but they are not an admin", mxid); + Response.StatusCode = StatusCodes.Status403Forbidden; + Response.ContentType = "application/json"; + + await Response.WriteAsJsonAsync(new { + ErrorCode = "M_FORBIDDEN", + Error = "You are not an admin" + }); + await Response.CompleteAsync(); + return null; + } + + _logger.LogInformation("Got debug request for {user}", mxid); + return new { + SyncStates = SyncController._syncStates + }; + } +} diff --git a/MxApiExtensions/Controllers/LoginController.cs b/MxApiExtensions/Controllers/LoginController.cs deleted file mode 100644 index bd354ef..0000000 --- a/MxApiExtensions/Controllers/LoginController.cs +++ /dev/null @@ -1,70 +0,0 @@ -using System.Net.Http.Headers; -using LibMatrix; -using LibMatrix.Extensions; -using LibMatrix.Responses; -using LibMatrix.Services; -using Microsoft.AspNetCore.Mvc; -using MxApiExtensions.Classes.LibMatrix; -using MxApiExtensions.Services; - -namespace MxApiExtensions.Controllers; - -[ApiController] -[Route("/")] -public class LoginController : ControllerBase { - private readonly ILogger _logger; - private readonly HomeserverProviderService _hsProvider; - private readonly HomeserverResolverService _hsResolver; - private readonly AuthenticationService _auth; - private readonly MxApiExtensionsConfiguration _conf; - - public LoginController(ILogger<LoginController> logger, HomeserverProviderService hsProvider, HomeserverResolverService hsResolver, AuthenticationService auth, MxApiExtensionsConfiguration conf) { - _logger = logger; - _hsProvider = hsProvider; - _hsResolver = hsResolver; - _auth = auth; - _conf = conf; - } - - [HttpPost("/_matrix/client/{_}/login")] - public async Task Proxy([FromBody] LoginRequest request, string _) { - if (!request.Identifier.User.Contains("#")) { - Response.StatusCode = (int)StatusCodes.Status403Forbidden; - Response.ContentType = "application/json"; - await Response.StartAsync(); - await Response.WriteAsync(new MxApiMatrixException { - ErrorCode = "M_FORBIDDEN", - Error = "[MxApiExtensions] Invalid username, must be of the form @user#domain:" + Request.Host.Value - }.GetAsJson() ?? ""); - await Response.CompleteAsync(); - } - var hsCanonical = request.Identifier.User.Split('#')[1].Split(':')[0]; - request.Identifier.User = request.Identifier.User.Split(':')[0].Replace('#', ':'); - if(!request.Identifier.User.StartsWith('@')) request.Identifier.User = '@' + request.Identifier.User; - var hs = await _hsResolver.ResolveHomeserverFromWellKnown(hsCanonical); - //var hs = await _hsProvider.Login(hsCanonical, mxid, request.Password); - var hsClient = new MatrixHttpClient { BaseAddress = new Uri(hs.client) }; - //hsClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", hsClient.DefaultRequestHeaders.Authorization!.Parameter); - var resp = await hsClient.PostAsJsonAsync("/_matrix/client/r0/login", request); - var loginResp = await resp.Content.ReadAsStringAsync(); - Response.StatusCode = (int)resp.StatusCode; - Response.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; - await Response.StartAsync(); - await Response.WriteAsync(loginResp); - await Response.CompleteAsync(); - var token = (await resp.Content.ReadFromJsonAsync<LoginResponse>())!.AccessToken; - await _auth.SaveMxidForToken(token, request.Identifier.User); - } - - - [HttpGet("/_matrix/client/{_}/login")] - public async Task<object> Proxy(string? _) { - return new { - flows = new[] { - new { - type = "m.login.password" - } - } - }; - } -} diff --git a/MxApiExtensions/Controllers/GenericProxyController.cs b/MxApiExtensions/Controllers/Other/GenericProxyController.cs index c004fcb..bae07c0 100644 --- a/MxApiExtensions/Controllers/GenericProxyController.cs +++ b/MxApiExtensions/Controllers/Other/GenericProxyController.cs @@ -7,29 +7,17 @@ namespace MxApiExtensions.Controllers; [ApiController] [Route("/{*_}")] -public class GenericController : ControllerBase { - private readonly ILogger<GenericController> _logger; - private readonly MxApiExtensionsConfiguration _config; - private readonly AuthenticationService _authenticationService; - private readonly AuthenticatedHomeserverProviderService _authenticatedHomeserverProviderService; - private static Dictionary<string, string> _tokenMap = new(); - - public GenericController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, - AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) { - _logger = logger; - _config = config; - _authenticationService = authenticationService; - _authenticatedHomeserverProviderService = authenticatedHomeserverProviderService; - } - +public class GenericController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, + AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService) + : ControllerBase { [HttpGet] public async Task Proxy([FromQuery] string? access_token, string? _) { try { - access_token ??= _authenticationService.GetToken(fail: false); - var mxid = await _authenticationService.GetMxidFromToken(fail: false); - var hs = await _authenticatedHomeserverProviderService.GetHomeserver(); + // access_token ??= _authenticationService.GetToken(fail: false); + // var mxid = await _authenticationService.GetMxidFromToken(fail: false); + var hs = await authenticatedHomeserverProviderService.GetRemoteHomeserver(); - _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); + logger.LogInformation("Proxying request: {}{}", Request.Path, Request.QueryString); //remove access_token from query string Request.QueryString = new QueryString( @@ -55,7 +43,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (MxApiMatrixException e) { - _logger.LogError(e, "Matrix error"); + logger.LogError(e, "Matrix error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "application/json"; @@ -63,7 +51,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (Exception e) { - _logger.LogError(e, "Unhandled error"); + logger.LogError(e, "Unhandled error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "text/plain"; @@ -75,11 +63,11 @@ public class GenericController : ControllerBase { [HttpPost] public async Task ProxyPost([FromQuery] string? access_token, string _) { try { - access_token ??= _authenticationService.GetToken(fail: false); - var mxid = await _authenticationService.GetMxidFromToken(fail: false); - var hs = await _authenticatedHomeserverProviderService.GetHomeserver(); + access_token ??= authenticationService.GetToken(fail: false); + var mxid = await authenticationService.GetMxidFromToken(fail: false); + var hs = await authenticatedHomeserverProviderService.GetHomeserver(); - _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); + logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); using var hc = new HttpClient(); hc.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", access_token); @@ -112,7 +100,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (MxApiMatrixException e) { - _logger.LogError(e, "Matrix error"); + logger.LogError(e, "Matrix error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "application/json"; @@ -120,7 +108,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (Exception e) { - _logger.LogError(e, "Unhandled error"); + logger.LogError(e, "Unhandled error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "text/plain"; @@ -132,11 +120,11 @@ public class GenericController : ControllerBase { [HttpPut] public async Task ProxyPut([FromQuery] string? access_token, string _) { try { - access_token ??= _authenticationService.GetToken(fail: false); - var mxid = await _authenticationService.GetMxidFromToken(fail: false); - var hs = await _authenticatedHomeserverProviderService.GetHomeserver(); + access_token ??= authenticationService.GetToken(fail: false); + var mxid = await authenticationService.GetMxidFromToken(fail: false); + var hs = await authenticatedHomeserverProviderService.GetHomeserver(); - _logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); + logger.LogInformation("Proxying request for {}: {}{}", mxid, Request.Path, Request.QueryString); using var hc = new HttpClient(); hc.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", access_token); @@ -169,7 +157,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (MxApiMatrixException e) { - _logger.LogError(e, "Matrix error"); + logger.LogError(e, "Matrix error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "application/json"; @@ -177,7 +165,7 @@ public class GenericController : ControllerBase { await Response.CompleteAsync(); } catch (Exception e) { - _logger.LogError(e, "Unhandled error"); + logger.LogError(e, "Unhandled error"); Response.StatusCode = StatusCodes.Status500InternalServerError; Response.ContentType = "text/plain"; diff --git a/MxApiExtensions/Controllers/Other/MediaProxyController.cs b/MxApiExtensions/Controllers/Other/MediaProxyController.cs new file mode 100644 index 0000000..03b68ba --- /dev/null +++ b/MxApiExtensions/Controllers/Other/MediaProxyController.cs @@ -0,0 +1,78 @@ +using System.Net.Http.Headers; +using LibMatrix.Homeservers; +using LibMatrix.Services; +using Microsoft.AspNetCore.Mvc; +using MxApiExtensions.Classes.LibMatrix; +using MxApiExtensions.Services; + +namespace MxApiExtensions.Controllers; + +[ApiController] +[Route("/")] +public class MediaProxyController(ILogger<GenericController> logger, MxApiExtensionsConfiguration config, AuthenticationService authenticationService, + AuthenticatedHomeserverProviderService authenticatedHomeserverProviderService, HomeserverProviderService hsProvider) + : ControllerBase { + private class MediaCacheEntry { + public DateTime LastRequested { get; set; } = DateTime.Now; + public byte[] Data { get; set; } + public string ContentType { get; set; } + public long Size => Data.LongCount(); + } + + private static Dictionary<string, MediaCacheEntry> _mediaCache = new(); + private static SemaphoreSlim _semaphore = new(1, 1); + + [HttpGet("/_matrix/media/{_}/download/{serverName}/{mediaId}")] + public async Task Proxy(string? _, string serverName, string mediaId) { + try { + logger.LogInformation("Proxying media: {}{}", serverName, mediaId); + + await _semaphore.WaitAsync(); + MediaCacheEntry entry; + if (!_mediaCache.ContainsKey($"{serverName}/{mediaId}")) { + _mediaCache.Add($"{serverName}/{mediaId}", entry = new()); + List<RemoteHomeserver> FeasibleHomeservers = new(); + { + var a = await authenticatedHomeserverProviderService.TryGetRemoteHomeserver(); + if(a is not null) + FeasibleHomeservers.Add(a); + } + + FeasibleHomeservers.Add(await hsProvider.GetRemoteHomeserver(serverName)); + + foreach (var homeserver in FeasibleHomeservers) { + var resp = await homeserver.ClientHttpClient.GetAsync($"{Request.Path}"); + if(!resp.IsSuccessStatusCode) continue; + entry.ContentType = resp.Content.Headers.ContentType?.ToString() ?? "application/json"; + entry.Data = await resp.Content.ReadAsByteArrayAsync(); + break; + } + } + else entry = _mediaCache[$"{serverName}/{mediaId}"]; + _semaphore.Release(); + + Response.StatusCode = 200; + Response.ContentType = entry.ContentType; + await Response.StartAsync(); + await Response.Body.WriteAsync(entry.Data, 0, entry.Data.Length); + await Response.Body.FlushAsync(); + await Response.CompleteAsync(); + } + catch (MxApiMatrixException e) { + logger.LogError(e, "Matrix error"); + Response.StatusCode = StatusCodes.Status500InternalServerError; + Response.ContentType = "application/json"; + + await Response.WriteAsync(e.GetAsJson()); + await Response.CompleteAsync(); + } + catch (Exception e) { + logger.LogError(e, "Unhandled error"); + Response.StatusCode = StatusCodes.Status500InternalServerError; + Response.ContentType = "text/plain"; + + await Response.WriteAsync(e.ToString()); + await Response.CompleteAsync(); + } + } +} diff --git a/MxApiExtensions/Controllers/WellKnownController.cs b/MxApiExtensions/Controllers/Other/WellKnownController.cs index b27451f..c0e255f 100644 --- a/MxApiExtensions/Controllers/WellKnownController.cs +++ b/MxApiExtensions/Controllers/Other/WellKnownController.cs @@ -5,18 +5,14 @@ namespace MxApiExtensions.Controllers; [ApiController] [Route("/")] -public class WellKnownController : ControllerBase { - private readonly MxApiExtensionsConfiguration _config; - - public WellKnownController(MxApiExtensionsConfiguration config) { - _config = config; - } +public class WellKnownController(MxApiExtensionsConfiguration config) : ControllerBase { + private readonly MxApiExtensionsConfiguration _config = config; [HttpGet("/.well-known/matrix/client")] public object GetWellKnown() { var res = new JsonObject(); res.Add("m.homeserver", new JsonObject { - { "base_url", Request.Scheme + "://" + Request.Host + "/" }, + { "base_url", Request.Scheme + "://" + Request.Host }, }); return res; } diff --git a/MxApiExtensions/Controllers/SyncController.cs b/MxApiExtensions/Controllers/SyncController.cs deleted file mode 100644 index 0b0007f..0000000 --- a/MxApiExtensions/Controllers/SyncController.cs +++ /dev/null @@ -1,274 +0,0 @@ -using System.Collections.Concurrent; -using System.Text.Json; -using System.Text.Json.Serialization; -using System.Web; -using LibMatrix; -using LibMatrix.EventTypes.Spec.State; -using LibMatrix.Helpers; -using LibMatrix.Homeservers; -using LibMatrix.Responses; -using LibMatrix.RoomTypes; -using Microsoft.AspNetCore.Mvc; -using MxApiExtensions.Classes; -using MxApiExtensions.Classes.LibMatrix; -using MxApiExtensions.Extensions; -using MxApiExtensions.Services; - -namespace MxApiExtensions.Controllers; - -[ApiController] -[Route("/")] -public class SyncController : ControllerBase { - private readonly ILogger<SyncController> _logger; - private readonly MxApiExtensionsConfiguration _config; - private readonly AuthenticationService _auth; - private readonly AuthenticatedHomeserverProviderService _hs; - - private static readonly ConcurrentDictionary<string, SyncState> _syncStates = new(); - - public SyncController(ILogger<SyncController> logger, MxApiExtensionsConfiguration config, AuthenticationService auth, AuthenticatedHomeserverProviderService hs) { - _logger = logger; - _config = config; - _auth = auth; - _hs = hs; - } - - [HttpGet("/_matrix/client/v3/sync")] - public async Task Sync([FromQuery] string? since, [FromQuery] int timeout = 1000) { - Task? preloadTask = null; - AuthenticatedHomeserverGeneric? hs = null; - try { - hs = await _hs.GetHomeserver(); - } - catch (Exception e) { - Console.WriteLine(); - } - var qs = HttpUtility.ParseQueryString(Request.QueryString.Value!); - qs.Remove("access_token"); - - if (!_config.FastInitialSync.Enabled) { - _logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - var result = await hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}"); - await Response.WriteHttpResponse(result); - return; - } - - try { - var syncState = _syncStates.GetOrAdd(hs.AccessToken, _ => { - _logger.LogInformation("Started tracking sync state for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - return new SyncState { - IsInitialSync = string.IsNullOrWhiteSpace(since), - Homeserver = hs - }; - }); - - if (syncState.NextSyncResponse is null) { - _logger.LogInformation("Starting sync for {} on {} ({})", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - - if (syncState.IsInitialSync) { - preloadTask = EnqueuePreloadData(syncState); - } - - syncState.NextSyncResponseStartedAt = DateTime.Now; - syncState.NextSyncResponse = Task.Delay(30_000); - syncState.NextSyncResponse.ContinueWith(x => { - _logger.LogInformation("Sync for {} on {} ({}) starting", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - syncState.NextSyncResponse = hs.ClientHttpClient.GetAsync($"{Request.Path}?{qs}"); - }); - } - - if (syncState.SyncQueue.Count > 0) { - _logger.LogInformation("Sync for {} on {} ({}) has {} queued results", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, syncState.SyncQueue.Count); - syncState.SyncQueue.TryDequeue(out var result); - - Response.StatusCode = StatusCodes.Status200OK; - Response.ContentType = "application/json"; - await Response.StartAsync(); - await JsonSerializer.SerializeAsync(Response.Body, result, new JsonSerializerOptions { - WriteIndented = true, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull - }); - await Response.CompleteAsync(); - return; - } - - timeout = Math.Clamp(timeout, 0, 100); - _logger.LogInformation("Sync for {} on {} ({}) is still running, waiting for {}ms, {} elapsed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken, timeout, - DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)); - - try { - await syncState.NextSyncResponse.WaitAsync(TimeSpan.FromMilliseconds(timeout)); - } - catch { } - - if (syncState.NextSyncResponse is Task<HttpResponseMessage> { IsCompleted: true } response) { - _logger.LogInformation("Sync for {} on {} ({}) completed", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - var resp = await response; - await Response.WriteHttpResponse(resp); - return; - } - - // await Task.Delay(timeout); - _logger.LogInformation("Sync for {} on {} ({}): sending bogus response", hs.WhoAmI.UserId, hs.ServerName, hs.AccessToken); - Response.StatusCode = StatusCodes.Status200OK; - Response.ContentType = "application/json"; - await Response.StartAsync(); - var SyncResponse = new SyncResponse { - // NextBatch = "MxApiExtensions::Next" + Random.Shared.NextInt64(), - NextBatch = since ?? "", - Presence = new() { - Events = new() { - await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status}") - } - }, - Rooms = new() { - Invite = new(), - Join = new() - } - }; - await JsonSerializer.SerializeAsync(Response.Body, SyncResponse, new JsonSerializerOptions { - WriteIndented = true, - DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull - }); - await Response.CompleteAsync(); - } - catch (MxApiMatrixException e) { - _logger.LogError(e, "Error while syncing for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId, - _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken); - - Response.StatusCode = StatusCodes.Status500InternalServerError; - Response.ContentType = "application/json"; - - await Response.WriteAsJsonAsync(e.GetAsJson()); - await Response.CompleteAsync(); - } - - catch (Exception e) { - //catch SSL connection errors and retry - if (e.InnerException is HttpRequestException && e.InnerException.Message.Contains("The SSL connection could not be established")) { - _logger.LogWarning("Caught SSL connection error, retrying sync for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId, - _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken); - await Sync(since, timeout); - return; - } - - _logger.LogError(e, "Error while syncing for {} on {} ({})", _hs.GetHomeserver().Result.WhoAmI.UserId, - _hs.GetHomeserver().Result.ServerName, _hs.GetHomeserver().Result.AccessToken); - - Response.StatusCode = StatusCodes.Status500InternalServerError; - Response.ContentType = "text/plain"; - - await Response.WriteAsync(e.ToString()); - await Response.CompleteAsync(); - } - - Response.Body.Close(); - if (preloadTask is not null) - await preloadTask; - } - - private async Task EnqueuePreloadData(SyncState syncState) { - var rooms = await syncState.Homeserver.GetJoinedRooms(); - var dm_rooms = (await syncState.Homeserver.GetAccountDataAsync<Dictionary<string, List<string>>>("m.direct")).Aggregate(new List<string>(), (list, entry) => { - list.AddRange(entry.Value); - return list; - }); - - var ownHs = syncState.Homeserver.WhoAmI.UserId.Split(':')[1]; - rooms = rooms.OrderBy(x => { - if (dm_rooms.Contains(x.RoomId)) return -1; - var parts = x.RoomId.Split(':'); - if (parts[1] == ownHs) return 200; - if (HomeserverWeightEstimation.EstimatedSize.ContainsKey(parts[1])) return HomeserverWeightEstimation.EstimatedSize[parts[1]] + parts[0].Length; - return 5000; - }).ToList(); - var roomDataTasks = rooms.Select(room => EnqueueRoomData(syncState, room)).ToList(); - _logger.LogInformation("Preloading data for {} rooms on {} ({})", roomDataTasks.Count, syncState.Homeserver.ServerName, syncState.Homeserver.AccessToken); - - await Task.WhenAll(roomDataTasks); - } - - private SemaphoreSlim _roomDataSemaphore = new(4, 4); - - private async Task EnqueueRoomData(SyncState syncState, GenericRoom room) { - await _roomDataSemaphore.WaitAsync(); - var roomState = room.GetFullStateAsync(); - var timeline = await room.GetMessagesAsync(limit: 100, dir: "b"); - timeline.Chunk.Reverse(); - var SyncResponse = new SyncResponse { - Rooms = new() { - Join = new() { - { - room.RoomId, - new SyncResponse.RoomsDataStructure.JoinedRoomDataStructure { - AccountData = new() { - Events = new() - }, - Ephemeral = new() { - Events = new() - }, - State = new() { - Events = timeline.State - }, - UnreadNotifications = new() { - HighlightCount = 0, - NotificationCount = 0 - }, - Timeline = new() { - Events = timeline.Chunk, - Limited = false, - PrevBatch = timeline.Start - }, - Summary = new() { - Heroes = new(), - InvitedMemberCount = 0, - JoinedMemberCount = 1 - } - } - } - } - }, - Presence = new() { - Events = new() { - await GetStatusMessage(syncState, $"{DateTime.Now.Subtract(syncState.NextSyncResponseStartedAt)} {syncState.NextSyncResponse.Status} {room.RoomId}") - } - }, - NextBatch = "" - }; - - await foreach (var stateEvent in roomState) { - SyncResponse.Rooms.Join[room.RoomId].State.Events.Add(stateEvent); - } - - var joinRoom = SyncResponse.Rooms.Join[room.RoomId]; - joinRoom.Summary.Heroes.AddRange(joinRoom.State.Events - .Where(x => - x.Type == "m.room.member" - && x.StateKey != syncState.Homeserver.WhoAmI.UserId - && (x.TypedContent as RoomMemberEventContent).Membership == "join" - ) - .Select(x => x.StateKey)); - joinRoom.Summary.JoinedMemberCount = joinRoom.Summary.Heroes.Count; - - syncState.SyncQueue.Enqueue(SyncResponse); - _roomDataSemaphore.Release(); - } - - private async Task<StateEventResponse> GetStatusMessage(SyncState syncState, string message) { - return new StateEventResponse { - TypedContent = new PresenceEventContent { - DisplayName = "MxApiExtensions", - Presence = "online", - StatusMessage = message, - // AvatarUrl = (await syncState.Homeserver.GetProfile(syncState.Homeserver.WhoAmI.UserId)).AvatarUrl - AvatarUrl = "" - }, - Type = "m.presence", - StateKey = syncState.Homeserver.WhoAmI.UserId, - Sender = syncState.Homeserver.WhoAmI.UserId, - UserId = syncState.Homeserver.WhoAmI.UserId, - EventId = Guid.NewGuid().ToString(), - OriginServerTs = 0 - }; - } -} diff --git a/MxApiExtensions/MxApiExtensions.csproj b/MxApiExtensions/MxApiExtensions.csproj index 92474e0..b34ef78 100644 --- a/MxApiExtensions/MxApiExtensions.csproj +++ b/MxApiExtensions/MxApiExtensions.csproj @@ -10,26 +10,15 @@ <ItemGroup> <PackageReference Include="ArcaneLibs" Version="1.0.0-preview6437853305.78f6d30" /> + <PackageReference Include="EasyCompressor.LZMA" Version="1.4.0" /> <PackageReference Include="Microsoft.AspNetCore.OpenApi" Version="8.0.0-preview.7.23375.9" /> <PackageReference Include="Swashbuckle.AspNetCore" Version="6.4.0" /> </ItemGroup> - <ItemGroup> <ProjectReference Include="..\..\ArcaneLibs\ArcaneLibs\ArcaneLibs.csproj" /> <ProjectReference Include="..\..\LibMatrix\LibMatrix\LibMatrix.csproj" /> <ProjectReference Include="..\MxApiExtensions.Classes.LibMatrix\MxApiExtensions.Classes.LibMatrix.csproj" /> <ProjectReference Include="..\MxApiExtensions.Classes\MxApiExtensions.Classes.csproj" /> - - - - </ItemGroup> - - - - - - - </Project> diff --git a/MxApiExtensions/Program.cs b/MxApiExtensions/Program.cs index cdbdb59..72a5dc9 100644 --- a/MxApiExtensions/Program.cs +++ b/MxApiExtensions/Program.cs @@ -1,4 +1,7 @@ +using System.Net.Mime; +using LibMatrix; using LibMatrix.Services; +using Microsoft.AspNetCore.Diagnostics; using Microsoft.AspNetCore.Http.Timeouts; using MxApiExtensions; using MxApiExtensions.Classes.LibMatrix; @@ -8,9 +11,7 @@ var builder = WebApplication.CreateBuilder(args); // Add services to the container. -builder.Services.AddControllers().AddJsonOptions(options => { - options.JsonSerializerOptions.WriteIndented = true; -}); +builder.Services.AddControllers().AddJsonOptions(options => { options.JsonSerializerOptions.WriteIndented = true; }); // Learn more about configuring Swagger/OpenAPI at https://aka.ms/aspnetcore/swashbuckle builder.Services.AddEndpointsApiExplorer(); builder.Services.AddSwaggerGen(); @@ -47,6 +48,13 @@ builder.Services.AddRequestTimeouts(x => { }; }); +// builder.Services.AddCors(x => x.AddDefaultPolicy(y => y.AllowAnyHeader().AllowCredentials().AllowAnyOrigin().AllowAnyMethod())); +builder.Services.AddCors(options => { + options.AddPolicy( + "Open", + policy => policy.AllowAnyOrigin().AllowAnyHeader()); +}); + var app = builder.Build(); // Configure the HTTP request pipeline. @@ -56,9 +64,38 @@ if (app.Environment.IsDevelopment()) { } // app.UseHttpsRedirection(); +app.UseCors("Open"); + +app.UseExceptionHandler(exceptionHandlerApp => +{ + exceptionHandlerApp.Run(async context => + { + + var exceptionHandlerPathFeature = + context.Features.Get<IExceptionHandlerPathFeature>(); + + if (exceptionHandlerPathFeature?.Error is MatrixException mxe) { + context.Response.StatusCode = mxe.ErrorCode switch { + "M_NOT_FOUND" => StatusCodes.Status404NotFound, + _ => StatusCodes.Status500InternalServerError + }; + context.Response.ContentType = MediaTypeNames.Application.Json; + await context.Response.WriteAsync(mxe.GetAsJson()!); + } + else { + context.Response.StatusCode = StatusCodes.Status500InternalServerError; + context.Response.ContentType = MediaTypeNames.Application.Json; + await context.Response.WriteAsync(new MxApiMatrixException() { + ErrorCode = "M_UNKNOWN", + Error = exceptionHandlerPathFeature?.Error.ToString() + }.GetAsJson()); + } + }); +}); + -app.UseAuthorization(); +// app.UseAuthorization(); app.MapControllers(); -app.Run(); +app.Run(); \ No newline at end of file diff --git a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs index dc8a8dc..e0f9db5 100644 --- a/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs +++ b/MxApiExtensions/Services/AuthenticatedHomeserverProviderService.cs @@ -1,20 +1,37 @@ +using ArcaneLibs.Extensions; using LibMatrix.Homeservers; using LibMatrix.Services; using MxApiExtensions.Classes.LibMatrix; namespace MxApiExtensions.Services; -public class AuthenticatedHomeserverProviderService { - private readonly AuthenticationService _authenticationService; - private readonly HomeserverProviderService _homeserverProviderService; - - public AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService) { - _authenticationService = authenticationService; - _homeserverProviderService = homeserverProviderService; +public class AuthenticatedHomeserverProviderService(AuthenticationService authenticationService, HomeserverProviderService homeserverProviderService, IHttpContextAccessor request) { + public async Task<RemoteHomeserver?> TryGetRemoteHomeserver() { + try { + return await GetRemoteHomeserver(); + } + catch { + return null; + } + } + + public async Task<RemoteHomeserver> GetRemoteHomeserver() { + try { + return await GetHomeserver(); + } + catch (MxApiMatrixException e) { + if (e is not { ErrorCode: "M_MISSING_TOKEN" }) throw; + if (!request.HttpContext!.Request.Headers.Keys.Any(x=>x.ToUpper() == "MXAE_UPSTREAM")) + throw new MxApiMatrixException() { + ErrorCode = "MXAE_MISSING_UPSTREAM", + Error = "[MxApiExtensions] Missing MXAE_UPSTREAM header for unauthenticated request, this should be a server_name!" + }; + return await homeserverProviderService.GetRemoteHomeserver(request.HttpContext.Request.Headers.GetByCaseInsensitiveKey("MXAE_UPSTREAM")[0]); + } } public async Task<AuthenticatedHomeserverGeneric> GetHomeserver() { - var token = _authenticationService.GetToken(); + var token = authenticationService.GetToken(); if (token == null) { throw new MxApiMatrixException { ErrorCode = "M_MISSING_TOKEN", @@ -22,7 +39,7 @@ public class AuthenticatedHomeserverProviderService { }; } - var mxid = await _authenticationService.GetMxidFromToken(token); + var mxid = await authenticationService.GetMxidFromToken(token); if (mxid == "@anonymous:*") { throw new MxApiMatrixException { ErrorCode = "M_MISSING_TOKEN", @@ -31,6 +48,6 @@ public class AuthenticatedHomeserverProviderService { } var hsCanonical = string.Join(":", mxid.Split(':').Skip(1)); - return await _homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token); + return await homeserverProviderService.GetAuthenticatedWithToken(hsCanonical, token); } } diff --git a/MxApiExtensions/Services/AuthenticationService.cs b/MxApiExtensions/Services/AuthenticationService.cs index 9eac20a..0dcc8b1 100644 --- a/MxApiExtensions/Services/AuthenticationService.cs +++ b/MxApiExtensions/Services/AuthenticationService.cs @@ -3,21 +3,11 @@ using MxApiExtensions.Classes.LibMatrix; namespace MxApiExtensions.Services; -public class AuthenticationService { - private readonly ILogger<AuthenticationService> _logger; - private readonly MxApiExtensionsConfiguration _config; - private readonly HomeserverProviderService _homeserverProviderService; - private readonly HttpRequest _request; +public class AuthenticationService(ILogger<AuthenticationService> logger, MxApiExtensionsConfiguration config, IHttpContextAccessor request, HomeserverProviderService homeserverProviderService) { + private readonly HttpRequest _request = request.HttpContext!.Request; private static Dictionary<string, string> _tokenMap = new(); - public AuthenticationService(ILogger<AuthenticationService> logger, MxApiExtensionsConfiguration config, IHttpContextAccessor request, HomeserverProviderService homeserverProviderService) { - _logger = logger; - _config = config; - _homeserverProviderService = homeserverProviderService; - _request = request.HttpContext!.Request; - } - internal string? GetToken(bool fail = true) { string? token; if (_request.Headers.TryGetValue("Authorization", out var tokens)) { @@ -59,7 +49,7 @@ public class AuthenticationService { if (_tokenMap.TryGetValue(token, out var mxid)) return mxid; var lookupTasks = new Dictionary<string, Task<string?>>(); - foreach (var homeserver in _config.AuthHomeservers) { + foreach (var homeserver in config.AuthHomeservers) { lookupTasks.Add(homeserver, GetMxidFromToken(token, homeserver)); await lookupTasks[homeserver].WaitAsync(TimeSpan.FromMilliseconds(250)); if(lookupTasks[homeserver].IsCompletedSuccessfully && !string.IsNullOrWhiteSpace(lookupTasks[homeserver].Result)) break; @@ -70,7 +60,7 @@ public class AuthenticationService { if(mxid is null) { throw new MxApiMatrixException { ErrorCode = "M_UNKNOWN_TOKEN", - Error = "Token not found on any configured homeservers: " + string.Join(", ", _config.AuthHomeservers) + Error = "Token not found on any configured homeservers: " + string.Join(", ", config.AuthHomeservers) }; } @@ -93,17 +83,17 @@ public class AuthenticationService { // // var json = (await JsonDocument.ParseAsync(await resp.Content.ReadAsStreamAsync())).RootElement; // var mxid = json.GetProperty("user_id").GetString()!; - _logger.LogInformation("Got mxid {} from token {}", mxid, token); + logger.LogInformation("Got mxid {} from token {}", mxid, token); await SaveMxidForToken(token, mxid); return mxid; } private async Task<string?> GetMxidFromToken(string token, string hsDomain) { - _logger.LogInformation("Looking up mxid for token {} on {}", token, hsDomain); - var hs = await _homeserverProviderService.GetAuthenticatedWithToken(hsDomain, token); + logger.LogInformation("Looking up mxid for token {} on {}", token, hsDomain); + var hs = await homeserverProviderService.GetAuthenticatedWithToken(hsDomain, token); try { var res = hs.WhoAmI.UserId; - _logger.LogInformation("Got mxid {} for token {} on {}", res, token, hsDomain); + logger.LogInformation("Got mxid {} for token {} on {}", res, token, hsDomain); return res; } catch (MxApiMatrixException e) { |