summary refs log tree commit diff
path: root/tests/state/test_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/state/test_v2.py')
-rw-r--r--tests/state/test_v2.py32
1 files changed, 21 insertions, 11 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index a44960203e..ad9bbef9d2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,11 +14,12 @@
 # limitations under the License.
 
 import itertools
-
-from six.moves import zip
+from typing import List
 
 import attr
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
@@ -43,7 +44,12 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
 ORIGIN_SERVER_TS = 0
 
 
-class FakeEvent(object):
+class FakeClock:
+    def sleep(self, msec):
+        return defer.succeed(None)
+
+
+class FakeEvent:
     """A fake event we use as a convenience.
 
     NOTE: Again as a convenience we use "node_ids" rather than event_ids to
@@ -419,6 +425,7 @@ class StateTestCase(unittest.TestCase):
                 state_before = dict(state_at_event[prev_events[0]])
             else:
                 state_d = resolve_events_with_store(
+                    FakeClock(),
                     ROOM_ID,
                     RoomVersions.V2.identifier,
                     [state_at_event[n] for n in prev_events],
@@ -426,7 +433,7 @@ class StateTestCase(unittest.TestCase):
                     state_res_store=TestStateResolutionStore(event_map),
                 )
 
-                state_before = self.successResultOf(state_d)
+                state_before = self.successResultOf(defer.ensureDeferred(state_d))
 
             state_after = dict(state_before)
             if fake_event.state_key is not None:
@@ -567,6 +574,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
         # Test that we correctly handle passing `None` as the event_map
 
         state_d = resolve_events_with_store(
+            FakeClock(),
             ROOM_ID,
             RoomVersions.V2.identifier,
             [self.state_at_bob, self.state_at_charlie],
@@ -574,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
             state_res_store=TestStateResolutionStore(self.event_map),
         )
 
-        state = self.successResultOf(state_d)
+        state = self.successResultOf(defer.ensureDeferred(state_d))
 
         self.assert_dict(self.expected_combined_state, state)
 
@@ -587,7 +595,7 @@ def pairwise(iterable):
 
 
 @attr.s
-class TestStateResolutionStore(object):
+class TestStateResolutionStore:
     event_map = attr.ib()
 
     def get_events(self, event_ids, allow_rejected=False):
@@ -601,9 +609,11 @@ class TestStateResolutionStore(object):
             Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
         """
 
-        return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+        return defer.succeed(
+            {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+        )
 
-    def _get_auth_chain(self, event_ids):
+    def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -615,10 +625,10 @@ class TestStateResolutionStore(object):
                presence of rejected events
 
         Args:
-            event_ids (list): The event IDs of the events to fetch the auth
+            event_ids: The event IDs of the events to fetch the auth
                 chain for. Must be state events.
         Returns:
-            Deferred[list[str]]: List of event IDs of the auth chain.
+            List of event IDs of the auth chain.
         """
 
         # Simple DFS for auth chain
@@ -641,4 +651,4 @@ class TestStateResolutionStore(object):
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
-        return set(chains[0]).union(*chains[1:]) - common
+        return defer.succeed(set(chains[0]).union(*chains[1:]) - common)