diff options
author | David Robertson <davidr@element.io> | 2022-06-09 09:48:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-06-09 09:48:04 +0100 |
commit | 97053c94060ea31d3b9d41a129221ad4b2a76865 (patch) | |
tree | d1f40ebd953b13b7804914c7d7175a644e2c008b /tests/state/test_v2.py | |
parent | Use READ COMMITTED isolation level when inserting read receipts (#12957) (diff) | |
download | synapse-97053c94060ea31d3b9d41a129221ad4b2a76865.tar.xz |
Type annotations for `test_v2` (#12985)
Diffstat (limited to 'tests/state/test_v2.py')
-rw-r--r-- | tests/state/test_v2.py | 125 |
1 files changed, 83 insertions, 42 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8370a27195..78b83d97b6 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -13,7 +13,17 @@ # limitations under the License. import itertools -from typing import List +from typing import ( + Collection, + Dict, + Iterable, + List, + Mapping, + Optional, + Set, + Tuple, + TypeVar, +) import attr @@ -22,13 +32,13 @@ from twisted.internet import defer from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event -from synapse.events import make_event_from_dict +from synapse.events import EventBase, make_event_from_dict from synapse.state.v2 import ( _get_auth_chain_difference, lexicographical_topological_sort, resolve_events_with_store, ) -from synapse.types import EventID +from synapse.types import EventID, StateMap from tests import unittest @@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0 class FakeClock: - def sleep(self, msec): + def sleep(self, msec: float) -> "defer.Deferred[None]": return defer.succeed(None) @@ -60,7 +70,14 @@ class FakeEvent: as domain. """ - def __init__(self, id, sender, type, state_key, content): + def __init__( + self, + id: str, + sender: str, + type: str, + state_key: Optional[str], + content: Mapping[str, object], + ): self.node_id = id self.event_id = EventID(id, "example.com").to_string() self.sender = sender @@ -69,12 +86,12 @@ class FakeEvent: self.content = content self.room_id = ROOM_ID - def to_event(self, auth_events, prev_events): + def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase: """Given the auth_events and prev_events, convert to a Frozen Event Args: - auth_events (list[str]): list of event_ids - prev_events (list[str]): list of event_ids + auth_events: list of event_ids + prev_events: list of event_ids Returns: FrozenEvent @@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] class StateTestCase(unittest.TestCase): - def test_ban_vs_pl(self): + def test_ban_vs_pl(self) -> None: events = [ FakeEvent( id="PA", @@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_join_rule_evasion(self): + def test_join_rule_evasion(self) -> None: events = [ FakeEvent( id="JR", @@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_offtopic_pl(self): + def test_offtopic_pl(self) -> None: events = [ FakeEvent( id="PA", @@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_topic_basic(self): + def test_topic_basic(self) -> None: events = [ FakeEvent( id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} @@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_topic_reset(self): + def test_topic_reset(self) -> None: events = [ FakeEvent( id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} @@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_topic(self): + def test_topic(self) -> None: events = [ FakeEvent( id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} @@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def test_mainline_sort(self): + def test_mainline_sort(self) -> None: """Tests that the mainline ordering works correctly.""" events = [ @@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) - def do_check(self, events, edges, expected_state_ids): + def do_check( + self, + events: List[FakeEvent], + edges: List[List[str]], + expected_state_ids: List[str], + ) -> None: """Take a list of events and edges and calculate the state of the graph at END, and asserts it matches `expected_state_ids` Args: - events (list[FakeEvent]) - edges (list[list[str]]): A list of chains of event edges, e.g. + events + edges: A list of chains of event edges, e.g. `[[A, B, C]]` are edges A->B and B->C. - expected_state_ids (list[str]): The expected state at END, (excluding + expected_state_ids: The expected state at END, (excluding the keys that haven't changed since START). """ # We want to sort the events into topological order for processing. - graph = {} + graph: Dict[str, Set[str]] = {} - # node_id -> FakeEvent - fake_event_map = {} + fake_event_map: Dict[str, FakeEvent] = {} for ev in itertools.chain(INITIAL_EVENTS, events): graph[ev.node_id] = set() @@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase): for a, b in pairwise(edge_list): graph[a].add(b) - # event_id -> FrozenEvent - event_map = {} - # node_id -> state - state_at_event = {} + event_map: Dict[str, EventBase] = {} + state_at_event: Dict[str, StateMap[str]] = {} # We copy the map as the sort consumes the graph graph_copy = {k: set(v) for k, v in graph.items()} @@ -496,7 +515,16 @@ class StateTestCase(unittest.TestCase): if fake_event.state_key is not None: state_after[(fake_event.type, fake_event.state_key)] = event_id - auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) + # This type ignore is a bit sad. Things we have tried: + # 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and + # EventBuilder. But this is Hard because the relevant attributes are + # DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent. + # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and + # change this function to accept Union[Event, EventBase, EventBuilder]. + # This seems reasonable to me, but mypy isn't happy. I think that's + # a mypy bug, see https://github.com/python/mypy/issues/5570 + # Instead, resort to a type-ignore. + auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) # type: ignore[arg-type] auth_events = [] for key in auth_types: @@ -530,8 +558,14 @@ class StateTestCase(unittest.TestCase): class LexicographicalTestCase(unittest.TestCase): - def test_simple(self): - graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}} + def test_simple(self) -> None: + graph: Dict[str, Set[str]] = { + "l": {"o"}, + "m": {"n", "o"}, + "n": {"o"}, + "o": set(), + "p": {"o"}, + } res = list(lexicographical_topological_sort(graph, key=lambda x: x)) @@ -539,7 +573,7 @@ class LexicographicalTestCase(unittest.TestCase): class SimpleParamStateTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: # We build up a simple DAG. event_map = {} @@ -627,7 +661,7 @@ class SimpleParamStateTestCase(unittest.TestCase): ] } - def test_event_map_none(self): + def test_event_map_none(self) -> None: # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( @@ -649,7 +683,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase): events. """ - def test_simple(self): + def test_simple(self) -> None: # Test getting the auth difference for a simple chain with a single # unpersisted event: # @@ -695,7 +729,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase): self.assertEqual(difference, {c.event_id}) - def test_multiple_unpersisted_chain(self): + def test_multiple_unpersisted_chain(self) -> None: # Test getting the auth difference for a simple chain with multiple # unpersisted events: # @@ -752,7 +786,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase): self.assertEqual(difference, {d.event_id, c.event_id}) - def test_unpersisted_events_different_sets(self): + def test_unpersisted_events_different_sets(self) -> None: # Test getting the auth difference for with multiple unpersisted events # in different branches: # @@ -820,7 +854,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase): self.assertEqual(difference, {d.event_id, e.event_id}) -def pairwise(iterable): +T = TypeVar("T") + + +def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]: "s -> (s0,s1), (s1,s2), (s2, s3), ..." a, b = itertools.tee(iterable) next(b, None) @@ -829,24 +866,26 @@ def pairwise(iterable): @attr.s class TestStateResolutionStore: - event_map = attr.ib() + event_map: Dict[str, EventBase] = attr.ib() - def get_events(self, event_ids, allow_rejected=False): + def get_events( + self, event_ids: Collection[str], allow_rejected: bool = False + ) -> "defer.Deferred[Dict[str, EventBase]]": """Get events from the database Args: - event_ids (list): The event_ids of the events to fetch - allow_rejected (bool): If True return rejected events. + event_ids: The event_ids of the events to fetch + allow_rejected: If True return rejected events. Returns: - Deferred[dict[str, FrozenEvent]]: Dict from event_id to event. + Dict from event_id to event. """ return defer.succeed( {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} ) - def _get_auth_chain(self, event_ids: List[str]) -> List[str]: + def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]: """Gets the full auth chain for a set of events (including rejected events). @@ -880,7 +919,9 @@ class TestStateResolutionStore: return list(result) - def get_auth_chain_difference(self, room_id, auth_sets): + def get_auth_chain_difference( + self, room_id: str, auth_sets: List[Set[str]] + ) -> "defer.Deferred[Set[str]]": chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) |