diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 1a9bd5f37d..d1bd18da39 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -26,21 +26,24 @@ from synapse.rest import admin
from synapse.rest.client.v1 import login
from synapse.types import JsonDict, ReadReceipt
+from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
+ mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
+ # Ensure a new Awaitable is created for each call.
+ mock_state_handler.get_current_hosts_in_room.side_effect = lambda room_Id: make_awaitable(
+ ["test", "host2"]
+ )
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
+ state_handler=mock_state_handler,
federation_transport_client=Mock(spec=["send_transaction"]),
)
@override_config({"send_federation": True})
def test_send_receipts(self):
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
@@ -81,9 +84,6 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_state_handler = self.hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
mock_send_transaction = (
self.hs.get_federation_transport_client().send_transaction
)
@@ -164,7 +164,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
return self.setup_test_homeserver(
- state_handler=Mock(spec=["get_current_hosts_in_room"]),
federation_transport_client=Mock(spec=["send_transaction"]),
)
@@ -174,10 +173,6 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
return c
def prepare(self, reactor, clock, hs):
- # stub out get_current_hosts_in_room
- mock_state_handler = hs.get_state_handler()
- mock_state_handler.get_current_hosts_in_room.return_value = ["test", "host2"]
-
# stub out get_users_who_share_room_with_user so that it claims that
# `@user2:host2` is in the room
def get_users_who_share_room_with_user(user_id):
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 38f9b423ef..f2955a9c69 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -14,6 +14,7 @@
# limitations under the License.
import itertools
+from typing import List
import attr
@@ -432,7 +433,7 @@ class StateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(event_map),
)
- state_before = self.successResultOf(state_d)
+ state_before = self.successResultOf(defer.ensureDeferred(state_d))
state_after = dict(state_before)
if fake_event.state_key is not None:
@@ -581,7 +582,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
state_res_store=TestStateResolutionStore(self.event_map),
)
- state = self.successResultOf(state_d)
+ state = self.successResultOf(defer.ensureDeferred(state_d))
self.assert_dict(self.expected_combined_state, state)
@@ -608,9 +609,11 @@ class TestStateResolutionStore(object):
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
- return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ return defer.succeed(
+ {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
+ )
- def _get_auth_chain(self, event_ids):
+ def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -622,10 +625,10 @@ class TestStateResolutionStore(object):
presence of rejected events
Args:
- event_ids (list): The event IDs of the events to fetch the auth
+ event_ids: The event IDs of the events to fetch the auth
chain for. Must be state events.
Returns:
- Deferred[list[str]]: List of event IDs of the auth chain.
+ List of event IDs of the auth chain.
"""
# Simple DFS for auth chain
@@ -648,4 +651,4 @@ class TestStateResolutionStore(object):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
- return set(chains[0]).union(*chains[1:]) - common
+ return defer.succeed(set(chains[0]).union(*chains[1:]) - common)
diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py
index b1dceb2918..1d77b4a2d6 100644
--- a/tests/storage/test_room.py
+++ b/tests/storage/test_room.py
@@ -109,7 +109,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Name, name=name, content={"name": name}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
@@ -125,7 +127,9 @@ class RoomEventsStoreTestCase(unittest.TestCase):
etype=EventTypes.Topic, topic=topic, content={"topic": topic}, depth=1
)
- state = yield self.store.get_current_state(room_id=self.room.to_string())
+ state = yield defer.ensureDeferred(
+ self.store.get_current_state(room_id=self.room.to_string())
+ )
self.assertEquals(1, len(state))
self.assertObjectHasAttributes(
diff --git a/tests/test_state.py b/tests/test_state.py
index 66f22f6813..4858e8fc59 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -97,17 +97,19 @@ class StateGroupStore(object):
self._group_to_state[state_group] = dict(current_state_ids)
- return state_group
+ return defer.succeed(state_group)
def get_events(self, event_ids, **kwargs):
- return {
- e_id: self._event_id_to_event[e_id]
- for e_id in event_ids
- if e_id in self._event_id_to_event
- }
+ return defer.succeed(
+ {
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
+ )
def get_state_group_delta(self, name):
- return None, None
+ return defer.succeed((None, None))
def register_events(self, events):
for e in events:
@@ -120,7 +122,7 @@ class StateGroupStore(object):
self._event_to_state_group[event_id] = state_group
def get_room_version_id(self, room_id):
- return RoomVersions.V1.identifier
+ return defer.succeed(RoomVersions.V1.identifier)
class DictObj(dict):
@@ -202,7 +204,9 @@ class StateTestCase(unittest.TestCase):
context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -244,7 +248,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -300,7 +306,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -373,7 +381,9 @@ class StateTestCase(unittest.TestCase):
context_store = {}
for event in graph.walk():
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event)
+ )
self.store.register_event_context(event, context)
context_store[event.event_id] = context
@@ -411,12 +421,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state), current_state_ids.values()
)
@@ -434,12 +446,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(event, old_state=old_state)
+ context = yield defer.ensureDeferred(
+ self.state.compute_event_context(event, old_state=old_state)
+ )
prev_state_ids = yield context.get_prev_state_ids()
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertCountEqual(
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
@@ -462,7 +476,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -471,9 +485,9 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(
{e.event_id for e in old_state}, set(current_state_ids.values())
@@ -494,7 +508,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = self.store.store_state_group(
+ group_name = yield self.store.store_state_group(
prev_event_id,
event.room_id,
None,
@@ -503,7 +517,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id, group_name)
- context = yield self.state.compute_event_context(event)
+ context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids()
@@ -544,7 +558,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -586,7 +600,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(len(current_state_ids), 6)
@@ -641,7 +655,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -669,14 +683,15 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids()
+ current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
+ @defer.inlineCallbacks
def _get_context(
self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
):
- sg1 = self.store.store_state_group(
+ sg1 = yield self.store.store_state_group(
prev_event_id_1,
event.room_id,
None,
@@ -685,7 +700,7 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
- sg2 = self.store.store_state_group(
+ sg2 = yield self.store.store_state_group(
prev_event_id_2,
event.room_id,
None,
@@ -694,4 +709,5 @@ class StateTestCase(unittest.TestCase):
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
- return self.state.compute_event_context(event)
+ result = yield defer.ensureDeferred(self.state.compute_event_context(event))
+ return result
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 7b345b03bb..508aeba078 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -17,7 +17,7 @@
"""
Utilities for running the unit tests
"""
-from typing import Awaitable, TypeVar
+from typing import Any, Awaitable, TypeVar
TV = TypeVar("TV")
@@ -36,3 +36,8 @@ def get_awaitable_result(awaitable: Awaitable[TV]) -> TV:
# if next didn't raise, the awaitable hasn't completed.
raise Exception("awaitable has not yet completed")
+
+
+async def make_awaitable(result: Any):
+ """Create an awaitable that just returns a result."""
+ return result
|