diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index ec8864dafe..268a48d7ba 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
)
# mock up the response, and have the agent return it
- self._mock_agent.request.return_value = defer.succeed(
+ self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
_mock_response(
{
"pdus": [
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 91f982518e..6b26353d5e 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
self.hs.get_federation_transport_client().query_user_devices.return_value = (
- defer.succeed(
+ make_awaitable(
{
"stream_id": "1",
"user_id": "@user2:host2",
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 8c72cf6b30..5b0cd1ab86 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -411,6 +411,88 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
"exclusive_as_user", "password", self.exclusive_as_user_device_id
)
+ def test_sending_read_receipt_batches_to_application_services(self):
+ """Tests that a large batch of read receipts are sent correctly to
+ interested application services.
+ """
+ # Register an application service that's interested in a certain user
+ # and room prefix
+ interested_appservice = self._register_application_service(
+ namespaces={
+ ApplicationService.NS_USERS: [
+ {
+ "regex": "@exclusive_as_user:.+",
+ "exclusive": True,
+ }
+ ],
+ ApplicationService.NS_ROOMS: [
+ {
+ "regex": "!fakeroom_.*",
+ "exclusive": True,
+ }
+ ],
+ },
+ )
+
+ # "Complete" a transaction.
+ # All this really does for us is make an entry in the application_services_state
+ # database table, which tracks the current stream_token per stream ID per AS.
+ self.get_success(
+ self.hs.get_datastores().main.complete_appservice_txn(
+ 0,
+ interested_appservice,
+ )
+ )
+
+ # Now, pretend that we receive a large burst of read receipts (300 total) that
+ # all come in at once.
+ for i in range(300):
+ self.get_success(
+ # Insert a fake read receipt into the database
+ self.hs.get_datastores().main.insert_receipt(
+ # We have to use unique room ID + user ID combinations here, as the db query
+ # is an upsert.
+ room_id=f"!fakeroom_{i}:test",
+ receipt_type="m.read",
+ user_id=self.local_user,
+ event_ids=[f"$eventid_{i}"],
+ data={},
+ )
+ )
+
+ # Now notify the appservice handler that 300 read receipts have all arrived
+ # at once. What will it do!
+ # note: stream tokens start at 2
+ for stream_token in range(2, 303):
+ self.get_success(
+ self.hs.get_application_service_handler()._notify_interested_services_ephemeral(
+ services=[interested_appservice],
+ stream_key="receipt_key",
+ new_token=stream_token,
+ users=[self.exclusive_as_user],
+ )
+ )
+
+ # Using our txn send mock, we can see what the AS received. After iterating over every
+ # transaction, we'd like to see all 300 read receipts accounted for.
+ # No more, no less.
+ all_ephemeral_events = []
+ for call in self.send_mock.call_args_list:
+ ephemeral_events = call[0][2]
+ all_ephemeral_events += ephemeral_events
+
+ # Ensure that no duplicate events were sent
+ self.assertEqual(len(all_ephemeral_events), 300)
+
+ # Check that the ephemeral event is a read receipt with the expected structure
+ latest_read_receipt = all_ephemeral_events[-1]
+ self.assertEqual(latest_read_receipt["type"], "m.receipt")
+
+ event_id = list(latest_read_receipt["content"].keys())[0]
+ self.assertEqual(
+ latest_read_receipt["content"][event_id]["m.read"], {self.local_user: {}}
+ )
+
@unittest.override_config(
{"experimental_features": {"msc2409_to_device_messages_enabled": True}}
)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
index a54aa29cf1..751025c5da 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -201,4 +201,16 @@ class CasHandlerTestCase(HomeserverTestCase):
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
- return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
+ mock = Mock(
+ spec=[
+ "finish",
+ "getClientIP",
+ "getHeader",
+ "setHeader",
+ "setResponseCode",
+ "write",
+ ]
+ )
+ # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
+ mock._disconnected = False
+ return mock
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 8c74ed1fcf..1e6ad4b663 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -19,7 +19,6 @@ from unittest import mock
from parameterized import parameterized
from signedjson import key as key, sign as sign
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
@@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_client_keys = mock.Mock(
- return_value=defer.succeed(
+ return_value=make_awaitable(
{
"device_keys": {remote_user_id: {}},
"master_keys": {
@@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# Pretend we're sharing a room with the user we're querying. If not,
# `_query_devices_for_destination` will return early.
self.store.get_rooms_for_user = mock.Mock(
- return_value=defer.succeed({"some_room_id"})
+ return_value=make_awaitable({"some_room_id"})
)
remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
self.hs.get_federation_client().query_user_devices = mock.Mock(
- return_value=defer.succeed(
+ return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index d401fda938..addf14fa2b 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -17,8 +17,6 @@
from typing import Any, Type, Union
from unittest.mock import Mock
-from twisted.internet import defer
-
import synapse
from synapse.api.constants import LoginType
from synapse.api.errors import Codes
@@ -190,7 +188,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = defer.succeed(True)
+ mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@u:test", channel.json_body["user_id"])
@@ -226,13 +224,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.get_success(module_api.register_user("u"))
# log in twice, to get two devices
- mock_password_provider.check_password.return_value = defer.succeed(True)
+ mock_password_provider.check_password.return_value = make_awaitable(True)
tok1 = self.login("u", "p")
self.login("u", "p", device_id="dev2")
mock_password_provider.reset_mock()
# have the auth provider deny the request to start with
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
# make the initial request which returns a 401
session = self._start_delete_device_session(tok1, "dev2")
@@ -246,7 +244,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# Finally, check the request goes through when we allow it
- mock_password_provider.check_password.return_value = defer.succeed(True)
+ mock_password_provider.check_password.return_value = make_awaitable(True)
channel = self._authed_delete_device(tok1, "dev2", session, "u", "p")
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
@@ -260,7 +258,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("u", "p")
self.assertEqual(channel.code, 403, channel.result)
@@ -277,7 +275,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# have the auth provider deny the request
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
# log in twice, to get two devices
tok1 = self.login("localuser", "localpass")
@@ -320,7 +318,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# check_password must return an awaitable
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 403)
self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
@@ -342,7 +340,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.register_user("localuser", "localpass")
# allow login via the auth provider
- mock_password_provider.check_password.return_value = defer.succeed(True)
+ mock_password_provider.check_password.return_value = make_awaitable(True)
# log in twice, to get two devices
tok1 = self.login("localuser", "p")
@@ -359,7 +357,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_password.assert_not_called()
# now try deleting with the local password
- mock_password_provider.check_password.return_value = defer.succeed(False)
+ mock_password_provider.check_password.return_value = make_awaitable(False)
channel = self._authed_delete_device(
tok1, "dev2", session, "localuser", "localpass"
)
@@ -413,7 +411,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
channel = self._send_login("test.login_type", "u", test_field="y")
@@ -427,7 +425,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# try a weird username. Again, it's unclear what we *expect* to happen
# in these cases, but at least we can guard against the API changing
# unexpectedly
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
@@ -477,7 +475,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", None)
)
body["auth"]["test_field"] = "foo"
@@ -490,7 +488,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# and finally, succeed
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
@@ -508,9 +506,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.custom_auth_provider_callback_test_body()
def custom_auth_provider_callback_test_body(self):
- callback = Mock(return_value=defer.succeed(None))
+ callback = Mock(return_value=make_awaitable(None))
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@user:bz", callback)
)
channel = self._send_login("test.login_type", "u", test_field="y")
@@ -646,7 +644,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
- mock_password_provider.check_auth.return_value = defer.succeed(
+ mock_password_provider.check_auth.return_value = make_awaitable(
("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py
index 5081b97573..65ab7db0c8 100644
--- a/tests/handlers/test_receipts.py
+++ b/tests/handlers/test_receipts.py
@@ -15,7 +15,7 @@
from typing import List
-from synapse.api.constants import ReadReceiptEventFields
+from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes
from synapse.types import JsonDict
from tests import unittest
@@ -35,7 +35,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": {
"ts": 1436451550453,
"hidden": True,
@@ -56,7 +56,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916hfgh4394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@me:server.org": {
"ts": 1436451550453,
"hidden": True,
@@ -72,7 +72,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916hfgh4394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@me:server.org": {
"ts": 1436451550453,
ReadReceiptEventFields.MSC2285_HIDDEN: True,
@@ -92,7 +92,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": {
"ts": 1436451550453,
"hidden": True,
@@ -111,7 +111,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1dgdgrd5641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -130,7 +130,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$14356419edgd14394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": {
"ts": 1436451550453,
"hidden": True,
@@ -138,7 +138,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
}
},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -153,7 +153,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -171,9 +171,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[
{
"content": {
- "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}},
+ "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -187,9 +187,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
[
{
"content": {
- "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}},
+ "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -209,7 +209,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
"content": {
"$143564gdfg6114394fHBLK:matrix.org": {},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -225,7 +225,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
"content": {
"$143564gdfg6114394fHBLK:matrix.org": {},
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -244,7 +244,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$14356419edgd14394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": {
"ts": 1436451550453,
"hidden": True,
@@ -258,7 +258,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -273,7 +273,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$1435641916114394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@user:jki.re": {
"ts": 1436451550453,
}
@@ -297,7 +297,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase):
{
"content": {
"$14356419edgd14394fHBLK:matrix.org": {
- "m.read": {
+ ReceiptTypes.READ: {
"@rikj:jki.re": "string",
}
},
diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py
index 45fd30cf43..b6ba19c739 100644
--- a/tests/handlers/test_register.py
+++ b/tests/handlers/test_register.py
@@ -193,8 +193,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_not_blocked(self):
- # Type ignore: mypy doesn't like us assigning to methods.
- self.store.count_monthly_users = Mock( # type: ignore[assignment]
+ self.store.count_monthly_users = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value - 1)
)
# Ensure does not throw exception
@@ -202,8 +201,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": True})
def test_get_or_create_user_mau_blocked(self):
- # Type ignore: mypy doesn't like us assigning to methods.
- self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
+ self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.lots_of_users)
)
self.get_failure(
@@ -211,8 +209,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
- # Type ignore: mypy doesn't like us assigning to methods.
- self.store.get_monthly_active_count = Mock( # type: ignore[assignment]
+ self.store.get_monthly_active_count = Mock(
return_value=make_awaitable(self.hs.config.server.max_mau_value)
)
self.get_failure(
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 8d4404eda1..e2f0f90ef1 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -349,4 +349,16 @@ class SamlHandlerTestCase(HomeserverTestCase):
def _mock_request():
"""Returns a mock which will stand in as a SynapseRequest"""
- return Mock(spec=["getClientIP", "getHeader", "_disconnected"])
+ mock = Mock(
+ spec=[
+ "finish",
+ "getClientIP",
+ "getHeader",
+ "setHeader",
+ "setResponseCode",
+ "write",
+ ]
+ )
+ # `_disconnected` musn't be another `Mock`, otherwise it will be truthy.
+ mock._disconnected = False
+ return mock
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index ffd5c4cb93..5f2e26a5fc 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
mock_keyring = Mock(spec=["verify_json_for_server"])
- mock_keyring.verify_json_for_server.return_value = defer.succeed(True)
+ mock_keyring.verify_json_for_server.return_value = make_awaitable(True)
# we mock out the federation client too
mock_federation_client = Mock(spec=["put_json"])
- mock_federation_client.put_json.return_value = defer.succeed((200, "OK"))
+ mock_federation_client.put_json.return_value = make_awaitable((200, "OK"))
# the tests assume that we are starting at unix time 1000
reactor.pump((1000,))
@@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.datastore = hs.get_datastores().main
self.datastore.get_destination_retry_timings = Mock(
- return_value=defer.succeed(None)
+ return_value=make_awaitable(None)
)
self.datastore.get_device_updates_by_remote = Mock(
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index c6e501c7be..96e2e3039b 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -15,7 +15,6 @@ from typing import Tuple
from unittest.mock import Mock, patch
from urllib.parse import quote
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -30,6 +29,7 @@ from synapse.util import Clock
from tests import unittest
from tests.storage.test_user_directory import GetUserDirectoryTables
+from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_member_event
from tests.unittest import override_config
@@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
)
- mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+ mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
@@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.store.register_user(user_id=r_user_id, password_hash=None)
)
- mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+ mock_remove_from_user_dir = Mock(return_value=make_awaitable(None))
with patch.object(
self.store, "remove_from_user_dir", mock_remove_from_user_dir
):
diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py
index bec018d9e7..89009bea8c 100644
--- a/tests/module_api/test_account_data_manager.py
+++ b/tests/module_api/test_account_data_manager.py
@@ -11,8 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
from synapse.api.errors import SynapseError
from synapse.rest import admin
+from synapse.server import HomeServer
+from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -22,7 +26,9 @@ class ModuleApiTestCase(HomeserverTestCase):
admin.register_servlets,
]
- def prepare(self, reactor, clock, homeserver) -> None:
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
self._store = homeserver.get_datastores().main
self._module_api = homeserver.get_module_api()
self._account_data_mgr = self._module_api.account_data_manager
@@ -91,7 +97,7 @@ class ModuleApiTestCase(HomeserverTestCase):
)
with self.assertRaises(TypeError):
# This throws an exception because it's a frozen dict.
- the_data["wombat"] = False
+ the_data["wombat"] = False # type: ignore[index]
def test_put_global(self) -> None:
"""
@@ -143,15 +149,14 @@ class ModuleApiTestCase(HomeserverTestCase):
with self.assertRaises(TypeError):
# The account data type must be a string.
self.get_success_or_raise(
- self._module_api.account_data_manager.put_global(
- self.user_id, 42, {} # type: ignore
- )
+ self._module_api.account_data_manager.put_global(self.user_id, 42, {}) # type: ignore[arg-type]
)
with self.assertRaises(TypeError):
# The account data dict must be a dict.
+ # noinspection PyTypeChecker
self.get_success_or_raise(
self._module_api.account_data_manager.put_global(
- self.user_id, "test.data", 42 # type: ignore
+ self.user_id, "test.data", 42 # type: ignore[arg-type]
)
)
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9fd5d59c55..8bc84aaaca 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes
from synapse.events import EventBase
from synapse.federation.units import Transaction
from synapse.handlers.presence import UserPresenceState
+from synapse.handlers.push_rules import InvalidRuleException
from synapse.rest import admin
-from synapse.rest.client import login, presence, profile, room
+from synapse.rest.client import login, notifications, presence, profile, room
from synapse.types import create_requester
from tests.events.test_presence_router import send_presence_update, sync_presence
@@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase):
room.register_servlets,
presence.register_servlets,
profile.register_servlets,
+ notifications.register_servlets,
]
def prepare(self, reactor, clock, homeserver):
@@ -553,6 +555,86 @@ class ModuleApiTestCase(HomeserverTestCase):
self.assertEqual(state[("org.matrix.test", "")].state_key, "")
self.assertEqual(state[("org.matrix.test", "")].content, {})
+ def test_set_push_rules_action(self) -> None:
+ """Test that a module can change the actions of an existing push rule for a user."""
+
+ # Create a room with 2 users in it. Push rules must not match if the user is the
+ # event's sender, so we need one user to send messages and one user to receive
+ # notifications.
+ user_id = self.register_user("user", "password")
+ tok = self.login("user", "password")
+
+ room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok)
+
+ user_id2 = self.register_user("user2", "password")
+ tok2 = self.login("user2", "password")
+ self.helper.join(room_id, user_id2, tok=tok2)
+
+ # Register a 3rd user and join them to the room, so that we don't accidentally
+ # trigger 1:1 push rules.
+ user_id3 = self.register_user("user3", "password")
+ tok3 = self.login("user3", "password")
+ self.helper.join(room_id, user_id3, tok=tok3)
+
+ # Send a message as the second user and check that it notifies.
+ res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2)
+ event_id = res["event_id"]
+
+ channel = self.make_request(
+ "GET",
+ "/notifications",
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+
+ self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
+ self.assertEqual(
+ channel.json_body["notifications"][0]["event"]["event_id"],
+ event_id,
+ channel.json_body,
+ )
+
+ # Change the .m.rule.message actions to not notify on new messages.
+ self.get_success(
+ defer.ensureDeferred(
+ self.module_api.set_push_rule_action(
+ user_id=user_id,
+ scope="global",
+ kind="underride",
+ rule_id=".m.rule.message",
+ actions=["dont_notify"],
+ )
+ )
+ )
+
+ # Send another message as the second user and check that the number of
+ # notifications didn't change.
+ self.helper.send(room_id=room_id, body="here's another message", tok=tok2)
+
+ channel = self.make_request(
+ "GET",
+ "/notifications?from=",
+ access_token=tok,
+ )
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body)
+
+ def test_check_push_rules_actions(self) -> None:
+ """Test that modules can check whether a list of push rules actions are spec
+ compliant.
+ """
+ with self.assertRaises(InvalidRuleException):
+ self.module_api.check_push_rule_actions(["foo"])
+
+ with self.assertRaises(InvalidRuleException):
+ self.module_api.check_push_rule_actions({"foo": "bar"})
+
+ self.module_api.check_push_rule_actions(["notify"])
+
+ self.module_api.check_push_rule_actions(
+ [{"set_tweak": "sound", "value": "default"}]
+ )
+
class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
"""For testing ModuleApi functionality in a multi-worker setup"""
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
index f47d94f690..de19e75b9d 100644
--- a/tests/replication/slave/storage/test_receipts.py
+++ b/tests/replication/slave/storage/test_receipts.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from synapse.api.constants import ReceiptTypes
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from ._base import BaseSlavedStoreTestCase
@@ -26,9 +27,13 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
STORE_TYPE = SlavedReceiptsStore
def test_receipt(self):
- self.check("get_receipts_for_user", [USER_ID, "m.read"], {})
+ self.check("get_receipts_for_user", [USER_ID, ReceiptTypes.READ], {})
self.get_success(
- self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {})
+ self.master_store.insert_receipt(
+ ROOM_ID, ReceiptTypes.READ, USER_ID, [EVENT_ID], {}
+ )
)
self.replicate()
- self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID})
+ self.check(
+ "get_receipts_for_user", [USER_ID, ReceiptTypes.READ], {ROOM_ID: EVENT_ID}
+ )
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index ba1a63c0d6..6104a55aa1 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -102,8 +102,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
- mock_client1.reset_mock() # type: ignore[attr-defined]
- mock_client2.reset_mock() # type: ignore[attr-defined]
+ mock_client1.reset_mock()
+ mock_client2.reset_mock()
self.create_and_send_event(room, UserID.from_string(user))
self.replicate()
@@ -167,8 +167,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
for i in range(20):
server_name = "other_server_%d" % (i,)
room = self.create_room_with_remote_server(user, token, server_name)
- mock_client1.reset_mock() # type: ignore[attr-defined]
- mock_client2.reset_mock() # type: ignore[attr-defined]
+ mock_client1.reset_mock()
+ mock_client2.reset_mock()
self.get_success(
typing_handler.started_typing(
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index 0abe378fe4..b3738a0304 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -14,7 +14,6 @@
from http import HTTPStatus
from unittest.mock import Mock
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.handlers.presence import PresenceHandler
@@ -24,6 +23,7 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
+from tests.test_utils import make_awaitable
class PresenceTestCase(unittest.HomeserverTestCase):
@@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
presence_handler = Mock(spec=PresenceHandler)
- presence_handler.set_state.return_value = defer.succeed(None)
+ presence_handler.set_state.return_value = make_awaitable(None)
hs = self.setup_test_homeserver(
"red",
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 6ff79b9e2e..9443daa056 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional
from unittest.mock import Mock, call
from urllib import parse as urlparse
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
@@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
def test_simple(self) -> None:
"Simple test for searching rooms over federation"
- self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined]
- {}
- )
+ self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
search_filter = {"generic_search_term": "foobar"}
@@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(404, "Not Found", b""),
- defer.succeed({}),
+ make_awaitable({}),
)
search_filter = {"generic_search_term": "foobar"}
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 773c16a54c..cb765455c1 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -24,6 +24,7 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
ReadReceiptEventFields,
+ ReceiptTypes,
RelationTypes,
)
from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync
@@ -560,7 +561,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
- body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
+ body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8")
channel = self.make_request(
"POST",
"/rooms/%s/read_markers" % self.room_id,
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index 8d8251b2ac..21a1ca2a68 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC
from synapse.util import Clock
from tests import unittest
+from tests.test_utils import make_awaitable
from tests.utils import MockClock
@@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_executes_given_function(self):
- cb = Mock(return_value=defer.succeed(self.mock_http_response))
+ cb = Mock(return_value=make_awaitable(self.mock_http_response))
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg"
)
@@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_deduplicates_based_on_key(self):
- cb = Mock(return_value=defer.succeed(self.mock_http_response))
+ cb = Mock(return_value=make_awaitable(self.mock_http_response))
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute(
self.mock_key, cb, "some_arg", keyword="arg", changing_args=i
@@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cleans_up(self):
- cb = Mock(return_value=defer.succeed(self.mock_http_response))
+ cb = Mock(return_value=make_awaitable(self.mock_http_response))
yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg")
# should NOT have cleaned up yet
self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2)
diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py
index 02b96c9e6e..9ee9509d3a 100644
--- a/tests/server_notices/test_resource_limits_server_notices.py
+++ b/tests/server_notices/test_resource_limits_server_notices.py
@@ -14,8 +14,6 @@
from unittest.mock import Mock
-from twisted.internet import defer
-
from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType
from synapse.api.errors import ResourceLimitError
from synapse.rest import admin
@@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
return_value=make_awaitable(1000)
)
self._rlsn._server_notices_manager.send_notice = Mock(
- return_value=defer.succeed(Mock())
+ return_value=make_awaitable(Mock())
)
self._send_notice = self._rlsn._server_notices_manager.send_notice
self.user_id = "@user_id:test"
self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock(
- return_value=defer.succeed("!something:localhost")
+ return_value=make_awaitable("!something:localhost")
)
- self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None))
+ self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None))
self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({}))
@override_config({"hs_disabled": True})
@@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
def test_maybe_send_server_notice_to_user_remove_blocked_notice(self):
"""Test when user has blocked notice, but should have it removed"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
mock_event = Mock(
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
@@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user has blocked notice, but notice ought to be there (NOOP)
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
+ return_value=make_awaitable(None),
+ side_effect=ResourceLimitError(403, "foo"),
)
mock_event = Mock(
@@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user does not have blocked notice, but should have one
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo")
+ return_value=make_awaitable(None),
+ side_effect=ResourceLimitError(403, "foo"),
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
"""
Test when user does not have blocked notice, nor should they (NOOP)
"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test when user is not part of the MAU cohort - this should not ever
happen - but ...
"""
- self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None))
+ self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None))
self._rlsn._store.user_last_seen_monthly_active = Mock(
return_value=make_awaitable(None)
)
@@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
an alert message is not sent into the room
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None),
+ return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
@@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
Test that when a server is disabled, that MAU limit alerting is ignored.
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None),
+ return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED
),
@@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
is suppressed that the room is returned to an unblocked state.
"""
self._rlsn._auth.check_auth_blocking = Mock(
- return_value=defer.succeed(None),
+ return_value=make_awaitable(None),
side_effect=ResourceLimitError(
403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER
),
)
self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock(
- return_value=defer.succeed((True, []))
+ return_value=make_awaitable((True, []))
)
mock_event = Mock(
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 60c8d37594..0fbf465670 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -14,7 +14,6 @@
from typing import Any, Dict, List
from unittest.mock import Mock
-from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import UserTypes
@@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_update(self):
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
- self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
+ self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(None)
+ return_value=make_awaitable(None)
)
d = self.store.populate_monthly_active_users("user_id")
self.get_success(d)
@@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
def test_populate_monthly_users_should_not_update(self):
self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
- self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment]
+ self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
self.store.user_last_seen_monthly_active = Mock(
- return_value=defer.succeed(self.hs.get_clock().time_msec())
+ return_value=make_awaitable(self.hs.get_clock().time_msec())
)
d = self.store.populate_monthly_active_users("user_id")
diff --git a/tests/test_federation.py b/tests/test_federation.py
index c39816de85..0cbef70bfa 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
# Register mock device list retrieval on the federation client.
federation_client = self.homeserver.get_federation_client()
federation_client.query_user_devices = Mock(
- return_value=succeed(
+ return_value=make_awaitable(
{
"user_id": remote_user_id,
"stream_id": 1,
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index f05a373aa0..0d0d6faf0d 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -52,7 +52,7 @@ def make_awaitable(result: TV) -> Awaitable[TV]:
This uses Futures as they can be awaited multiple times so can be returned
to multiple callers.
"""
- future = Future() # type: ignore
+ future: Future[TV] = Future()
future.set_result(result)
return future
@@ -69,7 +69,7 @@ def setup_awaitable_errors() -> Callable[[], None]:
# State shared between unraisablehook and check_for_unraisable_exceptions.
unraisable_exceptions = []
- orig_unraisablehook = sys.unraisablehook # type: ignore
+ orig_unraisablehook = sys.unraisablehook
def unraisablehook(unraisable):
unraisable_exceptions.append(unraisable.exc_value)
@@ -78,11 +78,11 @@ def setup_awaitable_errors() -> Callable[[], None]:
"""
A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions.
"""
- sys.unraisablehook = orig_unraisablehook # type: ignore
+ sys.unraisablehook = orig_unraisablehook
if unraisable_exceptions:
raise unraisable_exceptions.pop()
- sys.unraisablehook = unraisablehook # type: ignore
+ sys.unraisablehook = unraisablehook
return cleanup
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
index 51a197a8c6..9228454c9e 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler):
def emit(self, record):
log_entry = self.format(record)
log_level = record.levelname.lower().replace("warning", "warn")
- self.tx_log.emit( # type: ignore
+ self.tx_log.emit(
twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
)
|