diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8370a27195..78b83d97b6 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -13,7 +13,17 @@
# limitations under the License.
import itertools
-from typing import List
+from typing import (
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
import attr
@@ -22,13 +32,13 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.state.v2 import (
_get_auth_chain_difference,
lexicographical_topological_sort,
resolve_events_with_store,
)
-from synapse.types import EventID
+from synapse.types import EventID, StateMap
from tests import unittest
@@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0
class FakeClock:
- def sleep(self, msec):
+ def sleep(self, msec: float) -> "defer.Deferred[None]":
return defer.succeed(None)
@@ -60,7 +70,14 @@ class FakeEvent:
as domain.
"""
- def __init__(self, id, sender, type, state_key, content):
+ def __init__(
+ self,
+ id: str,
+ sender: str,
+ type: str,
+ state_key: Optional[str],
+ content: Mapping[str, object],
+ ):
self.node_id = id
self.event_id = EventID(id, "example.com").to_string()
self.sender = sender
@@ -69,12 +86,12 @@ class FakeEvent:
self.content = content
self.room_id = ROOM_ID
- def to_event(self, auth_events, prev_events):
+ def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase:
"""Given the auth_events and prev_events, convert to a Frozen Event
Args:
- auth_events (list[str]): list of event_ids
- prev_events (list[str]): list of event_ids
+ auth_events: list of event_ids
+ prev_events: list of event_ids
Returns:
FrozenEvent
@@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
class StateTestCase(unittest.TestCase):
- def test_ban_vs_pl(self):
+ def test_ban_vs_pl(self) -> None:
events = [
FakeEvent(
id="PA",
@@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_join_rule_evasion(self):
+ def test_join_rule_evasion(self) -> None:
events = [
FakeEvent(
id="JR",
@@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_offtopic_pl(self):
+ def test_offtopic_pl(self) -> None:
events = [
FakeEvent(
id="PA",
@@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic_basic(self):
+ def test_topic_basic(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic_reset(self):
+ def test_topic_reset(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic(self):
+ def test_topic(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_mainline_sort(self):
+ def test_mainline_sort(self) -> None:
"""Tests that the mainline ordering works correctly."""
events = [
@@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def do_check(self, events, edges, expected_state_ids):
+ def do_check(
+ self,
+ events: List[FakeEvent],
+ edges: List[List[str]],
+ expected_state_ids: List[str],
+ ) -> None:
"""Take a list of events and edges and calculate the state of the
graph at END, and asserts it matches `expected_state_ids`
Args:
- events (list[FakeEvent])
- edges (list[list[str]]): A list of chains of event edges, e.g.
+ events
+ edges: A list of chains of event edges, e.g.
`[[A, B, C]]` are edges A->B and B->C.
- expected_state_ids (list[str]): The expected state at END, (excluding
+ expected_state_ids: The expected state at END, (excluding
the keys that haven't changed since START).
"""
# We want to sort the events into topological order for processing.
- graph = {}
+ graph: Dict[str, Set[str]] = {}
- # node_id -> FakeEvent
- fake_event_map = {}
+ fake_event_map: Dict[str, FakeEvent] = {}
for ev in itertools.chain(INITIAL_EVENTS, events):
graph[ev.node_id] = set()
@@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase):
for a, b in pairwise(edge_list):
graph[a].add(b)
- # event_id -> FrozenEvent
- event_map = {}
- # node_id -> state
- state_at_event = {}
+ event_map: Dict[str, EventBase] = {}
+ state_at_event: Dict[str, StateMap[str]] = {}
# We copy the map as the sort consumes the graph
graph_copy = {k: set(v) for k, v in graph.items()}
@@ -496,7 +515,16 @@ class StateTestCase(unittest.TestCase):
if fake_event.state_key is not None:
state_after[(fake_event.type, fake_event.state_key)] = event_id
- auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
+ # This type ignore is a bit sad. Things we have tried:
+ # 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and
+ # EventBuilder. But this is Hard because the relevant attributes are
+ # DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent.
+ # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and
+ # change this function to accept Union[Event, EventBase, EventBuilder].
+ # This seems reasonable to me, but mypy isn't happy. I think that's
+ # a mypy bug, see https://github.com/python/mypy/issues/5570
+ # Instead, resort to a type-ignore.
+ auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) # type: ignore[arg-type]
auth_events = []
for key in auth_types:
@@ -530,8 +558,14 @@ class StateTestCase(unittest.TestCase):
class LexicographicalTestCase(unittest.TestCase):
- def test_simple(self):
- graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}}
+ def test_simple(self) -> None:
+ graph: Dict[str, Set[str]] = {
+ "l": {"o"},
+ "m": {"n", "o"},
+ "n": {"o"},
+ "o": set(),
+ "p": {"o"},
+ }
res = list(lexicographical_topological_sort(graph, key=lambda x: x))
@@ -539,7 +573,7 @@ class LexicographicalTestCase(unittest.TestCase):
class SimpleParamStateTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
# We build up a simple DAG.
event_map = {}
@@ -627,7 +661,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
]
}
- def test_event_map_none(self):
+ def test_event_map_none(self) -> None:
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
@@ -649,7 +683,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
events.
"""
- def test_simple(self):
+ def test_simple(self) -> None:
# Test getting the auth difference for a simple chain with a single
# unpersisted event:
#
@@ -695,7 +729,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {c.event_id})
- def test_multiple_unpersisted_chain(self):
+ def test_multiple_unpersisted_chain(self) -> None:
# Test getting the auth difference for a simple chain with multiple
# unpersisted events:
#
@@ -752,7 +786,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {d.event_id, c.event_id})
- def test_unpersisted_events_different_sets(self):
+ def test_unpersisted_events_different_sets(self) -> None:
# Test getting the auth difference for with multiple unpersisted events
# in different branches:
#
@@ -820,7 +854,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {d.event_id, e.event_id})
-def pairwise(iterable):
+T = TypeVar("T")
+
+
+def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
@@ -829,24 +866,26 @@ def pairwise(iterable):
@attr.s
class TestStateResolutionStore:
- event_map = attr.ib()
+ event_map: Dict[str, EventBase] = attr.ib()
- def get_events(self, event_ids, allow_rejected=False):
+ def get_events(
+ self, event_ids: Collection[str], allow_rejected: bool = False
+ ) -> "defer.Deferred[Dict[str, EventBase]]":
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- allow_rejected (bool): If True return rejected events.
+ event_ids: The event_ids of the events to fetch
+ allow_rejected: If True return rejected events.
Returns:
- Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ Dict from event_id to event.
"""
return defer.succeed(
{eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
)
- def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
+ def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -880,7 +919,9 @@ class TestStateResolutionStore:
return list(result)
- def get_auth_chain_difference(self, room_id, auth_sets):
+ def get_auth_chain_difference(
+ self, room_id: str, auth_sets: List[Set[str]]
+ ) -> "defer.Deferred[Set[str]]":
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
|