diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index ddeffe1ad5..9e104fd96a 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -22,6 +22,7 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
from synapse.federation.units import Transaction
+from synapse.handlers.device import DeviceHandler
from synapse.rest import admin
from synapse.rest.client import login
from synapse.server import HomeServer
@@ -41,8 +42,9 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(spec=["send_transaction"])
hs = self.setup_test_homeserver(
- federation_transport_client=Mock(spec=["send_transaction"]),
+ federation_transport_client=self.federation_transport_client,
)
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
@@ -61,9 +63,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
return config
def test_send_receipts(self) -> None:
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -103,9 +103,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
)
def test_send_receipts_thread(self) -> None:
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
# Create receipts for:
@@ -181,9 +179,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_with_backoff(self) -> None:
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
- mock_send_transaction = (
- self.hs.get_federation_transport_client().send_transaction
- )
+ mock_send_transaction = self.federation_transport_client.send_transaction
mock_send_transaction.return_value = make_awaitable({})
sender = self.hs.get_federation_sender()
@@ -277,10 +273,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ self.federation_transport_client = Mock(
+ spec=["send_transaction", "query_user_devices"]
+ )
return self.setup_test_homeserver(
- federation_transport_client=Mock(
- spec=["send_transaction", "query_user_devices"]
- ),
+ federation_transport_client=self.federation_transport_client,
)
def default_config(self) -> JsonDict:
@@ -310,9 +307,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room # type: ignore[assignment]
+ device_handler = hs.get_device_handler()
+ assert isinstance(device_handler, DeviceHandler)
+ self.device_handler = device_handler
+
# whenever send_transaction is called, record the edu data
self.edus: List[JsonDict] = []
- self.hs.get_federation_transport_client().send_transaction.side_effect = (
+ self.federation_transport_client.send_transaction.side_effect = (
self.record_transaction
)
@@ -353,7 +354,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
- self.hs.get_federation_transport_client().query_user_devices.return_value = (
+ self.federation_transport_client.query_user_devices.return_value = (
make_awaitable(
{
"stream_id": "1",
@@ -364,7 +365,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
)
self.get_success(
- self.hs.get_device_handler().device_list_updater.incoming_device_list_update(
+ self.device_handler.device_list_updater.incoming_device_list_update(
"host2",
{
"user_id": "@user2:host2",
@@ -507,9 +508,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D3", stream_id)
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -533,7 +532,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
"""If the destination server is unreachable, all the updates should get sent on
recovery
"""
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -543,9 +542,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -580,7 +577,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
# create devices
@@ -590,9 +587,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.login("user", "pass", device_id="D3")
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -640,7 +635,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.check_device_update_edu(self.edus.pop(0), u1, "D1", None)
# now the server goes offline
- mock_send_txn = self.hs.get_federation_transport_client().send_transaction
+ mock_send_txn = self.federation_transport_client.send_transaction
mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
self.login("user", "pass", device_id="D2")
@@ -651,9 +646,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.reactor.advance(1)
# delete them again
- self.get_success(
- self.hs.get_device_handler().delete_devices(u1, ["D1", "D2", "D3"])
- )
+ self.get_success(self.device_handler.delete_devices(u1, ["D1", "D2", "D3"]))
self.assertGreaterEqual(mock_send_txn.call_count, 3)
|