diff --git a/MatrixUtils.Web/Classes/RmuSessionStore.cs b/MatrixUtils.Web/Classes/RmuSessionStore.cs
index 14aa1db..7e5b155 100644
--- a/MatrixUtils.Web/Classes/RmuSessionStore.cs
+++ b/MatrixUtils.Web/Classes/RmuSessionStore.cs
@@ -15,47 +15,19 @@ public class RmuSessionStore(
private SessionInfo? CurrentSession { get; set; }
private Dictionary<string, SessionInfo> SessionCache { get; set; } = [];
- private bool _isInitialized = false;
+ private bool _isInitialized;
private static readonly SemaphoreSlim InitSemaphore = new(1, 1);
- public async Task EnsureInitialized() {
- if (_isInitialized) return;
- await InitSemaphore.WaitAsync();
- if (_isInitialized) {
- InitSemaphore.Release();
- return;
- }
-
- try {
- await RunMigrations();
- SessionCache = (await GetAllSessions())
- .Select(x => (x.Key, Value: new SessionInfo {
- SessionId = x.Key,
- Auth = x.Value.Auth
- })).ToDictionary(x => x.Key, x => x.Value);
- CurrentSession = await GetCurrentSession();
- }
- catch (Exception e) {
- logger.LogError("Failed to initialize RmuSessionStore: {e}", e);
- }
- finally {
- _isInitialized = true;
- InitSemaphore.Release();
- }
- }
-
#region Sessions
-#region Session implementation details
-
-#endregion
-
public async Task<Dictionary<string, SessionInfo>> GetAllSessions() {
+ await LoadStorage();
logger.LogTrace("Getting all tokens.");
return SessionCache;
}
public async Task<SessionInfo?> GetSession(string sessionId) {
+ await LoadStorage();
if (SessionCache.TryGetValue(sessionId, out var cachedSession))
return cachedSession;
@@ -64,6 +36,7 @@ public class RmuSessionStore(
}
public async Task<SessionInfo?> GetCurrentSession(bool log = true) {
+ await LoadStorage();
if (log) logger.LogTrace("Getting current token.");
if (CurrentSession is not null) return CurrentSession;
@@ -72,28 +45,23 @@ public class RmuSessionStore(
}
public async Task<string> AddSession(UserAuth auth) {
+ await LoadStorage();
logger.LogTrace("Adding token.");
- // var sessions = await GetAllSessions() ?? [];
var sessionId = auth.GetHashCode().ToString();
- // sessions.Add(sessionId, auth);
SessionCache[sessionId] = new() {
Auth = auth,
SessionId = sessionId
};
- await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions",
- SessionCache.ToDictionary(
- x => x.Key,
- x => x.Value.Auth
- )
- );
if (CurrentSession == null) await SetCurrentSession(sessionId);
+ else await SaveStorage();
return sessionId;
}
public async Task RemoveSession(string sessionId) {
+ await LoadStorage();
logger.LogTrace("Removing session {sessionId}.", sessionId);
var tokens = await GetAllSessions();
if (tokens == null) {
@@ -104,13 +72,14 @@ public class RmuSessionStore(
await SetCurrentSession(tokens.First(x => x.Key != sessionId).Key);
if (tokens.Remove(sessionId))
- await storageService.DataStorageProvider!.SaveObjectAsync("rmu.tokens", tokens);
+ await SaveStorage();
}
public async Task SetCurrentSession(string? sessionId) {
+ await LoadStorage();
logger.LogTrace("Setting current session to {sessionId}.", sessionId);
CurrentSession = await GetSession(sessionId);
- await storageService.DataStorageProvider!.SaveObjectAsync("rmu.session", sessionId);
+ await SaveStorage();
}
#endregion
@@ -118,6 +87,7 @@ public class RmuSessionStore(
#region Homeservers
public async Task<AuthenticatedHomeserverGeneric?> GetHomeserver(string session, bool log = true) {
+ await LoadStorage();
if (log) logger.LogTrace("Getting session.");
if (!SessionCache.TryGetValue(session, out var cachedSession)) return null;
if (cachedSession.Homeserver is not null) return cachedSession.Homeserver;
@@ -137,6 +107,7 @@ public class RmuSessionStore(
}
public async Task<AuthenticatedHomeserverGeneric?> GetCurrentHomeserver(bool log = true, bool navigateOnFailure = false) {
+ await LoadStorage();
if (log) logger.LogTrace("Getting current session.");
if (CurrentSession?.Homeserver is not null) return CurrentSession.Homeserver;
@@ -167,16 +138,59 @@ public class RmuSessionStore(
#endregion
-#region Internal
+#region Storage
+
+ private async Task LoadStorage(bool hasMigrated = false) {
+ if (!await storageService.DataStorageProvider!.ObjectExistsAsync("rmu.sessions") || !await storageService.DataStorageProvider.ObjectExistsAsync("rmu.session")) {
+ if (!hasMigrated)
+ await MigrateFromMRU();
+ else
+ logger.LogWarning("No sessions found in storage.");
+ return;
+ }
+
+ SessionCache = (await storageService.DataStorageProvider.LoadObjectAsync<Dictionary<string, UserAuth>>("rmu.sessions") ?? throw new Exception("Failed to load sessions"))
+ .ToDictionary(x => x.Key, x => new SessionInfo {
+ SessionId = x.Key,
+ Auth = x.Value
+ });
+
+ var currentSessionId = await storageService.DataStorageProvider.LoadObjectAsync<string>("rmu.session");
+ if (currentSessionId == null) {
+ logger.LogWarning("No current session found in storage.");
+ return;
+ }
+
+ if (!SessionCache.TryGetValue(currentSessionId, out var currentSession)) {
+ logger.LogWarning("Current session {currentSessionId} not found in storage.", currentSessionId);
+ return;
+ }
+
+ CurrentSession = currentSession;
+ }
+
+ private async Task SaveStorage() {
+ await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions",
+ SessionCache.ToDictionary(
+ x => x.Key,
+ x => x.Value.Auth
+ )
+ );
+ await storageService.DataStorageProvider.SaveObjectAsync("rmu.session", CurrentSession?.SessionId);
+ }
+
+#endregion
+
+#region Migrations
public async Task RunMigrations() {
+ await LoadStorage();
await MigrateFromMRU();
await MigrateAccountsToKeyedStorage();
}
-#region Migrations
-
private async Task MigrateFromMRU() {
+ await LoadStorage();
logger.LogInformation("Migrating from MRU token namespace!");
var dsp = storageService.DataStorageProvider!;
if (await dsp.ObjectExistsAsync("token")) {
@@ -205,6 +219,7 @@ public class RmuSessionStore(
}
private async Task MigrateAccountsToKeyedStorage() {
+ await LoadStorage();
logger.LogInformation("Migrating accounts to keyed storage!");
var dsp = storageService.DataStorageProvider!;
if (await dsp.ObjectExistsAsync("rmu.tokens")) {
@@ -234,8 +249,6 @@ public class RmuSessionStore(
#endregion
-#endregion
-
public class Settings {
public DeveloperSettings DeveloperSettings { get; set; } = new();
}
|