summary refs log tree commit diff
path: root/synapse/state/v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/v2.py')
-rw-r--r--synapse/state/v2.py58
1 files changed, 42 insertions, 16 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index c618df2fde..6a16f38a15 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -17,12 +17,14 @@ import itertools
 import logging
 from typing import (
     Any,
+    Awaitable,
     Callable,
     Collection,
     Dict,
     Generator,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -30,33 +32,58 @@ from typing import (
     overload,
 )
 
-from typing_extensions import Literal
+from typing_extensions import Literal, Protocol
 
-import synapse.state
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
-from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
 
+class Clock(Protocol):
+    # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
+    # We only ever sleep(0) though, so that other async functions can make forward
+    # progress without waiting for stateres to complete.
+    def sleep(self, duration_ms: float) -> Awaitable[None]:
+        ...
+
+
+class StateResolutionStore(Protocol):
+    # This is usually synapse.state.StateResolutionStore, but it's replaced with a
+    # TestStateResolutionStore in tests.
+    def get_events(
+        self, event_ids: Collection[str], allow_rejected: bool = False
+    ) -> Awaitable[Dict[str, EventBase]]:
+        ...
+
+    def get_auth_chain_difference(
+        self, room_id: str, state_sets: List[Set[str]]
+    ) -> Awaitable[Set[str]]:
+        ...
+
+
 # We want to await to the reactor occasionally during state res when dealing
 # with large data sets, so that we don't exhaust the reactor. This is done by
 # awaiting to reactor during loops every N iterations.
 _AWAIT_AFTER_ITERATIONS = 100
 
 
+__all__ = [
+    "resolve_events_with_store",
+]
+
+
 async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
     room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> StateMap[str]:
     """Resolves the state using the v2 state resolution algorithm
 
@@ -194,7 +221,7 @@ async def _get_power_level_for_sender(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> int:
     """Return the power level of the sender of the given event according to
     their auth events.
@@ -243,9 +270,9 @@ async def _get_power_level_for_sender(
 
 async def _get_auth_chain_difference(
     room_id: str,
-    state_sets: Sequence[StateMap[str]],
+    state_sets: Sequence[Mapping[Any, str]],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
     that only appear in some but not all of the auth chains.
@@ -406,7 +433,7 @@ async def _add_event_and_auth_chain_to_graph(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     auth_diff: Set[str],
 ) -> None:
     """Helper function for _reverse_topological_power_sort that add the event
@@ -440,7 +467,7 @@ async def _reverse_topological_power_sort(
     room_id: str,
     event_ids: Iterable[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     auth_diff: Set[str],
 ) -> List[str]:
     """Returns a list of the event_ids sorted by reverse topological ordering,
@@ -501,7 +528,7 @@ async def _iterative_auth_checks(
     event_ids: List[str],
     base_state: StateMap[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> MutableStateMap[str]:
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
@@ -547,7 +574,6 @@ async def _iterative_auth_checks(
 
         try:
             event_auth.check_auth_rules_for_event(
-                room_version,
                 event,
                 auth_events.values(),
             )
@@ -570,7 +596,7 @@ async def _mainline_sort(
     event_ids: List[str],
     resolved_power_event_id: Optional[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> List[str]:
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
@@ -639,7 +665,7 @@ async def _get_mainline_depth_for_event(
     event: EventBase,
     mainline_map: Dict[str, int],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> int:
     """Get the mainline depths for the given event based on the mainline map
 
@@ -683,7 +709,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: Literal[False] = False,
 ) -> EventBase:
     ...
@@ -694,7 +720,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: Literal[True],
 ) -> Optional[EventBase]:
     ...
@@ -704,7 +730,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: bool = False,
 ) -> Optional[EventBase]:
     """Helper function to look up event in event_map, falling back to looking