diff --git a/ArcaneLibs b/ArcaneLibs
-Subproject 3a2937246ba42cdc2c0ffc1c0c5ed9cddc2ef56
+Subproject 0fd0bd310bd32ab108671f3eb60d3c1aeb115da
diff --git a/LibMatrix/Extensions/EnumerableExtensions.cs b/LibMatrix/Extensions/EnumerableExtensions.cs
index ace2c0c..4dcf26e 100644
--- a/LibMatrix/Extensions/EnumerableExtensions.cs
+++ b/LibMatrix/Extensions/EnumerableExtensions.cs
@@ -4,9 +4,6 @@ using System.Collections.Immutable;
namespace LibMatrix.Extensions;
public static class EnumerableExtensions {
- public static int insertions = 0;
- public static int replacements = 0;
-
public static void MergeStateEventLists(this IList<StateEvent> oldState, IList<StateEvent> newState) {
// foreach (var stateEvent in newState) {
// var old = oldState.FirstOrDefault(x => x.Type == stateEvent.Type && x.StateKey == stateEvent.StateKey);
@@ -69,11 +66,9 @@ public static class EnumerableExtensions {
switch (FindIndex(e)) {
case -1:
oldState.Add(e);
- insertions++;
break;
case var index:
oldState[index] = e;
- replacements++;
break;
}
}
diff --git a/LibMatrix/Helpers/SyncHelper.cs b/LibMatrix/Helpers/SyncHelper.cs
index 9b1b921..bdbd0b4 100644
--- a/LibMatrix/Helpers/SyncHelper.cs
+++ b/LibMatrix/Helpers/SyncHelper.cs
@@ -63,7 +63,7 @@ public class SyncHelper(AuthenticatedHomeserverGeneric homeserver, ILogger? logg
public async Task<int> GetUnoptimisedStoreCount() {
if (storageProvider is null) return -1;
var keys = await storageProvider.GetAllKeysAsync();
- return keys.Count(x => !x.StartsWith("old/")) - 1;
+ return keys.Count(static x => !x.StartsWith("old/")) - 1;
}
private async Task UpdateFilterAsync() {
diff --git a/LibMatrix/Helpers/SyncStateResolver.cs b/LibMatrix/Helpers/SyncStateResolver.cs
index 5d25561..a8f5615 100644
--- a/LibMatrix/Helpers/SyncStateResolver.cs
+++ b/LibMatrix/Helpers/SyncStateResolver.cs
@@ -1,7 +1,9 @@
+using System.Collections.Concurrent;
using System.Collections.Frozen;
using System.Collections.Immutable;
using System.Diagnostics;
-using System.Text;
+using System.Text.Json;
+using System.Threading.Tasks.Dataflow;
using ArcaneLibs.Extensions;
using LibMatrix.Extensions;
using LibMatrix.Filters;
@@ -23,6 +25,19 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
private SyncHelper _syncHelper = new(homeserver, logger, storageProvider);
+ private async Task<SyncResponse?> LoadSyncResponse(string key) {
+ if (storageProvider is null) ArgumentNullException.ThrowIfNull(storageProvider);
+ var stream = await storageProvider.LoadStreamAsync(key);
+ return JsonSerializer.Deserialize<SyncResponse>(stream!, SyncResponseSerializerContext.Default.SyncResponse);
+ }
+
+ private async Task SaveSyncResponse(string key, SyncResponse value) {
+ ArgumentNullException.ThrowIfNull(storageProvider);
+ var ms = new MemoryStream();
+ await JsonSerializer.SerializeAsync(ms, value, SyncResponseSerializerContext.Default.SyncResponse);
+ await storageProvider.SaveStreamAsync(key, ms);
+ }
+
public async Task<(SyncResponse next, SyncResponse merged)> ContinueAsync(CancellationToken? cancellationToken = null) {
// copy properties
_syncHelper.Since = Since;
@@ -41,14 +56,129 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
return (sync, MergedState);
}
+ // private async IAsyncEnumerable<List<SyncResponse>> MergeP() {
+
+ // }
+
+ private async Task<SyncResponse?> OptimiseFrom(string start, Action<int, int>? progressCallback = null) {
+ var a = GetSerializedUnoptimisedResponses(start);
+ SyncResponse merged = null!;
+ int iters = 0;
+ var sw = Stopwatch.StartNew();
+ await foreach (var (key, resp) in a) {
+ if (resp is null) continue;
+ iters++;
+ // if (key == "init") _merged = resp;
+ // else _merged = await MergeSyncs(_merged, resp);
+ // Console.WriteLine($"{key} @ {resp.GetDerivedSyncTime()} -> {resp.NextBatch}");
+ }
+
+ Console.WriteLine($"OptimiseFrom {start} finished in {sw.Elapsed.TotalMilliseconds}ms with {iters} iterations");
+
+ return merged;
+ }
+
+ private async Task<List<string>> GetSerializedUnoptimisedKeysParallel(string start = "init") {
+ Dictionary<string, string> pairs = [];
+ var unoptimisedKeys = (await storageProvider.GetAllKeysAsync()).Where(static x => !x.Contains('/')).ToFrozenSet();
+ await Parallel.ForEachAsync(unoptimisedKeys, async (key, _) => {
+ var data = await storageProvider.LoadObjectAsync<SyncResponse>(key, SyncResponseSerializerContext.Default.SyncResponse);
+ if (data is null) return;
+ lock (pairs)
+ pairs.Add(key, data.NextBatch);
+ });
+
+ var serializedKeys = new List<string>();
+ var currentKey = start;
+ while (pairs.TryGetValue(currentKey, out var nextKey)) {
+ serializedKeys.Add(currentKey);
+ currentKey = nextKey;
+ }
+
+ return serializedKeys;
+ }
+
+ private async Task<SyncResponse> MergeRecursive(string[] keys, int depth = 0) {
+ if (keys.Length > 10) {
+ var newKeys = keys.Chunk((keys.Length / 2) + 1).ToArray();
+ var (left, right) = (MergeRecursive(newKeys[0], depth + 1), MergeRecursive(newKeys[1], depth + 1));
+ await Task.WhenAll(left, right);
+ return await MergeSyncs(await left, await right);
+ }
+
+ // Console.WriteLine("Hit max depth: " + depth);
+ SyncResponse merged = await LoadSyncResponse(keys[0]);
+ foreach (var key in keys[1..]) {
+ merged = await MergeSyncs(merged, await LoadSyncResponse(key));
+ }
+
+ return merged;
+ }
+
public async Task OptimiseStore(Action<int, int>? progressCallback = null) {
if (storageProvider is null) return;
if (!await storageProvider.ObjectExistsAsync("init")) return;
+ //
+ // {
+ // var a = GetSerializedUnoptimisedResponses();
+ // SyncResponse _merged = null!;
+ // await foreach (var (key, resp) in a) {
+ // if (resp is null) continue;
+ // // if (key == "init") _merged = resp;
+ // // else _merged = await MergeSyncs(_merged, resp);
+ // // Console.WriteLine($"{key} @ {resp.GetDerivedSyncTime()} -> {resp.NextBatch}");
+ // }
+ // Environment.Exit(0);
+ // }
+
+ {
+ // List<string> serialisedKeys = new(4000000);
+ // await foreach (var res in GetSerializedUnoptimisedResponses()) {
+ // if (res.resp is null) continue;
+ // serialisedKeys.Add(res.key);
+ // if (serialisedKeys.Count % 1000 == 0) _ = Console.Out.WriteAsync($"{serialisedKeys.Count}\r");
+ // }
+
+ List<string> serialisedKeys = await GetSerializedUnoptimisedKeysParallel();
+
+ await MergeRecursive(serialisedKeys.ToArray());
+
+ // var chunkSize = serialisedKeys.Count / Environment.ProcessorCount;
+ // var chunks = serialisedKeys.Chunk(chunkSize+1).Select(x => (x.First(), x.Length)).ToList();
+ // Console.WriteLine($"Got {chunks.Count} chunks:");
+ // foreach (var chunk in chunks) {
+ // Console.WriteLine($"Chunk {chunk.Item1} with length {chunk.Length}");
+ // }
+ //
+ // var mergeTasks = chunks.Select(async chunk => {
+ // var (startKey, length) = chunk;
+ // string currentKey = startKey;
+ // SyncResponse merged = await storageProvider.LoadObjectAsync<SyncResponse>(currentKey, SyncResponseSerializerContext.Default.SyncResponse);
+ // for (int i = 0; i < length; i++) {
+ // if (i % 1000 == 0) Console.Write($"{i}... \r");
+ // var newData = await storageProvider.LoadObjectAsync<SyncResponse>(currentKey, SyncResponseSerializerContext.Default.SyncResponse);
+ // merged = await MergeSyncs(merged, newData);
+ // currentKey = merged.NextBatch;
+ // }
+ //
+ // return merged;
+ // }).ToList();
+ //
+ // var mergedResults = await Task.WhenAll(mergeTasks);
+ // SyncResponse _merged = mergedResults[0];
+ // foreach (var key in mergedResults[1..]) {
+ // _merged = await MergeSyncs(_merged, key);
+ // }
+ }
+
+ Environment.Exit(0);
+
+ return;
var totalSw = Stopwatch.StartNew();
Console.Write("Optimising sync store...");
- var initLoadTask = storageProvider.LoadObjectAsync<SyncResponse>("init");
- var keys = (await storageProvider.GetAllKeysAsync()).Where(x => !x.StartsWith("old/")).ToFrozenSet();
+ var initLoadTask = LoadSyncResponse("init");
+ var keys = (await storageProvider.GetAllKeysAsync()).Where(static x => !x.StartsWith("old/")).ToFrozenSet();
var count = keys.Count - 1;
int total = count;
Console.WriteLine($"Found {count} entries to optimise in {totalSw.Elapsed}.");
@@ -60,6 +190,12 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
return;
}
+ // if (keys.Count > 100_000) {
+ // // batch data by core count
+ //
+ // return;
+ // }
+
// We back up old entries
var oldPath = $"old/{DateTimeOffset.UtcNow.ToUnixTimeMilliseconds()}";
await storageProvider.MoveObjectAsync("init", $"{oldPath}/init");
@@ -67,26 +203,38 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
var moveTasks = new List<Task>();
Dictionary<string, Dictionary<string, TimeSpan>> traces = [];
+ string[] loopTrace = new string[4];
while (keys.Contains(merged.NextBatch)) {
- Console.Write($"Merging {merged.NextBatch}, {--count} remaining... ");
+ loopTrace[0] = $"Merging {merged.NextBatch}, {--count} remaining";
var sw = Stopwatch.StartNew();
var swt = Stopwatch.StartNew();
- var next = await storageProvider.LoadObjectAsync<SyncResponse>(merged.NextBatch);
- Console.Write($"Load {sw.GetElapsedAndRestart().TotalMilliseconds}ms... ");
+ var next = await LoadSyncResponse(merged.NextBatch);
+ loopTrace[1] = $"Load {sw.GetElapsedAndRestart().TotalMilliseconds}ms";
if (next is null || merged.NextBatch == next.NextBatch) break;
- Console.Write($"Check {sw.GetElapsedAndRestart().TotalMilliseconds}ms... ");
// back up old entry
moveTasks.Add(storageProvider.MoveObjectAsync(merged.NextBatch, $"{oldPath}/{merged.NextBatch}"));
- Console.Write($"Move {sw.GetElapsedAndRestart().TotalMilliseconds}ms... ");
+
+ if (moveTasks.Count >= 250)
+ moveTasks.RemoveAll(t => t.IsCompleted);
+
+ if (moveTasks.Count >= 500) {
+ Console.Write("Reached 500 moveTasks... ");
+ moveTasks.RemoveAll(t => t.IsCompleted);
+ Console.WriteLine($"{moveTasks.Count} remaining");
+ }
var trace = new Dictionary<string, TimeSpan>();
traces[merged.NextBatch] = trace;
merged = await MergeSyncs(merged, next, trace);
- Console.Write($"Merge {sw.GetElapsedAndRestart().TotalMilliseconds}ms... ");
- Console.WriteLine($"Total {swt.Elapsed.TotalMilliseconds}ms");
- // Console.WriteLine($"Merged {merged.NextBatch}, {--count} remaining...");
- progressCallback?.Invoke(count, total);
+ loopTrace[2] = $"Merge {sw.GetElapsedAndRestart().TotalMilliseconds}ms";
+ loopTrace[3] = $"Total {swt.Elapsed.TotalMilliseconds}ms";
+
+ if (swt.ElapsedMilliseconds >= 25)
+ Console.WriteLine(string.Join("... ", loopTrace));
+
+ if (count % 50 == 0)
+ progressCallback?.Invoke(count, total);
#if WRITE_TRACE
var traceString = string.Join("\n", traces.Select(x => $"{x.Key}\t{x.Value.ToJson(indent: false, ignoreNull: true)}"));
var ms = new MemoryStream(Encoding.UTF8.GetBytes(traceString));
@@ -110,11 +258,10 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
#endif
}
- await storageProvider.SaveObjectAsync("init", merged);
+ await SaveSyncResponse("init", merged);
await Task.WhenAll(moveTasks);
Console.WriteLine($"Optimised store in {totalSw.Elapsed.TotalMilliseconds}ms");
- Console.WriteLine($"Insertions: {EnumerableExtensions.insertions}, replacements: {EnumerableExtensions.replacements}");
}
/// <summary>
@@ -197,6 +344,24 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
}
}
+ private async IAsyncEnumerable<(string key, SyncResponse? resp)> GetSerializedUnoptimisedResponses(string since = "init") {
+ if (storageProvider is null) yield break;
+ var nextKey = since;
+ var next = storageProvider.LoadObjectAsync<SyncResponse>(nextKey);
+ while (true) {
+ var data = await next;
+
+ if (data is null) break;
+ yield return (nextKey, data);
+ if (await storageProvider.ObjectExistsAsync(data.NextBatch)) {
+ nextKey = data.NextBatch;
+ }
+ else break;
+
+ next = storageProvider.LoadObjectAsync<SyncResponse>(nextKey);
+ }
+ }
+
public async Task<SyncResponse?> GetMergedUpTo(DateTime time) {
if (storageProvider is null) return null;
var unixTime = new DateTimeOffset(time.ToUniversalTime()).ToUnixTimeMilliseconds();
@@ -235,7 +400,7 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
map[checkpoint].Add(parts[2]);
}
- return map.OrderBy(x => x.Key).ToImmutableSortedDictionary(x => x.Key, x => x.Value.ToFrozenSet());
+ return map.OrderBy(static x => x.Key).ToImmutableSortedDictionary(static x => x.Key, x => x.Value.ToFrozenSet());
}
private async Task<SyncResponse> MergeSyncs(SyncResponse oldSync, SyncResponse newSync, Dictionary<string, TimeSpan>? trace = null) {
@@ -258,12 +423,13 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
var presenceTask = Task.Run(() => {
var sw = Stopwatch.StartNew();
- oldSync.Presence = MergeEventListBy(oldSync.Presence, newSync.Presence, (oldState, newState) => oldState.Sender == newState.Sender && oldState.Type == newState.Type);
+ oldSync.Presence = MergeEventListBy(oldSync.Presence, newSync.Presence,
+ static (oldState, newState) => oldState.Sender == newState.Sender && oldState.Type == newState.Type);
if (sw.ElapsedMilliseconds > 100) Console.WriteLine($"WARN: Presence took {sw.ElapsedMilliseconds}ms");
Trace("Presence", sw.GetElapsedAndRestart());
});
- var deviceOneTimeKeysTask = Task.Run(() => {
+ {
var sw = Stopwatch.StartNew();
// TODO: can this be cleaned up?
oldSync.DeviceOneTimeKeysCount ??= new();
@@ -272,7 +438,7 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
oldSync.DeviceOneTimeKeysCount[key] = value;
if (sw.ElapsedMilliseconds > 100) Console.WriteLine($"WARN: DeviceOneTimeKeysCount took {sw.ElapsedMilliseconds}ms");
Trace("DeviceOneTimeKeysCount", sw.GetElapsedAndRestart());
- });
+ }
var roomsTask = Task.Run(() => {
var sw = Stopwatch.StartNew();
@@ -284,7 +450,8 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
var toDeviceTask = Task.Run(() => {
var sw = Stopwatch.StartNew();
- oldSync.ToDevice = MergeEventList(oldSync.ToDevice, newSync.ToDevice);
+ // oldSync.ToDevice = MergeEventList(oldSync.ToDevice, newSync.ToDevice);
+ oldSync.ToDevice = AppendEventList(oldSync.ToDevice, newSync.ToDevice);
if (sw.ElapsedMilliseconds > 100) Console.WriteLine($"WARN: ToDevice took {sw.ElapsedMilliseconds}ms");
Trace("ToDevice", sw.GetElapsedAndRestart());
});
@@ -313,7 +480,7 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
Trace("DeviceLists.Left", sw.GetElapsedAndRestart());
});
- await Task.WhenAll(accountDataTask, presenceTask, deviceOneTimeKeysTask, roomsTask, toDeviceTask, deviceListsTask);
+ await Task.WhenAll(accountDataTask, presenceTask, roomsTask, toDeviceTask, deviceListsTask);
return oldSync;
}
@@ -454,7 +621,7 @@ public class SyncStateResolver(AuthenticatedHomeserverGeneric homeserver, ILogge
}
// oldState.Events.MergeStateEventLists(newState.Events);
- oldState = MergeEventListBy(oldState, newState, (oldEvt, newEvt) => oldEvt.Type == newEvt.Type && oldEvt.StateKey == newEvt.StateKey);
+ oldState = MergeEventListBy(oldState, newState, static (oldEvt, newEvt) => oldEvt.Type == newEvt.Type && oldEvt.StateKey == newEvt.StateKey);
return oldState;
}
diff --git a/LibMatrix/Interfaces/Services/IStorageProvider.cs b/LibMatrix/Interfaces/Services/IStorageProvider.cs
index fb7bb6d..f7e5488 100644
--- a/LibMatrix/Interfaces/Services/IStorageProvider.cs
+++ b/LibMatrix/Interfaces/Services/IStorageProvider.cs
@@ -1,3 +1,7 @@
+using System.Text.Json;
+using System.Text.Json.Serialization.Metadata;
+using LibMatrix.Responses;
+
namespace LibMatrix.Interfaces.Services;
public interface IStorageProvider {
@@ -17,6 +21,13 @@ public interface IStorageProvider {
Console.WriteLine($"StorageProvider<{GetType().Name}> does not implement SaveObject<T>(key, value)!");
throw new NotImplementedException();
}
+
+ public Task SaveObjectAsync<T>(string key, T value, JsonTypeInfo<T> jsonTypeInfo) {
+ Console.WriteLine($"StorageProvider<{GetType().Name}> does not implement SaveObjectAsync<T>(key, value, typeInfo), using default implementation w/ MemoryStream!");
+ var ms = new MemoryStream();
+ JsonSerializer.Serialize(ms, value, jsonTypeInfo);
+ return SaveStreamAsync(key, ms);
+ }
// load
public Task<T?> LoadObjectAsync<T>(string key) {
@@ -24,6 +35,12 @@ public interface IStorageProvider {
throw new NotImplementedException();
}
+ public async Task<T?> LoadObjectAsync<T>(string key, JsonTypeInfo<T> jsonTypeInfo) {
+ Console.WriteLine($"StorageProvider<{GetType().Name}> does not implement SaveObject<T>(key, typeInfo), using default implementation!");
+ await using var stream = await LoadStreamAsync(key);
+ return JsonSerializer.Deserialize(stream!, jsonTypeInfo);
+ }
+
// check if exists
public Task<bool> ObjectExistsAsync(string key) {
Console.WriteLine($"StorageProvider<{GetType().Name}> does not implement ObjectExists(key)!");
diff --git a/LibMatrix/LibMatrix.csproj b/LibMatrix/LibMatrix.csproj
index bbfaf38..3d10487 100644
--- a/LibMatrix/LibMatrix.csproj
+++ b/LibMatrix/LibMatrix.csproj
@@ -18,8 +18,9 @@
</ItemGroup>
<ItemGroup>
- <PackageReference Include="ArcaneLibs" Version="1.0.0-preview.20250313-104848" Condition="'$(Configuration)' == 'Release'" />
- <ProjectReference Include="..\ArcaneLibs\ArcaneLibs\ArcaneLibs.csproj" Condition="'$(Configuration)' == 'Debug'"/>
+<!-- <PackageReference Include="ArcaneLibs" Version="1.0.0-preview.20250313-104848" Condition="'$(Configuration)' == 'Release'" />-->
+<!-- <ProjectReference Include="..\ArcaneLibs\ArcaneLibs\ArcaneLibs.csproj" Condition="'$(Configuration)' == 'Debug'"/>-->
+ <ProjectReference Include="..\ArcaneLibs\ArcaneLibs\ArcaneLibs.csproj"/>
</ItemGroup>
</Project>
|