summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2024-05-08 16:31:59 +0100
committerErik Johnston <erik@matrix.org>2024-05-09 10:58:00 +0100
commitca79b4d87df814ae69dd093253d500108d48e461 (patch)
treeb4f9a40b84f19450af20dad92b9e3dece3f72e5b
parentNewsfile (diff)
downloadsynapse-ca79b4d87df814ae69dd093253d500108d48e461.tar.xz
Use a sortedset instead
-rw-r--r--synapse/storage/databases/main/event_federation.py28
-rw-r--r--tests/storage/test_purge.py74
2 files changed, 84 insertions, 18 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 68f30d893c..3dd53f2038 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -39,6 +39,7 @@ from typing import (
 
 import attr
 from prometheus_client import Counter, Gauge
+from sortedcontainers import SortedSet
 
 from synapse.api.constants import MAX_DEPTH
 from synapse.api.errors import StoreError
@@ -373,24 +374,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         # We fetch the links in batches. Separate batches will likely fetch the
         # same set of links (e.g. they'll always pull in the links to create
-        # event). To try and minimize the amount of redundant links, we sort the
-        # chain IDs in reverse, as there will be a correlation between the order
-        # of chain IDs and links (i.e., higher chain IDs are more likely to
-        # depend on lower chain IDs than vice versa).
+        # event). To try and minimize the amount of redundant links, we query
+        # the chain IDs in reverse order, as there will be a correlation between
+        # the order of chain IDs and links (i.e., higher chain IDs are more
+        # likely to depend on lower chain IDs than vice versa).
         BATCH_SIZE = 1000
-        chains_to_fetch_list = list(chains_to_fetch)
-        chains_to_fetch_list.sort(reverse=True)
+        chains_to_fetch_sorted = SortedSet(chains_to_fetch)
 
-        seen_chains: Set[int] = set()
-        while chains_to_fetch_list:
-            batch2 = [
-                c for c in chains_to_fetch_list[-BATCH_SIZE:] if c not in seen_chains
-            ]
-            chains_to_fetch_list = chains_to_fetch_list[:-BATCH_SIZE]
-            while len(batch2) < BATCH_SIZE and chains_to_fetch_list:
-                chain_id = chains_to_fetch_list.pop()
-                if chain_id not in seen_chains:
-                    batch2.append(chain_id)
+        while chains_to_fetch_sorted:
+            batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE))
+            chains_to_fetch_sorted.difference_update(batch2)
 
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "origin_chain_id", batch2
@@ -409,8 +402,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                     (origin_sequence_number, target_chain_id, target_sequence_number)
                 )
 
-            seen_chains.update(links)
-            seen_chains.update(batch2)
+            chains_to_fetch_sorted.difference_update(links)
 
             yield links
 
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 080d5640a5..9fa69f6581 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -25,6 +25,7 @@ from synapse.rest.client import room
 from synapse.server import HomeServer
 from synapse.util import Clock
 
+from tests.test_utils.event_injection import inject_event
 from tests.unittest import HomeserverTestCase
 
 
@@ -128,3 +129,76 @@ class PurgeTests(HomeserverTestCase):
         self.store._invalidate_local_get_event_cache(create_event.event_id)
         self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
         self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
+
+    def test_state_groups_state_decreases(self) -> None:
+        response = self.helper.send(self.room_id, body="first")
+        first_event_id = response["event_id"]
+
+        batches = []
+
+        previous_event_id = first_event_id
+        for i in range(50):
+            state_event1 = self.get_success(
+                inject_event(
+                    self.hs,
+                    type="test.state",
+                    sender=self.user_id,
+                    state_key="",
+                    room_id=self.room_id,
+                    content={"key": i, "e": 1},
+                    prev_event_ids=[previous_event_id],
+                    origin_server_ts=1,
+                )
+            )
+
+            state_event2 = self.get_success(
+                inject_event(
+                    self.hs,
+                    type="test.state",
+                    sender=self.user_id,
+                    state_key="",
+                    room_id=self.room_id,
+                    content={"key": i, "e": 2},
+                    prev_event_ids=[previous_event_id],
+                    origin_server_ts=2,
+                )
+            )
+
+            # print(state_event2.origin_server_ts - state_event1.origin_server_ts)
+
+            message_event = self.get_success(
+                inject_event(
+                    self.hs,
+                    type="dummy_event",
+                    sender=self.user_id,
+                    room_id=self.room_id,
+                    content={},
+                    prev_event_ids=[state_event1.event_id, state_event2.event_id],
+                )
+            )
+
+            token = self.get_success(
+                self.store.get_topological_token_for_event(state_event1.event_id)
+            )
+            batches.append(token)
+
+            previous_event_id = message_event.event_id
+
+        self.helper.send(self.room_id, body="last event")
+
+        def count_state_groups() -> int:
+            sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?"
+            rows = self.get_success(
+                self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
+            )
+            return rows[0][0]
+
+        print(count_state_groups())
+        for token in batches:
+            token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
+            self.get_success(
+                self._storage_controllers.purge_events.purge_history(
+                    self.room_id, token_str, False
+                )
+            )
+            print(count_state_groups())