summary refs log tree commit diff
path: root/tests/state/test_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/state/test_v2.py')
-rw-r--r--tests/state/test_v2.py17
1 files changed, 10 insertions, 7 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 38f9b423ef..f2955a9c69 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import itertools
+from typing import List
 
 import attr
 
@@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase):
                     state_res_store=TestStateResolutionStore(event_map),
                 )
 
-                state_before = self.successResultOf(state_d)
+                state_before = self.successResultOf(defer.ensureDeferred(state_d))
 
             state_after = dict(state_before)
             if fake_event.state_key is not None:
@@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
             state_res_store=TestStateResolutionStore(self.event_map),
         )
 
-        state = self.successResultOf(state_d)
+        state = self.successResultOf(defer.ensureDeferred(state_d))
 
         self.assert_dict(self.expected_combined_state, state)
 
@@ -608,9 +609,11 @@ class TestStateResolutionStore(object):
             Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
         """
 
-        return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+        return defer.succeed(
+            {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+        )
 
-    def _get_auth_chain(self, event_ids):
+    def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -622,10 +625,10 @@ class TestStateResolutionStore(object):
                presence of rejected events
 
         Args:
-            event_ids (list): The event IDs of the events to fetch the auth
+            event_ids: The event IDs of the events to fetch the auth
                 chain for. Must be state events.
         Returns:
-            Deferred[list[str]]: List of event IDs of the auth chain.
+            List of event IDs of the auth chain.
         """
 
         # Simple DFS for auth chain
@@ -648,4 +651,4 @@ class TestStateResolutionStore(object):
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
-        return set(chains[0]).union(*chains[1:]) - common
+        return defer.succeed(set(chains[0]).union(*chains[1:]) - common)