summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8879.misc1
-rw-r--r--synapse/state/__init__.py4
-rw-r--r--synapse/state/v2.py9
-rw-r--r--synapse/storage/databases/main/event_federation.py4
-rw-r--r--tests/state/test_v2.py14
-rw-r--r--tests/storage/test_event_federation.py18
6 files changed, 33 insertions, 17 deletions
diff --git a/changelog.d/8879.misc b/changelog.d/8879.misc
new file mode 100644
index 0000000000..6f9516b314
--- /dev/null
+++ b/changelog.d/8879.misc
@@ -0,0 +1 @@
+Pass `room_id` to `get_auth_chain_difference`.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1fa3b280b4..84f59c7d85 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -783,7 +783,7 @@ class StateResolutionStore:
         )
 
     def get_auth_chain_difference(
-        self, state_sets: List[Set[str]]
+        self, room_id: str, state_sets: List[Set[str]]
     ) -> Awaitable[Set[str]]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
@@ -796,4 +796,4 @@ class StateResolutionStore:
             An awaitable that resolves to a set of event IDs.
         """
 
-        return self.store.get_auth_chain_difference(state_sets)
+        return self.store.get_auth_chain_difference(room_id, state_sets)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index ffc504ce77..f85124bf81 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
 
     # Also fetch all auth events that appear in only some of the state sets'
     # auth chains.
-    auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
+    auth_diff = await _get_auth_chain_difference(
+        room_id, state_sets, event_map, state_res_store
+    )
 
     full_conflicted_set = set(
         itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
 
 
 async def _get_auth_chain_difference(
+    room_id: str,
     state_sets: Sequence[StateMap[str]],
     event_map: Dict[str, EventBase],
     state_res_store: "synapse.state.StateResolutionStore",
@@ -332,7 +335,9 @@ async def _get_auth_chain_difference(
         difference_from_event_map = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
 
-    difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
+    difference = await state_res_store.get_auth_chain_difference(
+        room_id, state_sets_ids
+    )
     difference.update(difference_from_event_map)
 
     return difference
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 2e07c37340..ebffd89251 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
-    async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
+    async def get_auth_chain_difference(
+        self, room_id: str, state_sets: List[Set[str]]
+    ) -> Set[str]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
 
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index f5c6db900d..09f4f32a02 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -623,7 +623,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
 
         store = TestStateResolutionStore(persisted_events)
 
-        diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
         difference = self.successResultOf(defer.ensureDeferred(diff_d))
 
         self.assertEqual(difference, {c.event_id})
@@ -662,7 +664,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
 
         store = TestStateResolutionStore(persisted_events)
 
-        diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
         difference = self.successResultOf(defer.ensureDeferred(diff_d))
 
         self.assertEqual(difference, {d.event_id, c.event_id})
@@ -707,7 +711,9 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
 
         store = TestStateResolutionStore(persisted_events)
 
-        diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+        diff_d = _get_auth_chain_difference(
+            ROOM_ID, state_sets, unpersited_events, store
+        )
         difference = self.successResultOf(defer.ensureDeferred(diff_d))
 
         self.assertEqual(difference, {d.event_id, e.event_id})
@@ -773,7 +779,7 @@ class TestStateResolutionStore:
 
         return list(result)
 
-    def get_auth_chain_difference(self, auth_sets):
+    def get_auth_chain_difference(self, room_id, auth_sets):
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 71c21d8c75..482506d731 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -202,39 +202,41 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         # Now actually test that various combinations give the right result:
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c", "e", "f"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a", "c"}, {"b"}])
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}])
+            self.store.get_auth_chain_difference(room_id, [{"a", "c"}, {"b", "c"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "d", "e"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"c"}, {"d"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"c"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "c", "d", "e", "f"})
 
         difference = self.get_success(
-            self.store.get_auth_chain_difference([{"a"}, {"b"}, {"e"}])
+            self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}, {"e"}])
         )
         self.assertSetEqual(difference, {"a", "b"})
 
-        difference = self.get_success(self.store.get_auth_chain_difference([{"a"}]))
+        difference = self.get_success(
+            self.store.get_auth_chain_difference(room_id, [{"a"}])
+        )
         self.assertSetEqual(difference, set())