diff --git a/MatrixUtils.Web/Classes/RmuSessionStore.cs b/MatrixUtils.Web/Classes/RmuSessionStore.cs
index 9df8837..1611b83 100644
--- a/MatrixUtils.Web/Classes/RmuSessionStore.cs
+++ b/MatrixUtils.Web/Classes/RmuSessionStore.cs
@@ -26,6 +26,11 @@ public class RmuSessionStore(
public async Task<SessionInfo?> GetSession(string sessionId) {
await LoadStorage();
+ if (string.IsNullOrEmpty(sessionId)) {
+ logger.LogWarning("No session ID provided.");
+ return null;
+ }
+
if (SessionCache.TryGetValue(sessionId, out var cachedSession))
return cachedSession;
@@ -39,6 +44,11 @@ public class RmuSessionStore(
if (CurrentSession is not null) return CurrentSession;
var currentSessionId = await storageService.DataStorageProvider!.LoadObjectAsync<string>("rmu.session");
+ if (currentSessionId == null) {
+ if (log) logger.LogWarning("No current session ID found in storage.");
+ return null;
+ }
+
return await GetSession(currentSessionId);
}
@@ -52,25 +62,31 @@ public class RmuSessionStore(
SessionId = sessionId
};
+ await SaveStorage();
if (CurrentSession == null) await SetCurrentSession(sessionId);
- else await SaveStorage();
return sessionId;
}
public async Task RemoveSession(string sessionId) {
await LoadStorage();
- logger.LogTrace("Removing session {sessionId}.", sessionId);
- var tokens = await GetAllSessions();
- if (tokens == null) {
+ if (SessionCache.Count == 0) {
+ logger.LogWarning("No sessions found.");
return;
}
+ logger.LogTrace("Removing session {sessionId}.", sessionId);
+
if ((await GetCurrentSession())?.SessionId == sessionId)
- await SetCurrentSession(tokens.First(x => x.Key != sessionId).Key);
+ await SetCurrentSession(SessionCache.FirstOrDefault(x => x.Key != sessionId).Key);
- if (tokens.Remove(sessionId))
- await SaveStorage();
+ if (SessionCache.Remove(sessionId)) {
+ logger.LogInformation("RemoveSession: Removed session {sessionId}.", sessionId);
+ logger.LogInformation("RemoveSession: Remaining sessions: {sessionIds}.", string.Join(", ", SessionCache.Keys));
+ await SaveStorage(log: true);
+ }
+ else
+ logger.LogWarning("RemoveSession: Session {sessionId} not found.", sessionId);
}
public async Task SetCurrentSession(string? sessionId) {
@@ -134,6 +150,53 @@ public class RmuSessionStore(
}
}
+ public async IAsyncEnumerable<AuthenticatedHomeserverGeneric> TryGetAllHomeservers(bool log = true, bool ignoreFailures = true) {
+ await LoadStorage();
+ if (log) logger.LogTrace("Getting all homeservers.");
+ var tasks = SessionCache.Values.Select(async session => {
+ if (ignoreFailures && session.Auth.LastFailureReason != null && session.Auth.LastFailureReason != UserAuth.FailureReason.None) {
+ if (log) logger.LogTrace("Skipping session {sessionId} due to previous failure: {reason}", session.SessionId, session.Auth.LastFailureReason);
+ return null;
+ }
+
+ try {
+ var hs = await GetHomeserver(session.SessionId, log: false);
+ if (session.Auth.LastFailureReason != null) {
+ SessionCache[session.SessionId].Auth.LastFailureReason = null;
+ await SaveStorage();
+ }
+
+ return hs;
+ }
+ catch (Exception e) {
+ logger.LogError("TryGetAllHomeservers: Failed to get homeserver for {userId} via {homeserver}: {ex}", session.Auth.UserId, session.Auth.Homeserver, e);
+ var reason = SessionCache[session.SessionId].Auth.LastFailureReason = e switch {
+ MatrixException { ErrorCode: MatrixException.ErrorCodes.M_UNKNOWN_TOKEN } => UserAuth.FailureReason.InvalidToken,
+ HttpRequestException => UserAuth.FailureReason.NetworkError,
+ _ => UserAuth.FailureReason.UnknownError
+ };
+ await SaveStorage(log: true);
+
+ // await LoadStorage(true);
+ if (SessionCache[session.SessionId].Auth.LastFailureReason != reason) {
+ await Console.Error.WriteLineAsync(
+ $"Warning: Session {session.SessionId} failure reason changed during reload from {reason} to {SessionCache[session.SessionId].Auth.LastFailureReason}");
+ }
+
+ throw;
+ }
+ }).ToList();
+
+ while (tasks.Count != 0) {
+ var finished = await Task.WhenAny(tasks);
+ tasks.Remove(finished);
+ if (finished.IsFaulted) continue;
+
+ var result = await finished;
+ if (result != null) yield return result;
+ }
+ }
+
#endregion
#region Storage
@@ -170,7 +233,8 @@ public class RmuSessionStore(
CurrentSession = currentSession;
}
- private async Task SaveStorage() {
+ private async Task SaveStorage(bool log = false) {
+ if (log) logger.LogWarning("Saving {count} sessions to storage.", SessionCache.Count);
await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions",
SessionCache.ToDictionary(
x => x.Key,
@@ -178,6 +242,7 @@ public class RmuSessionStore(
)
);
await storageService.DataStorageProvider.SaveObjectAsync("rmu.session", CurrentSession?.SessionId);
+ if (log) logger.LogWarning("{count} sessions saved to storage.", SessionCache.Count);
}
#endregion
@@ -190,29 +255,42 @@ public class RmuSessionStore(
}
private async Task MigrateFromMru() {
- logger.LogInformation("Migrating from MRU token namespace!");
var dsp = storageService.DataStorageProvider!;
- if (await dsp.ObjectExistsAsync("token")) {
- var oldToken = await dsp.LoadObjectAsync<UserAuth>("token");
- if (oldToken != null) {
- await dsp.SaveObjectAsync("rmu.token", oldToken);
- await dsp.DeleteObjectAsync("token");
+ if (await dsp.ObjectExistsAsync("token") || await dsp.ObjectExistsAsync("tokens")) {
+ logger.LogInformation("Migrating from unnamespaced localstorage!");
+ if (await dsp.ObjectExistsAsync("token")) {
+ var oldToken = await dsp.LoadObjectAsync<UserAuth>("token");
+ if (oldToken != null) {
+ await dsp.SaveObjectAsync("mru.token", oldToken);
+ await dsp.DeleteObjectAsync("token");
+ }
}
- }
- if (await dsp.ObjectExistsAsync("tokens")) {
- var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens");
- if (oldTokens != null) {
- await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
- await dsp.DeleteObjectAsync("tokens");
+ if (await dsp.ObjectExistsAsync("tokens")) {
+ var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens");
+ if (oldTokens != null) {
+ await dsp.SaveObjectAsync("mru.tokens", oldTokens);
+ await dsp.DeleteObjectAsync("tokens");
+ }
}
}
- if (await dsp.ObjectExistsAsync("mru.tokens")) {
- var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens");
- if (oldTokens != null) {
- await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
- await dsp.DeleteObjectAsync("mru.tokens");
+ if (await dsp.ObjectExistsAsync("mru.token") || await dsp.ObjectExistsAsync("mru.tokens")) {
+ logger.LogInformation("Migrating from MRU token namespace!");
+ if (await dsp.ObjectExistsAsync("mru.token")) {
+ var oldToken = await dsp.LoadObjectAsync<UserAuth>("mru.token");
+ if (oldToken != null) {
+ await dsp.SaveObjectAsync("rmu.token", oldToken);
+ await dsp.DeleteObjectAsync("mru.token");
+ }
+ }
+
+ if (await dsp.ObjectExistsAsync("mru.tokens")) {
+ var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens");
+ if (oldTokens != null) {
+ await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
+ await dsp.DeleteObjectAsync("mru.tokens");
+ }
}
}
}
|