diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py
index 23f1b33b2f..70e6a7e142 100644
--- a/tests/handlers/test_federation_event.py
+++ b/tests/handlers/test_federation_event.py
@@ -35,7 +35,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
@@ -50,6 +50,10 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
self.mock_federation_transport_client = mock.Mock(
spec=["get_room_state_ids", "get_room_state", "get_event", "backfill"]
)
+ self.mock_federation_transport_client.get_room_state_ids = mock.AsyncMock()
+ self.mock_federation_transport_client.get_room_state = mock.AsyncMock()
+ self.mock_federation_transport_client.get_event = mock.AsyncMock()
+ self.mock_federation_transport_client.backfill = mock.AsyncMock()
return super().setup_test_homeserver(
federation_transport_client=self.mock_federation_transport_client
)
@@ -198,20 +202,14 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# we expect an outbound request to /state_ids, so stub that out
- self.mock_federation_transport_client.get_room_state_ids.return_value = (
- make_awaitable(
- {
- "pdu_ids": [e.event_id for e in state_at_prev_event],
- "auth_chain_ids": [],
- }
- )
- )
+ self.mock_federation_transport_client.get_room_state_ids.return_value = {
+ "pdu_ids": [e.event_id for e in state_at_prev_event],
+ "auth_chain_ids": [],
+ }
# we also expect an outbound request to /state
self.mock_federation_transport_client.get_room_state.return_value = (
- make_awaitable(
- StateRequestResponse(auth_events=[], state=state_at_prev_event)
- )
+ StateRequestResponse(auth_events=[], state=state_at_prev_event)
)
# we have to bump the clock a bit, to keep the retry logic in
@@ -273,26 +271,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
room_version = self.get_success(main_store.get_room_version(room_id))
# We expect an outbound request to /state_ids, so stub that out
- self.mock_federation_transport_client.get_room_state_ids.return_value = make_awaitable(
- {
- # Mimic the other server not knowing about the state at all.
- # We want to cause Synapse to throw an error (`Unable to get
- # missing prev_event $fake_prev_event`) and fail to backfill
- # the pulled event.
- "pdu_ids": [],
- "auth_chain_ids": [],
- }
- )
+ self.mock_federation_transport_client.get_room_state_ids.return_value = {
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ "pdu_ids": [],
+ "auth_chain_ids": [],
+ }
+
# We also expect an outbound request to /state
- self.mock_federation_transport_client.get_room_state.return_value = make_awaitable(
- StateRequestResponse(
- # Mimic the other server not knowing about the state at all.
- # We want to cause Synapse to throw an error (`Unable to get
- # missing prev_event $fake_prev_event`) and fail to backfill
- # the pulled event.
- auth_events=[],
- state=[],
- )
+ self.mock_federation_transport_client.get_room_state.return_value = StateRequestResponse(
+ # Mimic the other server not knowing about the state at all.
+ # We want to cause Synapse to throw an error (`Unable to get
+ # missing prev_event $fake_prev_event`) and fail to backfill
+ # the pulled event.
+ auth_events=[],
+ state=[],
)
pulled_event = make_event_from_dict(
@@ -545,25 +540,23 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# We expect an outbound request to /backfill, so stub that out
- self.mock_federation_transport_client.backfill.return_value = make_awaitable(
- {
- "origin": self.OTHER_SERVER_NAME,
- "origin_server_ts": 123,
- "pdus": [
- # This is one of the important aspects of this test: we include
- # `pulled_event_without_signatures` so it fails the signature check
- # when we filter down the backfill response down to events which
- # have valid signatures in
- # `_check_sigs_and_hash_for_pulled_events_and_fetch`
- pulled_event_without_signatures.get_pdu_json(),
- # Then later when we process this valid signature event, when we
- # fetch the missing `prev_event`s, we want to make sure that we
- # backoff and don't try and fetch `pulled_event_without_signatures`
- # again since we know it just had an invalid signature.
- pulled_event.get_pdu_json(),
- ],
- }
- )
+ self.mock_federation_transport_client.backfill.return_value = {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ # This is one of the important aspects of this test: we include
+ # `pulled_event_without_signatures` so it fails the signature check
+ # when we filter down the backfill response down to events which
+ # have valid signatures in
+ # `_check_sigs_and_hash_for_pulled_events_and_fetch`
+ pulled_event_without_signatures.get_pdu_json(),
+ # Then later when we process this valid signature event, when we
+ # fetch the missing `prev_event`s, we want to make sure that we
+ # backoff and don't try and fetch `pulled_event_without_signatures`
+ # again since we know it just had an invalid signature.
+ pulled_event.get_pdu_json(),
+ ],
+ }
# Keep track of the count and make sure we don't make any of these requests
event_endpoint_requested_count = 0
@@ -731,15 +724,13 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
)
# We expect an outbound request to /backfill, so stub that out
- self.mock_federation_transport_client.backfill.return_value = make_awaitable(
- {
- "origin": self.OTHER_SERVER_NAME,
- "origin_server_ts": 123,
- "pdus": [
- pulled_event.get_pdu_json(),
- ],
- }
- )
+ self.mock_federation_transport_client.backfill.return_value = {
+ "origin": self.OTHER_SERVER_NAME,
+ "origin_server_ts": 123,
+ "pdus": [
+ pulled_event.get_pdu_json(),
+ ],
+ }
# The function under test: try to backfill and process the pulled event
with LoggingContext("test"):
|