about summary refs log tree commit diff
path: root/MatrixUtils.Web/Classes/RmuSessionStore.cs
diff options
context:
space:
mode:
authorRory& <root@rory.gay>2025-04-15 18:13:56 +0200
committerRory& <root@rory.gay>2025-04-15 18:13:56 +0200
commit425d0eb88c1b04e93bc32fcd900d8fcfa3a29410 (patch)
treeebed3904bec4afac8dd2e0b876601591facbe41a /MatrixUtils.Web/Classes/RmuSessionStore.cs
parentRefactor session store (WIP) (diff)
downloadMatrixUtils-425d0eb88c1b04e93bc32fcd900d8fcfa3a29410.tar.xz
Fix session store
Diffstat (limited to 'MatrixUtils.Web/Classes/RmuSessionStore.cs')
-rw-r--r--MatrixUtils.Web/Classes/RmuSessionStore.cs105
1 files changed, 59 insertions, 46 deletions
diff --git a/MatrixUtils.Web/Classes/RmuSessionStore.cs b/MatrixUtils.Web/Classes/RmuSessionStore.cs

index 14aa1db..7e5b155 100644 --- a/MatrixUtils.Web/Classes/RmuSessionStore.cs +++ b/MatrixUtils.Web/Classes/RmuSessionStore.cs
@@ -15,47 +15,19 @@ public class RmuSessionStore( private SessionInfo? CurrentSession { get; set; } private Dictionary<string, SessionInfo> SessionCache { get; set; } = []; - private bool _isInitialized = false; + private bool _isInitialized; private static readonly SemaphoreSlim InitSemaphore = new(1, 1); - public async Task EnsureInitialized() { - if (_isInitialized) return; - await InitSemaphore.WaitAsync(); - if (_isInitialized) { - InitSemaphore.Release(); - return; - } - - try { - await RunMigrations(); - SessionCache = (await GetAllSessions()) - .Select(x => (x.Key, Value: new SessionInfo { - SessionId = x.Key, - Auth = x.Value.Auth - })).ToDictionary(x => x.Key, x => x.Value); - CurrentSession = await GetCurrentSession(); - } - catch (Exception e) { - logger.LogError("Failed to initialize RmuSessionStore: {e}", e); - } - finally { - _isInitialized = true; - InitSemaphore.Release(); - } - } - #region Sessions -#region Session implementation details - -#endregion - public async Task<Dictionary<string, SessionInfo>> GetAllSessions() { + await LoadStorage(); logger.LogTrace("Getting all tokens."); return SessionCache; } public async Task<SessionInfo?> GetSession(string sessionId) { + await LoadStorage(); if (SessionCache.TryGetValue(sessionId, out var cachedSession)) return cachedSession; @@ -64,6 +36,7 @@ public class RmuSessionStore( } public async Task<SessionInfo?> GetCurrentSession(bool log = true) { + await LoadStorage(); if (log) logger.LogTrace("Getting current token."); if (CurrentSession is not null) return CurrentSession; @@ -72,28 +45,23 @@ public class RmuSessionStore( } public async Task<string> AddSession(UserAuth auth) { + await LoadStorage(); logger.LogTrace("Adding token."); - // var sessions = await GetAllSessions() ?? []; var sessionId = auth.GetHashCode().ToString(); - // sessions.Add(sessionId, auth); SessionCache[sessionId] = new() { Auth = auth, SessionId = sessionId }; - await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions", - SessionCache.ToDictionary( - x => x.Key, - x => x.Value.Auth - ) - ); if (CurrentSession == null) await SetCurrentSession(sessionId); + else await SaveStorage(); return sessionId; } public async Task RemoveSession(string sessionId) { + await LoadStorage(); logger.LogTrace("Removing session {sessionId}.", sessionId); var tokens = await GetAllSessions(); if (tokens == null) { @@ -104,13 +72,14 @@ public class RmuSessionStore( await SetCurrentSession(tokens.First(x => x.Key != sessionId).Key); if (tokens.Remove(sessionId)) - await storageService.DataStorageProvider!.SaveObjectAsync("rmu.tokens", tokens); + await SaveStorage(); } public async Task SetCurrentSession(string? sessionId) { + await LoadStorage(); logger.LogTrace("Setting current session to {sessionId}.", sessionId); CurrentSession = await GetSession(sessionId); - await storageService.DataStorageProvider!.SaveObjectAsync("rmu.session", sessionId); + await SaveStorage(); } #endregion @@ -118,6 +87,7 @@ public class RmuSessionStore( #region Homeservers public async Task<AuthenticatedHomeserverGeneric?> GetHomeserver(string session, bool log = true) { + await LoadStorage(); if (log) logger.LogTrace("Getting session."); if (!SessionCache.TryGetValue(session, out var cachedSession)) return null; if (cachedSession.Homeserver is not null) return cachedSession.Homeserver; @@ -137,6 +107,7 @@ public class RmuSessionStore( } public async Task<AuthenticatedHomeserverGeneric?> GetCurrentHomeserver(bool log = true, bool navigateOnFailure = false) { + await LoadStorage(); if (log) logger.LogTrace("Getting current session."); if (CurrentSession?.Homeserver is not null) return CurrentSession.Homeserver; @@ -167,16 +138,59 @@ public class RmuSessionStore( #endregion -#region Internal +#region Storage + + private async Task LoadStorage(bool hasMigrated = false) { + if (!await storageService.DataStorageProvider!.ObjectExistsAsync("rmu.sessions") || !await storageService.DataStorageProvider.ObjectExistsAsync("rmu.session")) { + if (!hasMigrated) + await MigrateFromMRU(); + else + logger.LogWarning("No sessions found in storage."); + return; + } + + SessionCache = (await storageService.DataStorageProvider.LoadObjectAsync<Dictionary<string, UserAuth>>("rmu.sessions") ?? throw new Exception("Failed to load sessions")) + .ToDictionary(x => x.Key, x => new SessionInfo { + SessionId = x.Key, + Auth = x.Value + }); + + var currentSessionId = await storageService.DataStorageProvider.LoadObjectAsync<string>("rmu.session"); + if (currentSessionId == null) { + logger.LogWarning("No current session found in storage."); + return; + } + + if (!SessionCache.TryGetValue(currentSessionId, out var currentSession)) { + logger.LogWarning("Current session {currentSessionId} not found in storage.", currentSessionId); + return; + } + + CurrentSession = currentSession; + } + + private async Task SaveStorage() { + await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions", + SessionCache.ToDictionary( + x => x.Key, + x => x.Value.Auth + ) + ); + await storageService.DataStorageProvider.SaveObjectAsync("rmu.session", CurrentSession?.SessionId); + } + +#endregion + +#region Migrations public async Task RunMigrations() { + await LoadStorage(); await MigrateFromMRU(); await MigrateAccountsToKeyedStorage(); } -#region Migrations - private async Task MigrateFromMRU() { + await LoadStorage(); logger.LogInformation("Migrating from MRU token namespace!"); var dsp = storageService.DataStorageProvider!; if (await dsp.ObjectExistsAsync("token")) { @@ -205,6 +219,7 @@ public class RmuSessionStore( } private async Task MigrateAccountsToKeyedStorage() { + await LoadStorage(); logger.LogInformation("Migrating accounts to keyed storage!"); var dsp = storageService.DataStorageProvider!; if (await dsp.ObjectExistsAsync("rmu.tokens")) { @@ -234,8 +249,6 @@ public class RmuSessionStore( #endregion -#endregion - public class Settings { public DeveloperSettings DeveloperSettings { get; set; } = new(); }