diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index a44960203e..ad9bbef9d2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,11 +14,12 @@
# limitations under the License.
import itertools
-
-from six.moves import zip
+from typing import List
import attr
+from twisted.internet import defer
+
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
@@ -43,7 +44,12 @@ MEMBERSHIP_CONTENT_BAN = {"membership": Membership.BAN}
ORIGIN_SERVER_TS = 0
-class FakeEvent(object):
+class FakeClock:
+ def sleep(self, msec):
+ return defer.succeed(None)
+
+
+class FakeEvent:
"""A fake event we use as a convenience.
NOTE: Again as a convenience we use "node_ids" rather than event_ids to
@@ -419,6 +425,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]])
else:
state_d = resolve_events_with_store(
+ FakeClock(),
ROOM_ID,
RoomVersions.V2.identifier,
[state_at_event[n] for n in prev_events],
@@ -426,7 +433,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
- state_before = self.successResultOf(state_d)
+ state_before = self.successResultOf(defer.ensureDeferred(state_d))
state_after = dict(state_before)
if fake_event.state_key is not None:
@@ -567,6 +574,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
+ FakeClock(),
ROOM_ID,
RoomVersions.V2.identifier,
[self.state_at_bob, self.state_at_charlie],
@@ -574,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(self.event_map),
)
- state = self.successResultOf(state_d)
+ state = self.successResultOf(defer.ensureDeferred(state_d))
self.assert_dict(self.expected_combined_state, state)
@@ -587,7 +595,7 @@ def pairwise(iterable):
@attr.s
-class TestStateResolutionStore(object):
+class TestStateResolutionStore:
event_map = attr.ib()
def get_events(self, event_ids, allow_rejected=False):
@@ -601,9 +609,11 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
- return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ return defer.succeed(
+ {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ )
- def _get_auth_chain(self, event_ids):
+ def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -615,10 +625,10 @@ class TestStateResolutionStore(object):
presence of rejected events
Args:
- event_ids (list): The event IDs of the events to fetch the auth
+ event_ids: The event IDs of the events to fetch the auth
chain for. Must be state events.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ List of event IDs of the auth chain.
"""
# Simple DFS for auth chain
@@ -641,4 +651,4 @@ class TestStateResolutionStore(object):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
- return set(chains[0]).union(*chains[1:]) - common
+ return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
|