summary refs log tree commit diff
path: root/synapse/state
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-12-04 15:52:49 +0000
committerGitHub <noreply@github.com>2020-12-04 15:52:49 +0000
commitdf4b1e9c74d56d79c274149b0dfb0fd5305c7659 (patch)
treedfdc8d1a66cb36975ded614070288ab47f8eee06 /synapse/state
parentAdd additional validation to pusher URLs. (#8865) (diff)
downloadsynapse-df4b1e9c74d56d79c274149b0dfb0fd5305c7659.tar.xz
Pass room_id to get_auth_chain_difference (#8879)
This is so that we can choose which algorithm to use based on the room ID.

Diffstat (limited to 'synapse/state')
-rw-r--r--synapse/state/__init__.py4
-rw-r--r--synapse/state/v2.py9
2 files changed, 9 insertions, 4 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1fa3b280b4..84f59c7d85 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -783,7 +783,7 @@ class StateResolutionStore:
         )
 
     def get_auth_chain_difference(
-        self, state_sets: List[Set[str]]
+        self, room_id: str, state_sets: List[Set[str]]
     ) -> Awaitable[Set[str]]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
@@ -796,4 +796,4 @@ class StateResolutionStore:
             An awaitable that resolves to a set of event IDs.
         """
 
-        return self.store.get_auth_chain_difference(state_sets)
+        return self.store.get_auth_chain_difference(room_id, state_sets)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index ffc504ce77..f85124bf81 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
 
     # Also fetch all auth events that appear in only some of the state sets'
     # auth chains.
-    auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
+    auth_diff = await _get_auth_chain_difference(
+        room_id, state_sets, event_map, state_res_store
+    )
 
     full_conflicted_set = set(
         itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
 
 
 async def _get_auth_chain_difference(
+    room_id: str,
     state_sets: Sequence[StateMap[str]],
     event_map: Dict[str, EventBase],
     state_res_store: "synapse.state.StateResolutionStore",
@@ -332,7 +335,9 @@ async def _get_auth_chain_difference(
         difference_from_event_map = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
 
-    difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
+    difference = await state_res_store.get_auth_chain_difference(
+        room_id, state_sets_ids
+    )
     difference.update(difference_from_event_map)
 
     return difference