summary refs log tree commit diff
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2022-06-13 18:14:30 +0100
committerDavid Robertson <davidr@element.io>2022-06-13 20:41:06 +0100
commite93d5d6b1fece600ad9c84787517aee8e2ea2d77 (patch)
tree8c418993feee1520ca85a0c85821cb0f20924a91
parentUniformize spam-checker API, part 4: port other spam-checker callbacks to ret... (diff)
downloadsynapse-e93d5d6b1fece600ad9c84787517aee8e2ea2d77.tar.xz
Maybe make stateres a bit faster?
-rw-r--r--synapse/state/v2.py38
1 files changed, 23 insertions, 15 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 6a16f38a15..255173864b 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -29,6 +29,7 @@ from typing import (
     Sequence,
     Set,
     Tuple,
+    cast,
     overload,
 )
 
@@ -119,20 +120,23 @@ async def resolve_events_with_store(
     if not conflicted_state:
         return unconflicted_state
 
-    logger.debug("%d conflicted state entries", len(conflicted_state))
+    logger.debug("%d events in conflicted state set", len(conflicted_state))
     logger.debug("Calculating auth chain difference")
 
     # Also fetch all auth events that appear in only some of the state sets'
     # auth chains.
+    conflicts_in_state_sets = [
+        # Note: we discard events in the unconflicted state, because they will never
+        # contribute to the auth difference. See
+        # https://github.com/matrix-org/matrix-spec/issues/1118
+        {k: v for k, v in state_set.items() if v in conflicted_state}
+        for state_set in state_sets
+    ]
     auth_diff = await _get_auth_chain_difference(
-        room_id, state_sets, event_map, state_res_store
+        room_id, conflicts_in_state_sets, event_map, state_res_store
     )
 
-    full_conflicted_set = set(
-        itertools.chain(
-            itertools.chain.from_iterable(conflicted_state.values()), auth_diff
-        )
-    )
+    full_conflicted_set = set(itertools.chain(conflicted_state, auth_diff))
 
     events = await state_res_store.get_events(
         [eid for eid in full_conflicted_set if eid not in event_map],
@@ -376,7 +380,7 @@ async def _get_auth_chain_difference(
 
 def _seperate(
     state_sets: Iterable[StateMap[str]],
-) -> Tuple[StateMap[str], StateMap[Set[str]]]:
+) -> Tuple[StateMap[str], Set[str]]:
     """Return the unconflicted and conflicted state. This is different than in
     the original algorithm, as this defines a key to be conflicted if one of
     the state sets doesn't have that key.
@@ -388,20 +392,24 @@ def _seperate(
         A tuple of unconflicted and conflicted state. The conflicted state dict
         is a map from type/state_key to set of event IDs
     """
-    unconflicted_state = {}
-    conflicted_state = {}
+    unconflicted_state: MutableStateMap[str] = {}
+    conflicted_state: Set[str] = set()
 
     for key in set(itertools.chain.from_iterable(state_sets)):
         event_ids = {state_set.get(key) for state_set in state_sets}
         if len(event_ids) == 1:
-            unconflicted_state[key] = event_ids.pop()
+            # Cast safety: mypy warns that event_ids.pop() could be None. Because `key`
+            # comes from one of the `state_sets`, `event_ids` always contains a non-None
+            # event ID. Since len(event_ids) is 1, there isn't room to have anything
+            # else in the set. Thus the call to `event_ids.pop()` returns a string.
+            unconflicted_state[key] = cast(str, event_ids.pop())
         else:
             event_ids.discard(None)
-            conflicted_state[key] = event_ids
+            # Cast safety: mypy can't infer that discarding None above means that
+            # event_ids is Set[str] and not Set[Optional[str]].
+            conflicted_state.update(cast(Set[str], event_ids))
 
-    # mypy doesn't understand that discarding None above means that conflicted
-    # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
-    return unconflicted_state, conflicted_state  # type: ignore
+    return unconflicted_state, conflicted_state
 
 
 def _is_power_event(event: EventBase) -> bool: