about summary refs log tree commit diff
path: root/MatrixUtils.Web/Classes/RmuSessionStore.cs
diff options
context:
space:
mode:
Diffstat (limited to 'MatrixUtils.Web/Classes/RmuSessionStore.cs')
-rw-r--r--MatrixUtils.Web/Classes/RmuSessionStore.cs128
1 files changed, 103 insertions, 25 deletions
diff --git a/MatrixUtils.Web/Classes/RmuSessionStore.cs b/MatrixUtils.Web/Classes/RmuSessionStore.cs

index 9df8837..1611b83 100644 --- a/MatrixUtils.Web/Classes/RmuSessionStore.cs +++ b/MatrixUtils.Web/Classes/RmuSessionStore.cs
@@ -26,6 +26,11 @@ public class RmuSessionStore( public async Task<SessionInfo?> GetSession(string sessionId) { await LoadStorage(); + if (string.IsNullOrEmpty(sessionId)) { + logger.LogWarning("No session ID provided."); + return null; + } + if (SessionCache.TryGetValue(sessionId, out var cachedSession)) return cachedSession; @@ -39,6 +44,11 @@ public class RmuSessionStore( if (CurrentSession is not null) return CurrentSession; var currentSessionId = await storageService.DataStorageProvider!.LoadObjectAsync<string>("rmu.session"); + if (currentSessionId == null) { + if (log) logger.LogWarning("No current session ID found in storage."); + return null; + } + return await GetSession(currentSessionId); } @@ -52,25 +62,31 @@ public class RmuSessionStore( SessionId = sessionId }; + await SaveStorage(); if (CurrentSession == null) await SetCurrentSession(sessionId); - else await SaveStorage(); return sessionId; } public async Task RemoveSession(string sessionId) { await LoadStorage(); - logger.LogTrace("Removing session {sessionId}.", sessionId); - var tokens = await GetAllSessions(); - if (tokens == null) { + if (SessionCache.Count == 0) { + logger.LogWarning("No sessions found."); return; } + logger.LogTrace("Removing session {sessionId}.", sessionId); + if ((await GetCurrentSession())?.SessionId == sessionId) - await SetCurrentSession(tokens.First(x => x.Key != sessionId).Key); + await SetCurrentSession(SessionCache.FirstOrDefault(x => x.Key != sessionId).Key); - if (tokens.Remove(sessionId)) - await SaveStorage(); + if (SessionCache.Remove(sessionId)) { + logger.LogInformation("RemoveSession: Removed session {sessionId}.", sessionId); + logger.LogInformation("RemoveSession: Remaining sessions: {sessionIds}.", string.Join(", ", SessionCache.Keys)); + await SaveStorage(log: true); + } + else + logger.LogWarning("RemoveSession: Session {sessionId} not found.", sessionId); } public async Task SetCurrentSession(string? sessionId) { @@ -134,6 +150,53 @@ public class RmuSessionStore( } } + public async IAsyncEnumerable<AuthenticatedHomeserverGeneric> TryGetAllHomeservers(bool log = true, bool ignoreFailures = true) { + await LoadStorage(); + if (log) logger.LogTrace("Getting all homeservers."); + var tasks = SessionCache.Values.Select(async session => { + if (ignoreFailures && session.Auth.LastFailureReason != null && session.Auth.LastFailureReason != UserAuth.FailureReason.None) { + if (log) logger.LogTrace("Skipping session {sessionId} due to previous failure: {reason}", session.SessionId, session.Auth.LastFailureReason); + return null; + } + + try { + var hs = await GetHomeserver(session.SessionId, log: false); + if (session.Auth.LastFailureReason != null) { + SessionCache[session.SessionId].Auth.LastFailureReason = null; + await SaveStorage(); + } + + return hs; + } + catch (Exception e) { + logger.LogError("TryGetAllHomeservers: Failed to get homeserver for {userId} via {homeserver}: {ex}", session.Auth.UserId, session.Auth.Homeserver, e); + var reason = SessionCache[session.SessionId].Auth.LastFailureReason = e switch { + MatrixException { ErrorCode: MatrixException.ErrorCodes.M_UNKNOWN_TOKEN } => UserAuth.FailureReason.InvalidToken, + HttpRequestException => UserAuth.FailureReason.NetworkError, + _ => UserAuth.FailureReason.UnknownError + }; + await SaveStorage(log: true); + + // await LoadStorage(true); + if (SessionCache[session.SessionId].Auth.LastFailureReason != reason) { + await Console.Error.WriteLineAsync( + $"Warning: Session {session.SessionId} failure reason changed during reload from {reason} to {SessionCache[session.SessionId].Auth.LastFailureReason}"); + } + + throw; + } + }).ToList(); + + while (tasks.Count != 0) { + var finished = await Task.WhenAny(tasks); + tasks.Remove(finished); + if (finished.IsFaulted) continue; + + var result = await finished; + if (result != null) yield return result; + } + } + #endregion #region Storage @@ -170,7 +233,8 @@ public class RmuSessionStore( CurrentSession = currentSession; } - private async Task SaveStorage() { + private async Task SaveStorage(bool log = false) { + if (log) logger.LogWarning("Saving {count} sessions to storage.", SessionCache.Count); await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions", SessionCache.ToDictionary( x => x.Key, @@ -178,6 +242,7 @@ public class RmuSessionStore( ) ); await storageService.DataStorageProvider.SaveObjectAsync("rmu.session", CurrentSession?.SessionId); + if (log) logger.LogWarning("{count} sessions saved to storage.", SessionCache.Count); } #endregion @@ -190,29 +255,42 @@ public class RmuSessionStore( } private async Task MigrateFromMru() { - logger.LogInformation("Migrating from MRU token namespace!"); var dsp = storageService.DataStorageProvider!; - if (await dsp.ObjectExistsAsync("token")) { - var oldToken = await dsp.LoadObjectAsync<UserAuth>("token"); - if (oldToken != null) { - await dsp.SaveObjectAsync("rmu.token", oldToken); - await dsp.DeleteObjectAsync("token"); + if (await dsp.ObjectExistsAsync("token") || await dsp.ObjectExistsAsync("tokens")) { + logger.LogInformation("Migrating from unnamespaced localstorage!"); + if (await dsp.ObjectExistsAsync("token")) { + var oldToken = await dsp.LoadObjectAsync<UserAuth>("token"); + if (oldToken != null) { + await dsp.SaveObjectAsync("mru.token", oldToken); + await dsp.DeleteObjectAsync("token"); + } } - } - if (await dsp.ObjectExistsAsync("tokens")) { - var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens"); - if (oldTokens != null) { - await dsp.SaveObjectAsync("rmu.tokens", oldTokens); - await dsp.DeleteObjectAsync("tokens"); + if (await dsp.ObjectExistsAsync("tokens")) { + var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens"); + if (oldTokens != null) { + await dsp.SaveObjectAsync("mru.tokens", oldTokens); + await dsp.DeleteObjectAsync("tokens"); + } } } - if (await dsp.ObjectExistsAsync("mru.tokens")) { - var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens"); - if (oldTokens != null) { - await dsp.SaveObjectAsync("rmu.tokens", oldTokens); - await dsp.DeleteObjectAsync("mru.tokens"); + if (await dsp.ObjectExistsAsync("mru.token") || await dsp.ObjectExistsAsync("mru.tokens")) { + logger.LogInformation("Migrating from MRU token namespace!"); + if (await dsp.ObjectExistsAsync("mru.token")) { + var oldToken = await dsp.LoadObjectAsync<UserAuth>("mru.token"); + if (oldToken != null) { + await dsp.SaveObjectAsync("rmu.token", oldToken); + await dsp.DeleteObjectAsync("mru.token"); + } + } + + if (await dsp.ObjectExistsAsync("mru.tokens")) { + var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens"); + if (oldTokens != null) { + await dsp.SaveObjectAsync("rmu.tokens", oldTokens); + await dsp.DeleteObjectAsync("mru.tokens"); + } } } }