summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12985.misc1
-rw-r--r--mypy.ini4
-rw-r--r--synapse/state/v2.py57
-rw-r--r--tests/state/test_v2.py125
4 files changed, 129 insertions, 58 deletions
diff --git a/changelog.d/12985.misc b/changelog.d/12985.misc
new file mode 100644
index 0000000000..d5ab9eedea
--- /dev/null
+++ b/changelog.d/12985.misc
@@ -0,0 +1 @@
+Add type annotations to `tests.state.test_v2`.
diff --git a/mypy.ini b/mypy.ini
index fe3e3f9b8e..7973f2ac01 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -56,7 +56,6 @@ exclude = (?x)
    |tests/rest/media/v1/test_media_storage.py
    |tests/server.py
    |tests/server_notices/test_resource_limits_server_notices.py
-   |tests/state/test_v2.py
    |tests/test_metrics.py
    |tests/test_server.py
    |tests/test_state.py
@@ -115,6 +114,9 @@ disallow_untyped_defs = False
 [mypy-tests.handlers.test_user_directory]
 disallow_untyped_defs = True
 
+[mypy-tests.state.test_profile]
+disallow_untyped_defs = True
+
 [mypy-tests.storage.test_profile]
 disallow_untyped_defs = True
 
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index c618df2fde..0e609114ef 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -17,12 +17,14 @@ import itertools
 import logging
 from typing import (
     Any,
+    Awaitable,
     Callable,
     Collection,
     Dict,
     Generator,
     Iterable,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -30,33 +32,58 @@ from typing import (
     overload,
 )
 
-from typing_extensions import Literal
+from typing_extensions import Literal, Protocol
 
-import synapse.state
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersion
 from synapse.events import EventBase
 from synapse.types import MutableStateMap, StateMap
-from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
 
 
+class Clock(Protocol):
+    # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests.
+    # We only ever sleep(0) though, so that other async functions can make forward
+    # progress without waiting for stateres to complete.
+    def sleep(self, duration_ms: float) -> Awaitable[None]:
+        ...
+
+
+class StateResolutionStore(Protocol):
+    # This is usually synapse.state.StateResolutionStore, but it's replaced with a
+    # TestStateResolutionStore in tests.
+    def get_events(
+        self, event_ids: Collection[str], allow_rejected: bool = False
+    ) -> Awaitable[Dict[str, EventBase]]:
+        ...
+
+    def get_auth_chain_difference(
+        self, room_id: str, state_sets: List[Set[str]]
+    ) -> Awaitable[Set[str]]:
+        ...
+
+
 # We want to await to the reactor occasionally during state res when dealing
 # with large data sets, so that we don't exhaust the reactor. This is done by
 # awaiting to reactor during loops every N iterations.
 _AWAIT_AFTER_ITERATIONS = 100
 
 
+__all__ = [
+    "resolve_events_with_store",
+]
+
+
 async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
     room_version: RoomVersion,
     state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> StateMap[str]:
     """Resolves the state using the v2 state resolution algorithm
 
@@ -194,7 +221,7 @@ async def _get_power_level_for_sender(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> int:
     """Return the power level of the sender of the given event according to
     their auth events.
@@ -243,9 +270,9 @@ async def _get_power_level_for_sender(
 
 async def _get_auth_chain_difference(
     room_id: str,
-    state_sets: Sequence[StateMap[str]],
+    state_sets: Sequence[Mapping[Any, str]],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
     that only appear in some but not all of the auth chains.
@@ -406,7 +433,7 @@ async def _add_event_and_auth_chain_to_graph(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     auth_diff: Set[str],
 ) -> None:
     """Helper function for _reverse_topological_power_sort that add the event
@@ -440,7 +467,7 @@ async def _reverse_topological_power_sort(
     room_id: str,
     event_ids: Iterable[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     auth_diff: Set[str],
 ) -> List[str]:
     """Returns a list of the event_ids sorted by reverse topological ordering,
@@ -501,7 +528,7 @@ async def _iterative_auth_checks(
     event_ids: List[str],
     base_state: StateMap[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> MutableStateMap[str]:
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
@@ -570,7 +597,7 @@ async def _mainline_sort(
     event_ids: List[str],
     resolved_power_event_id: Optional[str],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> List[str]:
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
@@ -639,7 +666,7 @@ async def _get_mainline_depth_for_event(
     event: EventBase,
     mainline_map: Dict[str, int],
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
 ) -> int:
     """Get the mainline depths for the given event based on the mainline map
 
@@ -683,7 +710,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: Literal[False] = False,
 ) -> EventBase:
     ...
@@ -694,7 +721,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: Literal[True],
 ) -> Optional[EventBase]:
     ...
@@ -704,7 +731,7 @@ async def _get_event(
     room_id: str,
     event_id: str,
     event_map: Dict[str, EventBase],
-    state_res_store: "synapse.state.StateResolutionStore",
+    state_res_store: StateResolutionStore,
     allow_none: bool = False,
 ) -> Optional[EventBase]:
     """Helper function to look up event in event_map, falling back to looking
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8370a27195..78b83d97b6 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -13,7 +13,17 @@
 # limitations under the License.
 
 import itertools
-from typing import List
+from typing import (
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Set,
+    Tuple,
+    TypeVar,
+)
 
 import attr
 
@@ -22,13 +32,13 @@ from twisted.internet import defer
 from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
 from synapse.state.v2 import (
     _get_auth_chain_difference,
     lexicographical_topological_sort,
     resolve_events_with_store,
 )
-from synapse.types import EventID
+from synapse.types import EventID, StateMap
 
 from tests import unittest
 
@@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0
 
 
 class FakeClock:
-    def sleep(self, msec):
+    def sleep(self, msec: float) -> "defer.Deferred[None]":
         return defer.succeed(None)
 
 
@@ -60,7 +70,14 @@ class FakeEvent:
     as domain.
     """
 
-    def __init__(self, id, sender, type, state_key, content):
+    def __init__(
+        self,
+        id: str,
+        sender: str,
+        type: str,
+        state_key: Optional[str],
+        content: Mapping[str, object],
+    ):
         self.node_id = id
         self.event_id = EventID(id, "example.com").to_string()
         self.sender = sender
@@ -69,12 +86,12 @@ class FakeEvent:
         self.content = content
         self.room_id = ROOM_ID
 
-    def to_event(self, auth_events, prev_events):
+    def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase:
         """Given the auth_events and prev_events, convert to a Frozen Event
 
         Args:
-            auth_events (list[str]): list of event_ids
-            prev_events (list[str]): list of event_ids
+            auth_events: list of event_ids
+            prev_events: list of event_ids
 
         Returns:
             FrozenEvent
@@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
 
 
 class StateTestCase(unittest.TestCase):
-    def test_ban_vs_pl(self):
+    def test_ban_vs_pl(self) -> None:
         events = [
             FakeEvent(
                 id="PA",
@@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_join_rule_evasion(self):
+    def test_join_rule_evasion(self) -> None:
         events = [
             FakeEvent(
                 id="JR",
@@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_offtopic_pl(self):
+    def test_offtopic_pl(self) -> None:
         events = [
             FakeEvent(
                 id="PA",
@@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_topic_basic(self):
+    def test_topic_basic(self) -> None:
         events = [
             FakeEvent(
                 id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_topic_reset(self):
+    def test_topic_reset(self) -> None:
         events = [
             FakeEvent(
                 id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_topic(self):
+    def test_topic(self) -> None:
         events = [
             FakeEvent(
                 id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def test_mainline_sort(self):
+    def test_mainline_sort(self) -> None:
         """Tests that the mainline ordering works correctly."""
 
         events = [
@@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase):
 
         self.do_check(events, edges, expected_state_ids)
 
-    def do_check(self, events, edges, expected_state_ids):
+    def do_check(
+        self,
+        events: List[FakeEvent],
+        edges: List[List[str]],
+        expected_state_ids: List[str],
+    ) -> None:
         """Take a list of events and edges and calculate the state of the
         graph at END, and asserts it matches `expected_state_ids`
 
         Args:
-            events (list[FakeEvent])
-            edges (list[list[str]]): A list of chains of event edges, e.g.
+            events
+            edges: A list of chains of event edges, e.g.
                 `[[A, B, C]]` are edges A->B and B->C.
-            expected_state_ids (list[str]): The expected state at END, (excluding
+            expected_state_ids: The expected state at END, (excluding
                 the keys that haven't changed since START).
         """
         # We want to sort the events into topological order for processing.
-        graph = {}
+        graph: Dict[str, Set[str]] = {}
 
-        # node_id -> FakeEvent
-        fake_event_map = {}
+        fake_event_map: Dict[str, FakeEvent] = {}
 
         for ev in itertools.chain(INITIAL_EVENTS, events):
             graph[ev.node_id] = set()
@@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase):
             for a, b in pairwise(edge_list):
                 graph[a].add(b)
 
-        # event_id -> FrozenEvent
-        event_map = {}
-        # node_id -> state
-        state_at_event = {}
+        event_map: Dict[str, EventBase] = {}
+        state_at_event: Dict[str, StateMap[str]] = {}
 
         # We copy the map as the sort consumes the graph
         graph_copy = {k: set(v) for k, v in graph.items()}
@@ -496,7 +515,16 @@ class StateTestCase(unittest.TestCase):
             if fake_event.state_key is not None:
                 state_after[(fake_event.type, fake_event.state_key)] = event_id
 
-            auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
+            # This type ignore is a bit sad. Things we have tried:
+            # 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and
+            #    EventBuilder. But this is Hard because the relevant attributes are
+            #    DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent.
+            # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and
+            #    change this function to accept Union[Event, EventBase, EventBuilder].
+            #    This seems reasonable to me, but mypy isn't happy. I think that's
+            #    a mypy bug, see https://github.com/python/mypy/issues/5570
+            # Instead, resort to a type-ignore.
+            auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))  # type: ignore[arg-type]
 
             auth_events = []
             for key in auth_types:
@@ -530,8 +558,14 @@ class StateTestCase(unittest.TestCase):
 
 
 class LexicographicalTestCase(unittest.TestCase):
-    def test_simple(self):
-        graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}}
+    def test_simple(self) -> None:
+        graph: Dict[str, Set[str]] = {
+            "l": {"o"},
+            "m": {"n", "o"},
+            "n": {"o"},
+            "o": set(),
+            "p": {"o"},
+        }
 
         res = list(lexicographical_topological_sort(graph, key=lambda x: x))
 
@@ -539,7 +573,7 @@ class LexicographicalTestCase(unittest.TestCase):
 
 
 class SimpleParamStateTestCase(unittest.TestCase):
-    def setUp(self):
+    def setUp(self) -> None:
         # We build up a simple DAG.
 
         event_map = {}
@@ -627,7 +661,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
             ]
         }
 
-    def test_event_map_none(self):
+    def test_event_map_none(self) -> None:
         # Test that we correctly handle passing `None` as the event_map
 
         state_d = resolve_events_with_store(
@@ -649,7 +683,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
     events.
     """
 
-    def test_simple(self):
+    def test_simple(self) -> None:
         # Test getting the auth difference for a simple chain with a single
         # unpersisted event:
         #
@@ -695,7 +729,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
 
         self.assertEqual(difference, {c.event_id})
 
-    def test_multiple_unpersisted_chain(self):
+    def test_multiple_unpersisted_chain(self) -> None:
         # Test getting the auth difference for a simple chain with multiple
         # unpersisted events:
         #
@@ -752,7 +786,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
 
         self.assertEqual(difference, {d.event_id, c.event_id})
 
-    def test_unpersisted_events_different_sets(self):
+    def test_unpersisted_events_different_sets(self) -> None:
         # Test getting the auth difference for with multiple unpersisted events
         # in different branches:
         #
@@ -820,7 +854,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
         self.assertEqual(difference, {d.event_id, e.event_id})
 
 
-def pairwise(iterable):
+T = TypeVar("T")
+
+
+def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
     a, b = itertools.tee(iterable)
     next(b, None)
@@ -829,24 +866,26 @@ def pairwise(iterable):
 
 @attr.s
 class TestStateResolutionStore:
-    event_map = attr.ib()
+    event_map: Dict[str, EventBase] = attr.ib()
 
-    def get_events(self, event_ids, allow_rejected=False):
+    def get_events(
+        self, event_ids: Collection[str], allow_rejected: bool = False
+    ) -> "defer.Deferred[Dict[str, EventBase]]":
         """Get events from the database
 
         Args:
-            event_ids (list): The event_ids of the events to fetch
-            allow_rejected (bool): If True return rejected events.
+            event_ids: The event_ids of the events to fetch
+            allow_rejected: If True return rejected events.
 
         Returns:
-            Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+            Dict from event_id to event.
         """
 
         return defer.succeed(
             {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
         )
 
-    def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
+    def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]:
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -880,7 +919,9 @@ class TestStateResolutionStore:
 
         return list(result)
 
-    def get_auth_chain_difference(self, room_id, auth_sets):
+    def get_auth_chain_difference(
+        self, room_id: str, auth_sets: List[Set[str]]
+    ) -> "defer.Deferred[Set[str]]":
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])