diff --git a/MatrixUtils.Web/Classes/LocalStorageProviderService.cs b/MatrixUtils.Web/Classes/LocalStorageProviderService.cs
index 0e99cd4..ddf3eed 100644
--- a/MatrixUtils.Web/Classes/LocalStorageProviderService.cs
+++ b/MatrixUtils.Web/Classes/LocalStorageProviderService.cs
@@ -3,26 +3,20 @@ using LibMatrix.Interfaces.Services;
namespace MatrixUtils.Web.Classes;
-public class LocalStorageProviderService : IStorageProvider {
- private readonly ILocalStorageService _localStorageService;
-
- public LocalStorageProviderService(ILocalStorageService localStorageService) {
- _localStorageService = localStorageService;
- }
-
+public class LocalStorageProviderService(ILocalStorageService localStorageService) : IStorageProvider {
Task IStorageProvider.SaveAllChildrenAsync<T>(string key, T value) {
throw new NotImplementedException();
}
Task<T?> IStorageProvider.LoadAllChildrenAsync<T>(string key) where T : default => throw new NotImplementedException();
- async Task IStorageProvider.SaveObjectAsync<T>(string key, T value) => await _localStorageService.SetItemAsync(key, value);
+ async Task IStorageProvider.SaveObjectAsync<T>(string key, T value) => await localStorageService.SetItemAsync(key, value);
- async Task<T?> IStorageProvider.LoadObjectAsync<T>(string key) where T : default => await _localStorageService.GetItemAsync<T>(key);
+ async Task<T?> IStorageProvider.LoadObjectAsync<T>(string key) where T : default => await localStorageService.GetItemAsync<T>(key);
- async Task<bool> IStorageProvider.ObjectExistsAsync(string key) => await _localStorageService.ContainKeyAsync(key);
+ async Task<bool> IStorageProvider.ObjectExistsAsync(string key) => await localStorageService.ContainKeyAsync(key);
- async Task<IEnumerable<string>> IStorageProvider.GetAllKeysAsync() => (await _localStorageService.KeysAsync()).ToList();
+ async Task<IEnumerable<string>> IStorageProvider.GetAllKeysAsync() => (await localStorageService.KeysAsync()).ToList();
- async Task IStorageProvider.DeleteObjectAsync(string key) => await _localStorageService.RemoveItemAsync(key);
+ async Task IStorageProvider.DeleteObjectAsync(string key) => await localStorageService.RemoveItemAsync(key);
}
diff --git a/MatrixUtils.Web/Classes/RmuSessionStore.cs b/MatrixUtils.Web/Classes/RmuSessionStore.cs
index 9df8837..1611b83 100644
--- a/MatrixUtils.Web/Classes/RmuSessionStore.cs
+++ b/MatrixUtils.Web/Classes/RmuSessionStore.cs
@@ -26,6 +26,11 @@ public class RmuSessionStore(
public async Task<SessionInfo?> GetSession(string sessionId) {
await LoadStorage();
+ if (string.IsNullOrEmpty(sessionId)) {
+ logger.LogWarning("No session ID provided.");
+ return null;
+ }
+
if (SessionCache.TryGetValue(sessionId, out var cachedSession))
return cachedSession;
@@ -39,6 +44,11 @@ public class RmuSessionStore(
if (CurrentSession is not null) return CurrentSession;
var currentSessionId = await storageService.DataStorageProvider!.LoadObjectAsync<string>("rmu.session");
+ if (currentSessionId == null) {
+ if (log) logger.LogWarning("No current session ID found in storage.");
+ return null;
+ }
+
return await GetSession(currentSessionId);
}
@@ -52,25 +62,31 @@ public class RmuSessionStore(
SessionId = sessionId
};
+ await SaveStorage();
if (CurrentSession == null) await SetCurrentSession(sessionId);
- else await SaveStorage();
return sessionId;
}
public async Task RemoveSession(string sessionId) {
await LoadStorage();
- logger.LogTrace("Removing session {sessionId}.", sessionId);
- var tokens = await GetAllSessions();
- if (tokens == null) {
+ if (SessionCache.Count == 0) {
+ logger.LogWarning("No sessions found.");
return;
}
+ logger.LogTrace("Removing session {sessionId}.", sessionId);
+
if ((await GetCurrentSession())?.SessionId == sessionId)
- await SetCurrentSession(tokens.First(x => x.Key != sessionId).Key);
+ await SetCurrentSession(SessionCache.FirstOrDefault(x => x.Key != sessionId).Key);
- if (tokens.Remove(sessionId))
- await SaveStorage();
+ if (SessionCache.Remove(sessionId)) {
+ logger.LogInformation("RemoveSession: Removed session {sessionId}.", sessionId);
+ logger.LogInformation("RemoveSession: Remaining sessions: {sessionIds}.", string.Join(", ", SessionCache.Keys));
+ await SaveStorage(log: true);
+ }
+ else
+ logger.LogWarning("RemoveSession: Session {sessionId} not found.", sessionId);
}
public async Task SetCurrentSession(string? sessionId) {
@@ -134,6 +150,53 @@ public class RmuSessionStore(
}
}
+ public async IAsyncEnumerable<AuthenticatedHomeserverGeneric> TryGetAllHomeservers(bool log = true, bool ignoreFailures = true) {
+ await LoadStorage();
+ if (log) logger.LogTrace("Getting all homeservers.");
+ var tasks = SessionCache.Values.Select(async session => {
+ if (ignoreFailures && session.Auth.LastFailureReason != null && session.Auth.LastFailureReason != UserAuth.FailureReason.None) {
+ if (log) logger.LogTrace("Skipping session {sessionId} due to previous failure: {reason}", session.SessionId, session.Auth.LastFailureReason);
+ return null;
+ }
+
+ try {
+ var hs = await GetHomeserver(session.SessionId, log: false);
+ if (session.Auth.LastFailureReason != null) {
+ SessionCache[session.SessionId].Auth.LastFailureReason = null;
+ await SaveStorage();
+ }
+
+ return hs;
+ }
+ catch (Exception e) {
+ logger.LogError("TryGetAllHomeservers: Failed to get homeserver for {userId} via {homeserver}: {ex}", session.Auth.UserId, session.Auth.Homeserver, e);
+ var reason = SessionCache[session.SessionId].Auth.LastFailureReason = e switch {
+ MatrixException { ErrorCode: MatrixException.ErrorCodes.M_UNKNOWN_TOKEN } => UserAuth.FailureReason.InvalidToken,
+ HttpRequestException => UserAuth.FailureReason.NetworkError,
+ _ => UserAuth.FailureReason.UnknownError
+ };
+ await SaveStorage(log: true);
+
+ // await LoadStorage(true);
+ if (SessionCache[session.SessionId].Auth.LastFailureReason != reason) {
+ await Console.Error.WriteLineAsync(
+ $"Warning: Session {session.SessionId} failure reason changed during reload from {reason} to {SessionCache[session.SessionId].Auth.LastFailureReason}");
+ }
+
+ throw;
+ }
+ }).ToList();
+
+ while (tasks.Count != 0) {
+ var finished = await Task.WhenAny(tasks);
+ tasks.Remove(finished);
+ if (finished.IsFaulted) continue;
+
+ var result = await finished;
+ if (result != null) yield return result;
+ }
+ }
+
#endregion
#region Storage
@@ -170,7 +233,8 @@ public class RmuSessionStore(
CurrentSession = currentSession;
}
- private async Task SaveStorage() {
+ private async Task SaveStorage(bool log = false) {
+ if (log) logger.LogWarning("Saving {count} sessions to storage.", SessionCache.Count);
await storageService.DataStorageProvider!.SaveObjectAsync("rmu.sessions",
SessionCache.ToDictionary(
x => x.Key,
@@ -178,6 +242,7 @@ public class RmuSessionStore(
)
);
await storageService.DataStorageProvider.SaveObjectAsync("rmu.session", CurrentSession?.SessionId);
+ if (log) logger.LogWarning("{count} sessions saved to storage.", SessionCache.Count);
}
#endregion
@@ -190,29 +255,42 @@ public class RmuSessionStore(
}
private async Task MigrateFromMru() {
- logger.LogInformation("Migrating from MRU token namespace!");
var dsp = storageService.DataStorageProvider!;
- if (await dsp.ObjectExistsAsync("token")) {
- var oldToken = await dsp.LoadObjectAsync<UserAuth>("token");
- if (oldToken != null) {
- await dsp.SaveObjectAsync("rmu.token", oldToken);
- await dsp.DeleteObjectAsync("token");
+ if (await dsp.ObjectExistsAsync("token") || await dsp.ObjectExistsAsync("tokens")) {
+ logger.LogInformation("Migrating from unnamespaced localstorage!");
+ if (await dsp.ObjectExistsAsync("token")) {
+ var oldToken = await dsp.LoadObjectAsync<UserAuth>("token");
+ if (oldToken != null) {
+ await dsp.SaveObjectAsync("mru.token", oldToken);
+ await dsp.DeleteObjectAsync("token");
+ }
}
- }
- if (await dsp.ObjectExistsAsync("tokens")) {
- var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens");
- if (oldTokens != null) {
- await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
- await dsp.DeleteObjectAsync("tokens");
+ if (await dsp.ObjectExistsAsync("tokens")) {
+ var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("tokens");
+ if (oldTokens != null) {
+ await dsp.SaveObjectAsync("mru.tokens", oldTokens);
+ await dsp.DeleteObjectAsync("tokens");
+ }
}
}
- if (await dsp.ObjectExistsAsync("mru.tokens")) {
- var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens");
- if (oldTokens != null) {
- await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
- await dsp.DeleteObjectAsync("mru.tokens");
+ if (await dsp.ObjectExistsAsync("mru.token") || await dsp.ObjectExistsAsync("mru.tokens")) {
+ logger.LogInformation("Migrating from MRU token namespace!");
+ if (await dsp.ObjectExistsAsync("mru.token")) {
+ var oldToken = await dsp.LoadObjectAsync<UserAuth>("mru.token");
+ if (oldToken != null) {
+ await dsp.SaveObjectAsync("rmu.token", oldToken);
+ await dsp.DeleteObjectAsync("mru.token");
+ }
+ }
+
+ if (await dsp.ObjectExistsAsync("mru.tokens")) {
+ var oldTokens = await dsp.LoadObjectAsync<List<UserAuth>>("mru.tokens");
+ if (oldTokens != null) {
+ await dsp.SaveObjectAsync("rmu.tokens", oldTokens);
+ await dsp.DeleteObjectAsync("mru.tokens");
+ }
}
}
}
diff --git a/MatrixUtils.Web/Classes/UserAuth.cs b/MatrixUtils.Web/Classes/UserAuth.cs
index 66476ae..16bb758 100644
--- a/MatrixUtils.Web/Classes/UserAuth.cs
+++ b/MatrixUtils.Web/Classes/UserAuth.cs
@@ -1,9 +1,11 @@
+using System.Text.Json.Serialization;
using LibMatrix.Responses;
namespace MatrixUtils.Web.Classes;
public class UserAuth : LoginResponse {
public UserAuth() { }
+
public UserAuth(LoginResponse login) {
Homeserver = login.Homeserver;
UserId = login.UserId;
@@ -12,4 +14,14 @@ public class UserAuth : LoginResponse {
}
public string? Proxy { get; set; }
-}
+
+ public FailureReason? LastFailureReason { get; set; }
+
+ [JsonConverter(typeof(JsonStringEnumConverter))]
+ public enum FailureReason {
+ None,
+ InvalidToken,
+ NetworkError,
+ UnknownError
+ }
+}
\ No newline at end of file
|