diff --git a/MatrixUtils.Web/Classes/RMUStorageWrapper.cs b/MatrixUtils.Web/Classes/RMUStorageWrapper.cs
deleted file mode 100644
index 1fc4dd1..0000000
--- a/MatrixUtils.Web/Classes/RMUStorageWrapper.cs
+++ /dev/null
@@ -1,156 +0,0 @@
-using LibMatrix;
-using LibMatrix.Homeservers;
-using LibMatrix.Services;
-using Microsoft.AspNetCore.Components;
-
-namespace MatrixUtils.Web.Classes;
-
-public class RMUStorageWrapper(
- ILogger<RMUStorageWrapper> logger,
- TieredStorageService storageService,
- HomeserverProviderService homeserverProviderService,
- NavigationManager navigationManager) {
- public async Task<List<UserAuth>?> GetAllTokens() {
- logger.LogTrace("Getting all tokens.");
- return await storageService.DataStorageProvider.LoadObjectAsync<List<UserAuth>>("rmu.tokens") ??
- new List<UserAuth>();
- }
-
- public async Task<UserAuth?> GetCurrentToken(bool log = true) {
- if (log) logger.LogTrace("Getting current token.");
- var currentToken = await storageService.DataStorageProvider.LoadObjectAsync<UserAuth>("rmu.token");
- var allTokens = await GetAllTokens();
- if (allTokens is null or { Count: 0 }) {
- await SetCurrentToken(null);
- return null;
- }
-
- if (currentToken is null) {
- await SetCurrentToken(currentToken = allTokens[0]);
- }
-
- if (currentToken is null) {
- await SetCurrentToken(currentToken = allTokens[0]);
- }
-
- if (!allTokens.Any(x => x.AccessToken == currentToken.AccessToken)) {
- await SetCurrentToken(currentToken = allTokens[0]);
- }
-
- return currentToken;
- }
-
- public async Task AddToken(UserAuth UserAuth) {
- logger.LogTrace("Adding token.");
- var tokens = await GetAllTokens() ?? new List<UserAuth>();
-
- tokens.Add(UserAuth);
- await storageService.DataStorageProvider.SaveObjectAsync("rmu.tokens", tokens);
- }
-
- private async Task<AuthenticatedHomeserverGeneric?> GetCurrentSession(bool log = true) {
- if (log) logger.LogTrace("Getting current session.");
- var token = await GetCurrentToken(log: false);
- if (token == null) {
- return null;
- }
-
- return await GetSession(token);
- }
-
- public async Task<AuthenticatedHomeserverGeneric?> GetSession(UserAuth userAuth, bool log = true) {
- if (log) logger.LogTrace("Getting session.");
- AuthenticatedHomeserverGeneric hs;
- try {
- hs = await homeserverProviderService.GetAuthenticatedWithToken(userAuth.Homeserver, userAuth.AccessToken, userAuth.Proxy);
- }
- catch (Exception e) {
- logger.LogError("Failed to get info for {0} via {1}: {2}", userAuth.UserId, userAuth.Homeserver, e);
- logger.LogError("Continuing with server-less session");
- hs = await homeserverProviderService.GetAuthenticatedWithToken(userAuth.Homeserver, userAuth.AccessToken, userAuth.Proxy, useGeneric: true, enableServer: false);
- }
-
- return hs;
- }
-
- public async Task<AuthenticatedHomeserverGeneric?> GetCurrentSessionOrNavigate(bool log = true) {
- if (log) logger.LogTrace("Getting current session or navigating.");
- AuthenticatedHomeserverGeneric? session = null;
-
- try {
- //catch if the token is invalid
- session = await GetCurrentSession();
- }
- catch (MatrixException e) {
- if (e.ErrorCode == "M_UNKNOWN_TOKEN") {
- var token = await GetCurrentToken();
- logger.LogWarning("Encountered invalid token for {user} on {homeserver}", token.UserId, token.Homeserver);
- navigationManager.NavigateTo("/InvalidSession?ctx=" + token.AccessToken);
- return null;
- }
-
- throw;
- }
-
- if (session is null) {
- logger.LogInformation("No session found. Navigating to login.");
- navigationManager.NavigateTo("/Login");
- }
-
- return session;
- }
-
- public class Settings {
- public DeveloperSettings DeveloperSettings { get; set; } = new();
- }
-
- public class DeveloperSettings {
- public bool EnableLogViewers { get; set; }
- public bool EnableConsoleLogging { get; set; } = true;
- public bool EnablePortableDevtools { get; set; }
- }
-
- public async Task RemoveToken(UserAuth auth) {
- logger.LogTrace("Removing token.");
- var tokens = await GetAllTokens();
- if (tokens == null) {
- return;
- }
-
- tokens.RemoveAll(x => x.AccessToken == auth.AccessToken);
- await storageService.DataStorageProvider.SaveObjectAsync("rmu.tokens", tokens);
- }
-
- public async Task SetCurrentToken(UserAuth? auth) {
- logger.LogTrace("Setting current token.");
- await storageService.DataStorageProvider.SaveObjectAsync("rmu.token", auth);
- }
-
- public async Task MigrateFromMRU() {
- logger.LogInformation("Migrating from MRU token namespace!");
- var dsp = storageService.DataStorageProvider!;
- if (await dsp.ObjectExistsAsync("token")) {
- var oldToken = await dsp.LoadObjectAsync<UserAuth>("token");
- if (oldToken != null) {
- await dsp.SaveObjectAsync("rmu.token", oldToken);
- await dsp.DeleteObjectAsync("tokens");
- }
- }
-
- if (await dsp.ObjectExistsAsync("tokens")) {
- var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens");
- if (oldTokens != null) {
- await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
- await dsp.DeleteObjectAsync("tokens");
- }
- }
-
- if (await dsp.ObjectExistsAsync("mru.tokens")) {
- var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens");
- if (oldTokens != null) {
- await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
- await dsp.DeleteObjectAsync("mru.tokens");
- }
- }
- }
-}
\ No newline at end of file
diff --git a/MatrixUtils.Web/Classes/RmuSessionStore.cs b/MatrixUtils.Web/Classes/RmuSessionStore.cs
new file mode 100644
index 0000000..14aa1db
--- /dev/null
+++ b/MatrixUtils.Web/Classes/RmuSessionStore.cs
@@ -0,0 +1,254 @@
+using System.Text.Json;
+using System.Text.Json.Nodes;
+using LibMatrix;
+using LibMatrix.Homeservers;
+using LibMatrix.Services;
+using Microsoft.AspNetCore.Components;
+
+namespace MatrixUtils.Web.Classes;
+
+public class RmuSessionStore(
+ ILogger<RmuSessionStore> logger,
+ TieredStorageService storageService,
+ HomeserverProviderService homeserverProviderService,
+ NavigationManager navigationManager) {
+ private SessionInfo? CurrentSession { get; set; }
+ private Dictionary<string, SessionInfo> SessionCache { get; set; } = [];
+
+ private bool _isInitialized = false;
+ private static readonly SemaphoreSlim InitSemaphore = new(1, 1);
+
+ public async Task EnsureInitialized() {
+ if (_isInitialized) return;
+ await InitSemaphore.WaitAsync();
+ if (_isInitialized) {
+ InitSemaphore.Release();
+ return;
+ }
+
+ try {
+ await RunMigrations();
+ SessionCache = (await GetAllSessions())
+ .Select(x => (x.Key, Value: new SessionInfo {
+ SessionId = x.Key,
+ Auth = x.Value.Auth
+ })).ToDictionary(x => x.Key, x => x.Value);
+ CurrentSession = await GetCurrentSession();
+ }
+ catch (Exception e) {
+ logger.LogError("Failed to initialize RmuSessionStore: {e}", e);
+ }
+ finally {
+ _isInitialized = true;
+ InitSemaphore.Release();
+ }
+ }
+
+#region Sessions
+
+#region Session implementation details
+
+#endregion
+
+ public async Task<Dictionary<string, SessionInfo>> GetAllSessions() {
+ logger.LogTrace("Getting all tokens.");
+ return SessionCache;
+ }
+
+ public async Task<SessionInfo?> GetSession(string sessionId) {
+ if (SessionCache.TryGetValue(sessionId, out var cachedSession))
+ return cachedSession;
+
+ logger.LogWarning("Session {sessionId} not found in all tokens.", sessionId);
+ return null;
+ }
+
+ public async Task<SessionInfo?> GetCurrentSession(bool log = true) {
+ if (log) logger.LogTrace("Getting current token.");
+ if (CurrentSession is not null) return CurrentSession;
+
+ var currentSessionId = await storageService.DataStorageProvider!.LoadObjectAsync<string>("rmu.session");
+ return await GetSession(currentSessionId);
+ }
+
+ public async Task<string> AddSession(UserAuth auth) {
+ logger.LogTrace("Adding token.");
+ // var sessions = await GetAllSessions() ?? [];
+
+ var sessionId = auth.GetHashCode().ToString();
+ // sessions.Add(sessionId, auth);
+ SessionCache[sessionId] = new() {
+ Auth = auth,
+ SessionId = sessionId
+ };
+ await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions",
+ SessionCache.ToDictionary(
+ x => x.Key,
+ x => x.Value.Auth
+ )
+ );
+
+ if (CurrentSession == null) await SetCurrentSession(sessionId);
+
+ return sessionId;
+ }
+
+ public async Task RemoveSession(string sessionId) {
+ logger.LogTrace("Removing session {sessionId}.", sessionId);
+ var tokens = await GetAllSessions();
+ if (tokens == null) {
+ return;
+ }
+
+ if ((await GetCurrentSession())?.SessionId == sessionId)
+ await SetCurrentSession(tokens.First(x => x.Key != sessionId).Key);
+
+ if (tokens.Remove(sessionId))
+ await storageService.DataStorageProvider!.SaveObjectAsync("rmu.tokens", tokens);
+ }
+
+ public async Task SetCurrentSession(string? sessionId) {
+ logger.LogTrace("Setting current session to {sessionId}.", sessionId);
+ CurrentSession = await GetSession(sessionId);
+ await storageService.DataStorageProvider!.SaveObjectAsync("rmu.session", sessionId);
+ }
+
+#endregion
+
+#region Homeservers
+
+ public async Task<AuthenticatedHomeserverGeneric?> GetHomeserver(string session, bool log = true) {
+ if (log) logger.LogTrace("Getting session.");
+ if (!SessionCache.TryGetValue(session, out var cachedSession)) return null;
+ if (cachedSession.Homeserver is not null) return cachedSession.Homeserver;
+
+ try {
+ cachedSession.Homeserver =
+ await homeserverProviderService.GetAuthenticatedWithToken(cachedSession.Auth.Homeserver, cachedSession.Auth.AccessToken, cachedSession.Auth.Proxy);
+ }
+ catch (Exception e) {
+ logger.LogError("Failed to get info for {0} via {1}: {2}", cachedSession.Auth.UserId, cachedSession.Auth.Homeserver, e);
+ logger.LogError("Continuing with server-less session");
+ cachedSession.Homeserver = await homeserverProviderService.GetAuthenticatedWithToken(cachedSession.Auth.Homeserver, cachedSession.Auth.AccessToken,
+ cachedSession.Auth.Proxy, useGeneric: true, enableServer: false);
+ }
+
+ return cachedSession.Homeserver;
+ }
+
+ public async Task<AuthenticatedHomeserverGeneric?> GetCurrentHomeserver(bool log = true, bool navigateOnFailure = false) {
+ if (log) logger.LogTrace("Getting current session.");
+ if (CurrentSession?.Homeserver is not null) return CurrentSession.Homeserver;
+
+ var currentSession = CurrentSession ??= await GetCurrentSession(log: false);
+ if (currentSession == null) {
+ if (navigateOnFailure) {
+ logger.LogInformation("No session found. Navigating to login.");
+ navigationManager.NavigateTo("/Login");
+ }
+
+ return null;
+ }
+
+ try {
+ return currentSession.Homeserver ??= await GetHomeserver(currentSession.SessionId);
+ }
+ catch (MatrixException e) {
+ if (e.ErrorCode == "M_UNKNOWN_TOKEN" && navigateOnFailure) {
+ logger.LogWarning("Encountered invalid token for {user} on {homeserver}", currentSession.Auth.UserId, currentSession.Auth.Homeserver);
+ if (navigateOnFailure) {
+ navigationManager.NavigateTo("/InvalidSession?ctx=" + currentSession.SessionId);
+ }
+ }
+
+ throw;
+ }
+ }
+
+#endregion
+
+#region Internal
+
+ public async Task RunMigrations() {
+ await MigrateFromMRU();
+ await MigrateAccountsToKeyedStorage();
+ }
+
+#region Migrations
+
+ private async Task MigrateFromMRU() {
+ logger.LogInformation("Migrating from MRU token namespace!");
+ var dsp = storageService.DataStorageProvider!;
+ if (await dsp.ObjectExistsAsync("token")) {
+ var oldToken = await dsp.LoadObjectAsync<UserAuth>("token");
+ if (oldToken != null) {
+ await dsp.SaveObjectAsync("rmu.token", oldToken);
+ await dsp.DeleteObjectAsync("tokens");
+ }
+ }
+
+ if (await dsp.ObjectExistsAsync("tokens")) {
+ var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens");
+ if (oldTokens != null) {
+ await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
+ await dsp.DeleteObjectAsync("tokens");
+ }
+ }
+
+ if (await dsp.ObjectExistsAsync("mru.tokens")) {
+ var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens");
+ if (oldTokens != null) {
+ await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
+ await dsp.DeleteObjectAsync("mru.tokens");
+ }
+ }
+ }
+
+ private async Task MigrateAccountsToKeyedStorage() {
+ logger.LogInformation("Migrating accounts to keyed storage!");
+ var dsp = storageService.DataStorageProvider!;
+ if (await dsp.ObjectExistsAsync("rmu.tokens")) {
+ var tokens = await dsp.LoadObjectAsync<JsonNode>("rmu.tokens") ?? throw new Exception("Failed to load tokens");
+ if (tokens is JsonArray array) {
+ var keyedTokens = array
+ .Deserialize<UserAuth[]>()!
+ .ToDictionary(x => x.GetHashCode().ToString(), x => x);
+ await dsp.SaveObjectAsync("rmu.sessions", keyedTokens);
+ await dsp.DeleteObjectAsync("rmu.tokens");
+ }
+ }
+
+ if (await dsp.ObjectExistsAsync("rmu.token")) {
+ var token = await dsp.LoadObjectAsync<UserAuth>("rmu.token") ?? throw new Exception("Failed to load tokens");
+ var sessionId = (await GetAllSessions())
+ .FirstOrDefault(x => x.Value.Equals(token)).Key;
+
+ if (sessionId is not null) {
+ await dsp.SaveObjectAsync("rmu.session", sessionId);
+ }
+ else AddSession(token);
+
+ await dsp.DeleteObjectAsync("rmu.token");
+ }
+ }
+
+#endregion
+
+#endregion
+
+ public class Settings {
+ public DeveloperSettings DeveloperSettings { get; set; } = new();
+ }
+
+ public class DeveloperSettings {
+ public bool EnableLogViewers { get; set; }
+ public bool EnableConsoleLogging { get; set; } = true;
+ public bool EnablePortableDevtools { get; set; }
+ }
+
+ public class SessionInfo {
+ public required string SessionId { get; set; }
+ public required UserAuth Auth { get; set; }
+ public AuthenticatedHomeserverGeneric? Homeserver { get; set; }
+ }
+}
\ No newline at end of file
|