summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_federation_sender.py19
-rw-r--r--tests/state/test_v2.py17
-rw-r--r--tests/storage/test_room.py8
-rw-r--r--tests/test_state.py72
-rw-r--r--tests/test_utils/__init__.py7
5 files changed, 73 insertions, 50 deletions
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 1a9bd5f37d..d1bd18da39 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -26,21 +26,24 @@ from synapse.rest import admin
 from synapse.rest.client.v1 import login
 from synapse.types import JsonDict, ReadReceipt
 
+from tests.test_utils import make_awaitable
 from tests.unittest import HomeserverTestCase, override_config
 
 
 class FederationSenderReceiptsTestCases(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
+        mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
+        # Ensure a new Awaitable is created for each call.
+        mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+            ["test", "host2"]
+        )
         return self.setup_test_homeserver(
-            state_handler=Mock(spec=["get_current_hosts_in_room"]),
+            state_handler=mock_state_handler,
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
 
     @override_config({"send_federation": True})
     def test_send_receipts(self):
-        mock_state_handler = self.hs.get_state_handler()
-        mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
@@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
     def test_send_receipts_with_backoff(self):
         """Send two receipts in quick succession; the second should be flushed, but
         only after 20ms"""
-        mock_state_handler = self.hs.get_state_handler()
-        mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
@@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         return self.setup_test_homeserver(
-            state_handler=Mock(spec=["get_current_hosts_in_room"]),
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
 
@@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         return c
 
     def prepare(self, reactor, clock, hs):
-        # stub out get_current_hosts_in_room
-        mock_state_handler = hs.get_state_handler()
-        mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
         # stub out get_users_who_share_room_with_user so that it claims that
         # `@user2:host2` is in the room
         def get_users_who_share_room_with_user(user_id):
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 38f9b423ef..f2955a9c69 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import itertools
+from typing import List
 
 import attr
 
@@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase):
                     state_res_store=TestStateResolutionStore(event_map),
                 )
 
-                state_before = self.successResultOf(state_d)
+                state_before = self.successResultOf(defer.ensureDeferred(state_d))
 
             state_after = dict(state_before)
             if fake_event.state_key is not None:
@@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
             state_res_store=TestStateResolutionStore(self.event_map),
         )
 
-        state = self.successResultOf(state_d)
+        state = self.successResultOf(defer.ensureDeferred(state_d))
 
         self.assert_dict(self.expected_combined_state, state)
 
@@ -608,9 +609,11 @@ class TestStateResolutionStore(object):
             Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
         """
 
-        return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+        return defer.succeed(
+            {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+        )
 
-    def _get_auth_chain(self, event_ids):
+    def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -622,10 +625,10 @@ class TestStateResolutionStore(object):
                presence of rejected events
 
         Args:
-            event_ids (list): The event IDs of the events to fetch the auth
+            event_ids: The event IDs of the events to fetch the auth
                 chain for. Must be state events.
         Returns:
-            Deferred[list[str]]: List of event IDs of the auth chain.
+            List of event IDs of the auth chain.
         """
 
         # Simple DFS for auth chain
@@ -648,4 +651,4 @@ class TestStateResolutionStore(object):
         chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
 
         common = set(chains[0]).intersection(*chains[1:])
-        return set(chains[0]).union(*chains[1:]) - common
+        return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index b1dceb2918..1d77b4a2d6 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             etype=EventTypes.Name, name=name, content={"name": name}, depth=1
         )
 
-        state = yield self.store.get_current_state(room_id=self.room.to_string())
+        state = yield defer.ensureDeferred(
+            self.store.get_current_state(room_id=self.room.to_string())
+        )
 
         self.assertEquals(1, len(state))
         self.assertObjectHasAttributes(
@@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
             etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
         )
 
-        state = yield self.store.get_current_state(room_id=self.room.to_string())
+        state = yield defer.ensureDeferred(
+            self.store.get_current_state(room_id=self.room.to_string())
+        )
 
         self.assertEquals(1, len(state))
         self.assertObjectHasAttributes(
diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f6813..4858e8fc59 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -97,17 +97,19 @@ class StateGroupStore(object):
 
         self._group_to_state[state_group] = dict(current_state_ids)
 
-        return state_group
+        return defer.succeed(state_group)
 
     def get_events(self, event_ids, **kwargs):
-        return {
-            e_id: self._event_id_to_event[e_id]
-            for e_id in event_ids
-            if e_id in self._event_id_to_event
-        }
+        return defer.succeed(
+            {
+                e_id: self._event_id_to_event[e_id]
+                for e_id in event_ids
+                if e_id in self._event_id_to_event
+            }
+        )
 
     def get_state_group_delta(self, name):
-        return None, None
+        return defer.succeed((None, None))
 
     def register_events(self, events):
         for e in events:
@@ -120,7 +122,7 @@ class StateGroupStore(object):
         self._event_to_state_group[event_id] = state_group
 
     def get_room_version_id(self, room_id):
-        return RoomVersions.V1.identifier
+        return defer.succeed(RoomVersions.V1.identifier)
 
 
 class DictObj(dict):
@@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}  # type: dict[str, EventContext]
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
         context_store = {}
 
         for event in graph.walk():
-            context = yield self.state.compute_event_context(event)
+            context = yield defer.ensureDeferred(
+                self.state.compute_event_context(event)
+            )
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
@@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        context = yield self.state.compute_event_context(event, old_state=old_state)
+        context = yield defer.ensureDeferred(
+            self.state.compute_event_context(event, old_state=old_state)
+        )
 
         prev_state_ids = yield context.get_prev_state_ids()
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         self.assertCountEqual(
             (e.event_id for e in old_state), current_state_ids.values()
         )
@@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        context = yield self.state.compute_event_context(event, old_state=old_state)
+        context = yield defer.ensureDeferred(
+            self.state.compute_event_context(event, old_state=old_state)
+        )
 
         prev_state_ids = yield context.get_prev_state_ids()
         self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
         self.assertCountEqual(
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
@@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = self.store.store_state_group(
+        group_name = yield self.store.store_state_group(
             prev_event_id,
             event.room_id,
             None,
@@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
-        context = yield self.state.compute_event_context(event)
+        context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(
             {e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test2", state_key=""),
         ]
 
-        group_name = self.store.store_state_group(
+        group_name = yield self.store.store_state_group(
             prev_event_id,
             event.room_id,
             None,
@@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
         )
         self.store.register_event_id_state_group(prev_event_id, group_name)
 
-        context = yield self.state.compute_event_context(event)
+        context = yield defer.ensureDeferred(self.state.compute_event_context(event))
 
         prev_state_ids = yield context.get_prev_state_ids()
 
@@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
 
@@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
 
         self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
 
+    @defer.inlineCallbacks
     def _get_context(
         self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
     ):
-        sg1 = self.store.store_state_group(
+        sg1 = yield self.store.store_state_group(
             prev_event_id_1,
             event.room_id,
             None,
@@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
         )
         self.store.register_event_id_state_group(prev_event_id_1, sg1)
 
-        sg2 = self.store.store_state_group(
+        sg2 = yield self.store.store_state_group(
             prev_event_id_2,
             event.room_id,
             None,
@@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
         )
         self.store.register_event_id_state_group(prev_event_id_2, sg2)
 
-        return self.state.compute_event_context(event)
+        result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+        return result
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03bb..508aeba078 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
 """
 Utilities for running the unit tests
 """
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
 
 TV = TypeVar("TV")
 
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
 
     # if next didn't raise, the awaitable hasn't completed.
     raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+    """Create an awaitable that just returns a result."""
+    return result