about summary refs log tree commit diff
diff options
context:
space:
mode:
authorRory& <root@rory.gay>2024-09-04 05:00:48 +0200
committerRory& <root@rory.gay>2024-09-04 05:00:48 +0200
commita8d20e9d57857296e4600f44807893f4dcad72d1 (patch)
treed19e74def992078f8663f7bdc5b8133d5bf83fe3
parentSynapse admin API stuff, a mass of other changes (diff)
downloadLibMatrix-a8d20e9d57857296e4600f44807893f4dcad72d1.tar.xz
Sync optimisation changes
-rw-r--r--LibMatrix/Helpers/SyncStateResolver.cs145
1 files changed, 130 insertions, 15 deletions
diff --git a/LibMatrix/Helpers/SyncStateResolver.cs b/LibMatrix/Helpers/SyncStateResolver.cs
index e9c5938..5e34628 100644
--- a/LibMatrix/Helpers/SyncStateResolver.cs
+++ b/LibMatrix/Helpers/SyncStateResolver.cs
@@ -1,4 +1,5 @@
 using System.Collections.Frozen;
+using System.Collections.Immutable;
 using System.Diagnostics;
 using ArcaneLibs.Extensions;
 using LibMatrix.Extensions;
@@ -43,10 +44,11 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
         if (storageProvider is null) return;
         if (!await storageProvider.ObjectExistsAsync("init")) return;
 
+        var totalSw = Stopwatch.StartNew();
         Console.Write("Optimising sync store...");
         var initLoadTask = storageProvider.LoadObjectAsync<SyncResponse>("init");
-        var keys = (await storageProvider.GetAllKeysAsync()).ToFrozenSet();
-        var count = keys.Count(x => !x.StartsWith("old/")) - 1;
+        var keys = (await storageProvider.GetAllKeysAsync()).Where(x=>!x.StartsWith("old/")).ToFrozenSet();
+        var count = keys.Count - 1;
         Console.WriteLine($"Found {count} entries to optimise.");
 
         var merged = await initLoadTask;
@@ -83,6 +85,32 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
 
         await storageProvider.SaveObjectAsync("init", merged);
         await Task.WhenAll(moveTasks);
+        
+        Console.WriteLine($"Optimised store in {totalSw.Elapsed.TotalMilliseconds}ms");
+    }
+
+    /// <summary>
+    /// Remove all but initial sync and last checkpoint
+    /// </summary>
+    public async Task RemoveOldSnapshots() {
+        if(storageProvider is null) return;
+        var sw = Stopwatch.StartNew();
+
+        var map = await GetCheckpointMap();
+        if (map is null) return;
+        if(map.Count < 3) return;
+
+        var toRemove = map.Keys.Skip(1).Take(map.Count - 2).ToList();
+        Console.Write("Cleaning up old snapshots: ");
+        foreach (var key in toRemove) {
+            var path = $"old/{key}/init";
+            if (await storageProvider?.ObjectExistsAsync(path)) {
+                Console.Write($"{key}... ");
+                await storageProvider?.DeleteObjectAsync(path);
+            }
+        }
+        Console.WriteLine("Done!");
+        Console.WriteLine($"Removed {toRemove.Count} old snapshots in {sw.Elapsed.TotalMilliseconds}ms");
     }
 
     public async Task UnrollOptimisedStore() {
@@ -101,23 +129,110 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
     }
 
     public async Task dev() {
-        var keys = (await storageProvider?.GetAllKeysAsync()).ToFrozenSet();
-        var times = new Dictionary<long, List<string>>();
-        var values = keys.Select(async x => Task.Run(async () => (x, await storageProvider?.LoadObjectAsync<SyncResponse>(x)))).ToAsyncEnumerable();
-        await foreach (var task in values) {
-            var (key, data) = await task;
-            if (data is null) continue;
-            var derivTime = data.GetDerivedSyncTime();
-            if (!times.ContainsKey(derivTime)) times[derivTime] = new();
-            times[derivTime].Add(key);
+        // var keys = (await storageProvider?.GetAllKeysAsync()).ToFrozenSet();
+        // var times = new Dictionary<long, List<string>>();
+        // var values = keys.Select(async x => Task.Run(async () => (x, await storageProvider?.LoadObjectAsync<SyncResponse>(x)))).ToAsyncEnumerable();
+        // await foreach (var task in values) {
+        //     var (key, data) = await task;
+        //     if (data is null) continue;
+        //     var derivTime = data.GetDerivedSyncTime();
+        //     if (!times.ContainsKey(derivTime)) times[derivTime] = new();
+        //     times[derivTime].Add(key);
+        // }
+        //
+        // foreach (var (time, ckeys) in times.OrderBy(x => x.Key)) {
+        //     Console.WriteLine($"{time}: {ckeys.Count} keys");
+        // }
+
+        // var map = await GetCheckpointMap();
+        // if (map is null) return;
+        //
+        // var times = new Dictionary<long, List<string>>();
+        // foreach (var (time, keys) in map) {
+        //     Console.WriteLine($"{time}: {keys.Count} keys - calculating times");
+        //     Dictionary<string, Task<SyncResponse?>?> tasks = keys.ToDictionary(x => x, x => storageProvider?.LoadObjectAsync<SyncResponse>(x));
+        //     var nextKey = "init";
+        //     long lastTime = 0;
+        //     while (tasks.ContainsKey(nextKey)) {
+        //         var data = await tasks[nextKey];
+        //         if (data is null) break;
+        //         var derivTime = data.GetDerivedSyncTime();
+        //         if (derivTime == 0) derivTime = lastTime + 1;
+        //         if (!times.ContainsKey(derivTime)) times[derivTime] = new();
+        //         times[derivTime].Add(nextKey);
+        //         lastTime = derivTime;
+        //         nextKey = data.NextBatch;
+        //     }
+        // }
+        //
+        // foreach (var (time, ckeys) in times.OrderBy(x => x.Key)) {
+        //     Console.WriteLine($"{time}: {ckeys.Count} keys");
+        // }
+
+        int i = 0;
+        var sw = Stopwatch.StartNew();
+        var hist = GetSerialisedHistory();
+        await foreach (var (key, resp) in hist) {
+            if (resp is null) continue;
+            // Console.WriteLine($"[{++i}] {key} -> {resp.NextBatch} ({resp.GetDerivedSyncTime()})");
+            i++;
+        }
+        Console.WriteLine($"Iterated {i} syncResponses in {sw.Elapsed}");
+        Environment.Exit(0);
+    }
+
+    private async IAsyncEnumerable<(string key, SyncResponse? resp)> GetSerialisedHistory() {
+        if (storageProvider is null) yield break;
+        var map = await GetCheckpointMap();
+        var currentRange = map.First();
+        var nextKey = $"old/{map.First().Key}/init";
+        var next = storageProvider.LoadObjectAsync<SyncResponse>(nextKey);
+        while (true) {
+            var data = await next;
+            if (data is null) break;
+            yield return (nextKey, data);
+            if (currentRange.Value.Contains(data.NextBatch)) {
+                nextKey = $"old/{currentRange.Key}/{data.NextBatch}";
+            }
+            else if (map.Any(x => x.Value.Contains(data.NextBatch))) {
+                currentRange = map.First(x => x.Value.Contains(data.NextBatch));
+                nextKey = $"old/{currentRange.Key}/{data.NextBatch}";
+            }
+            else if (await storageProvider.ObjectExistsAsync(data.NextBatch)) {
+                nextKey = data.NextBatch;
+            }
+            else break;
+
+            next = storageProvider.LoadObjectAsync<SyncResponse>(nextKey);
+        }
+    }
+
+    public async Task<SyncResponse?> GetMergedUpTo(DateTime time) {
+        if (storageProvider is null) return null;
+        var unixTime = new DateTimeOffset(time.ToUniversalTime()).ToUnixTimeMilliseconds();
+        var map = await GetCheckpointMap();
+        if (map is null) return new();
+        var stream = GetSerialisedHistory().GetAsyncEnumerator();
+        SyncResponse? merged = await stream.MoveNextAsync() ? stream.Current.resp : null;
+
+        if (merged.GetDerivedSyncTime() > unixTime) {
+            Console.WriteLine("Initial sync is already past the target time!");
+            Console.WriteLine($"CURRENT: {merged.GetDerivedSyncTime()} (UTC: {DateTimeOffset.FromUnixTimeMilliseconds(merged.GetDerivedSyncTime())})");
+            Console.WriteLine($" TARGET: {unixTime} ({time.Kind}: {time}, UTC: {time.ToUniversalTime()})");
+            return null;
         }
 
-        foreach (var (time, ckeys) in times.OrderBy(x => x.Key)) {
-            Console.WriteLine($"{time}: {ckeys.Count} keys");
+        while (await stream.MoveNextAsync()) {
+            var (key, resp) = stream.Current;
+            if (resp is null) continue;
+            if (resp.GetDerivedSyncTime() > unixTime) break;
+            merged = MergeSyncs(merged, resp);
         }
+        
+        return merged;
     }
 
-    private async Task<Dictionary<ulong, List<string>>?> GetCheckpointMap() {
+    private async Task<ImmutableSortedDictionary<ulong, FrozenSet<string>>> GetCheckpointMap() {
         if (storageProvider is null) return null;
         var keys = (await storageProvider.GetAllKeysAsync()).ToFrozenSet();
         var map = new Dictionary<ulong, List<string>>();
@@ -132,7 +247,7 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
             map[checkpoint].Add(parts[2]);
         }
 
-        return map.OrderBy(x => x.Key).ToDictionary(x => x.Key, x => x.Value);
+        return map.OrderBy(x => x.Key).ToImmutableSortedDictionary(x => x.Key, x => x.Value.ToFrozenSet());
     }
 
     private SyncResponse MergeSyncs(SyncResponse oldSync, SyncResponse newSync) {