about summary refs log tree commit diff
path: root/MatrixUtils.Web/Classes
diff options
context:
space:
mode:
Diffstat (limited to 'MatrixUtils.Web/Classes')
-rw-r--r--MatrixUtils.Web/Classes/RMUStorageWrapper.cs156
-rw-r--r--MatrixUtils.Web/Classes/RmuSessionStore.cs254
2 files changed, 254 insertions, 156 deletions
diff --git a/MatrixUtils.Web/Classes/RMUStorageWrapper.cs b/MatrixUtils.Web/Classes/RMUStorageWrapper.cs
deleted file mode 100644

index 1fc4dd1..0000000 --- a/MatrixUtils.Web/Classes/RMUStorageWrapper.cs +++ /dev/null
@@ -1,156 +0,0 @@ -using LibMatrix; -using LibMatrix.Homeservers; -using LibMatrix.Services; -using Microsoft.AspNetCore.Components; - -namespace MatrixUtils.Web.Classes; - -public class RMUStorageWrapper( - ILogger<RMUStorageWrapper> logger, - TieredStorageService storageService, - HomeserverProviderService homeserverProviderService, - NavigationManager navigationManager) { - public async Task<List<UserAuth>?> GetAllTokens() { - logger.LogTrace("Getting all tokens."); - return await storageService.DataStorageProvider.LoadObjectAsync<List<UserAuth>>("rmu.tokens") ?? - new List<UserAuth>(); - } - - public async Task<UserAuth?> GetCurrentToken(bool log = true) { - if (log) logger.LogTrace("Getting current token."); - var currentToken = await storageService.DataStorageProvider.LoadObjectAsync<UserAuth>("rmu.token"); - var allTokens = await GetAllTokens(); - if (allTokens is null or { Count: 0 }) { - await SetCurrentToken(null); - return null; - } - - if (currentToken is null) { - await SetCurrentToken(currentToken = allTokens[0]); - } - - if (currentToken is null) { - await SetCurrentToken(currentToken = allTokens[0]); - } - - if (!allTokens.Any(x => x.AccessToken == currentToken.AccessToken)) { - await SetCurrentToken(currentToken = allTokens[0]); - } - - return currentToken; - } - - public async Task AddToken(UserAuth UserAuth) { - logger.LogTrace("Adding token."); - var tokens = await GetAllTokens() ?? new List<UserAuth>(); - - tokens.Add(UserAuth); - await storageService.DataStorageProvider.SaveObjectAsync("rmu.tokens", tokens); - } - - private async Task<AuthenticatedHomeserverGeneric?> GetCurrentSession(bool log = true) { - if (log) logger.LogTrace("Getting current session."); - var token = await GetCurrentToken(log: false); - if (token == null) { - return null; - } - - return await GetSession(token); - } - - public async Task<AuthenticatedHomeserverGeneric?> GetSession(UserAuth userAuth, bool log = true) { - if (log) logger.LogTrace("Getting session."); - AuthenticatedHomeserverGeneric hs; - try { - hs = await homeserverProviderService.GetAuthenticatedWithToken(userAuth.Homeserver, userAuth.AccessToken, userAuth.Proxy); - } - catch (Exception e) { - logger.LogError("Failed to get info for {0} via {1}: {2}", userAuth.UserId, userAuth.Homeserver, e); - logger.LogError("Continuing with server-less session"); - hs = await homeserverProviderService.GetAuthenticatedWithToken(userAuth.Homeserver, userAuth.AccessToken, userAuth.Proxy, useGeneric: true, enableServer: false); - } - - return hs; - } - - public async Task<AuthenticatedHomeserverGeneric?> GetCurrentSessionOrNavigate(bool log = true) { - if (log) logger.LogTrace("Getting current session or navigating."); - AuthenticatedHomeserverGeneric? session = null; - - try { - //catch if the token is invalid - session = await GetCurrentSession(); - } - catch (MatrixException e) { - if (e.ErrorCode == "M_UNKNOWN_TOKEN") { - var token = await GetCurrentToken(); - logger.LogWarning("Encountered invalid token for {user} on {homeserver}", token.UserId, token.Homeserver); - navigationManager.NavigateTo("/InvalidSession?ctx=" + token.AccessToken); - return null; - } - - throw; - } - - if (session is null) { - logger.LogInformation("No session found. Navigating to login."); - navigationManager.NavigateTo("/Login"); - } - - return session; - } - - public class Settings { - public DeveloperSettings DeveloperSettings { get; set; } = new(); - } - - public class DeveloperSettings { - public bool EnableLogViewers { get; set; } - public bool EnableConsoleLogging { get; set; } = true; - public bool EnablePortableDevtools { get; set; } - } - - public async Task RemoveToken(UserAuth auth) { - logger.LogTrace("Removing token."); - var tokens = await GetAllTokens(); - if (tokens == null) { - return; - } - - tokens.RemoveAll(x => x.AccessToken == auth.AccessToken); - await storageService.DataStorageProvider.SaveObjectAsync("rmu.tokens", tokens); - } - - public async Task SetCurrentToken(UserAuth? auth) { - logger.LogTrace("Setting current token."); - await storageService.DataStorageProvider.SaveObjectAsync("rmu.token", auth); - } - - public async Task MigrateFromMRU() { - logger.LogInformation("Migrating from MRU token namespace!"); - var dsp = storageService.DataStorageProvider!; - if (await dsp.ObjectExistsAsync("token")) { - var oldToken = await dsp.LoadObjectAsync<UserAuth>("token"); - if (oldToken != null) { - await dsp.SaveObjectAsync("rmu.token", oldToken); - await dsp.DeleteObjectAsync("tokens"); - } - } - - if (await dsp.ObjectExistsAsync("tokens")) { - var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens"); - if (oldTokens != null) { - await dsp.SaveObjectAsync("rmu.tokens", oldTokens); - await dsp.DeleteObjectAsync("tokens"); - } - } - - if (await dsp.ObjectExistsAsync("mru.tokens")) { - var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens"); - if (oldTokens != null) { - await dsp.SaveObjectAsync("rmu.tokens", oldTokens); - await dsp.DeleteObjectAsync("mru.tokens"); - } - } - } -} \ No newline at end of file diff --git a/MatrixUtils.Web/Classes/RmuSessionStore.cs b/MatrixUtils.Web/Classes/RmuSessionStore.cs new file mode 100644
index 0000000..14aa1db --- /dev/null +++ b/MatrixUtils.Web/Classes/RmuSessionStore.cs
@@ -0,0 +1,254 @@ +using System.Text.Json; +using System.Text.Json.Nodes; +using LibMatrix; +using LibMatrix.Homeservers; +using LibMatrix.Services; +using Microsoft.AspNetCore.Components; + +namespace MatrixUtils.Web.Classes; + +public class RmuSessionStore( + ILogger<RmuSessionStore> logger, + TieredStorageService storageService, + HomeserverProviderService homeserverProviderService, + NavigationManager navigationManager) { + private SessionInfo? CurrentSession { get; set; } + private Dictionary<string, SessionInfo> SessionCache { get; set; } = []; + + private bool _isInitialized = false; + private static readonly SemaphoreSlim InitSemaphore = new(1, 1); + + public async Task EnsureInitialized() { + if (_isInitialized) return; + await InitSemaphore.WaitAsync(); + if (_isInitialized) { + InitSemaphore.Release(); + return; + } + + try { + await RunMigrations(); + SessionCache = (await GetAllSessions()) + .Select(x => (x.Key, Value: new SessionInfo { + SessionId = x.Key, + Auth = x.Value.Auth + })).ToDictionary(x => x.Key, x => x.Value); + CurrentSession = await GetCurrentSession(); + } + catch (Exception e) { + logger.LogError("Failed to initialize RmuSessionStore: {e}", e); + } + finally { + _isInitialized = true; + InitSemaphore.Release(); + } + } + +#region Sessions + +#region Session implementation details + +#endregion + + public async Task<Dictionary<string, SessionInfo>> GetAllSessions() { + logger.LogTrace("Getting all tokens."); + return SessionCache; + } + + public async Task<SessionInfo?> GetSession(string sessionId) { + if (SessionCache.TryGetValue(sessionId, out var cachedSession)) + return cachedSession; + + logger.LogWarning("Session {sessionId} not found in all tokens.", sessionId); + return null; + } + + public async Task<SessionInfo?> GetCurrentSession(bool log = true) { + if (log) logger.LogTrace("Getting current token."); + if (CurrentSession is not null) return CurrentSession; + + var currentSessionId = await storageService.DataStorageProvider!.LoadObjectAsync<string>("rmu.session"); + return await GetSession(currentSessionId); + } + + public async Task<string> AddSession(UserAuth auth) { + logger.LogTrace("Adding token."); + // var sessions = await GetAllSessions() ?? []; + + var sessionId = auth.GetHashCode().ToString(); + // sessions.Add(sessionId, auth); + SessionCache[sessionId] = new() { + Auth = auth, + SessionId = sessionId + }; + await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions", + SessionCache.ToDictionary( + x => x.Key, + x => x.Value.Auth + ) + ); + + if (CurrentSession == null) await SetCurrentSession(sessionId); + + return sessionId; + } + + public async Task RemoveSession(string sessionId) { + logger.LogTrace("Removing session {sessionId}.", sessionId); + var tokens = await GetAllSessions(); + if (tokens == null) { + return; + } + + if ((await GetCurrentSession())?.SessionId == sessionId) + await SetCurrentSession(tokens.First(x => x.Key != sessionId).Key); + + if (tokens.Remove(sessionId)) + await storageService.DataStorageProvider!.SaveObjectAsync("rmu.tokens", tokens); + } + + public async Task SetCurrentSession(string? sessionId) { + logger.LogTrace("Setting current session to {sessionId}.", sessionId); + CurrentSession = await GetSession(sessionId); + await storageService.DataStorageProvider!.SaveObjectAsync("rmu.session", sessionId); + } + +#endregion + +#region Homeservers + + public async Task<AuthenticatedHomeserverGeneric?> GetHomeserver(string session, bool log = true) { + if (log) logger.LogTrace("Getting session."); + if (!SessionCache.TryGetValue(session, out var cachedSession)) return null; + if (cachedSession.Homeserver is not null) return cachedSession.Homeserver; + + try { + cachedSession.Homeserver = + await homeserverProviderService.GetAuthenticatedWithToken(cachedSession.Auth.Homeserver, cachedSession.Auth.AccessToken, cachedSession.Auth.Proxy); + } + catch (Exception e) { + logger.LogError("Failed to get info for {0} via {1}: {2}", cachedSession.Auth.UserId, cachedSession.Auth.Homeserver, e); + logger.LogError("Continuing with server-less session"); + cachedSession.Homeserver = await homeserverProviderService.GetAuthenticatedWithToken(cachedSession.Auth.Homeserver, cachedSession.Auth.AccessToken, + cachedSession.Auth.Proxy, useGeneric: true, enableServer: false); + } + + return cachedSession.Homeserver; + } + + public async Task<AuthenticatedHomeserverGeneric?> GetCurrentHomeserver(bool log = true, bool navigateOnFailure = false) { + if (log) logger.LogTrace("Getting current session."); + if (CurrentSession?.Homeserver is not null) return CurrentSession.Homeserver; + + var currentSession = CurrentSession ??= await GetCurrentSession(log: false); + if (currentSession == null) { + if (navigateOnFailure) { + logger.LogInformation("No session found. Navigating to login."); + navigationManager.NavigateTo("/Login"); + } + + return null; + } + + try { + return currentSession.Homeserver ??= await GetHomeserver(currentSession.SessionId); + } + catch (MatrixException e) { + if (e.ErrorCode == "M_UNKNOWN_TOKEN" && navigateOnFailure) { + logger.LogWarning("Encountered invalid token for {user} on {homeserver}", currentSession.Auth.UserId, currentSession.Auth.Homeserver); + if (navigateOnFailure) { + navigationManager.NavigateTo("/InvalidSession?ctx=" + currentSession.SessionId); + } + } + + throw; + } + } + +#endregion + +#region Internal + + public async Task RunMigrations() { + await MigrateFromMRU(); + await MigrateAccountsToKeyedStorage(); + } + +#region Migrations + + private async Task MigrateFromMRU() { + logger.LogInformation("Migrating from MRU token namespace!"); + var dsp = storageService.DataStorageProvider!; + if (await dsp.ObjectExistsAsync("token")) { + var oldToken = await dsp.LoadObjectAsync<UserAuth>("token"); + if (oldToken != null) { + await dsp.SaveObjectAsync("rmu.token", oldToken); + await dsp.DeleteObjectAsync("tokens"); + } + } + + if (await dsp.ObjectExistsAsync("tokens")) { + var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens"); + if (oldTokens != null) { + await dsp.SaveObjectAsync("rmu.tokens", oldTokens); + await dsp.DeleteObjectAsync("tokens"); + } + } + + if (await dsp.ObjectExistsAsync("mru.tokens")) { + var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens"); + if (oldTokens != null) { + await dsp.SaveObjectAsync("rmu.tokens", oldTokens); + await dsp.DeleteObjectAsync("mru.tokens"); + } + } + } + + private async Task MigrateAccountsToKeyedStorage() { + logger.LogInformation("Migrating accounts to keyed storage!"); + var dsp = storageService.DataStorageProvider!; + if (await dsp.ObjectExistsAsync("rmu.tokens")) { + var tokens = await dsp.LoadObjectAsync<JsonNode>("rmu.tokens") ?? throw new Exception("Failed to load tokens"); + if (tokens is JsonArray array) { + var keyedTokens = array + .Deserialize<UserAuth[]>()! + .ToDictionary(x => x.GetHashCode().ToString(), x => x); + await dsp.SaveObjectAsync("rmu.sessions", keyedTokens); + await dsp.DeleteObjectAsync("rmu.tokens"); + } + } + + if (await dsp.ObjectExistsAsync("rmu.token")) { + var token = await dsp.LoadObjectAsync<UserAuth>("rmu.token") ?? throw new Exception("Failed to load tokens"); + var sessionId = (await GetAllSessions()) + .FirstOrDefault(x => x.Value.Equals(token)).Key; + + if (sessionId is not null) { + await dsp.SaveObjectAsync("rmu.session", sessionId); + } + else AddSession(token); + + await dsp.DeleteObjectAsync("rmu.token"); + } + } + +#endregion + +#endregion + + public class Settings { + public DeveloperSettings DeveloperSettings { get; set; } = new(); + } + + public class DeveloperSettings { + public bool EnableLogViewers { get; set; } + public bool EnableConsoleLogging { get; set; } = true; + public bool EnablePortableDevtools { get; set; } + } + + public class SessionInfo { + public required string SessionId { get; set; } + public required UserAuth Auth { get; set; } + public AuthenticatedHomeserverGeneric? Homeserver { get; set; } + } +} \ No newline at end of file