diff options
Diffstat (limited to 'tests')
34 files changed, 1059 insertions, 250 deletions
diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 3e05789923..d547df8a64 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -105,7 +105,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) @@ -124,7 +124,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "192.168.10.10" + request.getClientAddress.return_value.host = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() requester = self.get_success(self.auth.get_user_by_req(request)) @@ -143,7 +143,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "131.111.8.42" + request.getClientAddress.return_value.host = "131.111.8.42" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() f = self.get_failure( @@ -190,7 +190,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -209,7 +209,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_user_by_access_token = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() @@ -236,7 +236,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_device = simple_async_mock({"hidden": False}) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] @@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_device = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.args[b"org.matrix.msc3202.device_id"] = [masquerading_device_id] @@ -288,7 +288,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.store.insert_client_ip = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) @@ -305,7 +305,7 @@ class AuthTestCase(unittest.HomeserverTestCase): ) self.store.insert_client_ip = simple_async_mock(None) request = Mock(args={}) - request.getClientIP.return_value = "127.0.0.1" + request.getClientAddress.return_value.host = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() self.get_success(self.auth.get_user_by_req(request)) diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py index ec8864dafe..268a48d7ba 100644 --- a/tests/federation/test_federation_client.py +++ b/tests/federation/test_federation_client.py @@ -83,7 +83,7 @@ class FederationClientTest(FederatingHomeserverTestCase): ) # mock up the response, and have the agent return it - self._mock_agent.request.return_value = defer.succeed( + self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed( _mock_response( { "pdus": [ diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index 91f982518e..6b26353d5e 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -226,7 +226,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase): # Send the server a device list EDU for the other user, this will cause # it to try and resync the device lists. self.hs.get_federation_transport_client().query_user_devices.return_value = ( - defer.succeed( + make_awaitable( { "stream_id": "1", "user_id": "@user2:host2", diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 8c72cf6b30..5b0cd1ab86 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -411,6 +411,88 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): "exclusive_as_user", "password", self.exclusive_as_user_device_id ) + def test_sending_read_receipt_batches_to_application_services(self): + """Tests that a large batch of read receipts are sent correctly to + interested application services. + """ + # Register an application service that's interested in a certain user + # and room prefix + interested_appservice = self._register_application_service( + namespaces={ + ApplicationService.NS_USERS: [ + { + "regex": "@exclusive_as_user:.+", + "exclusive": True, + } + ], + ApplicationService.NS_ROOMS: [ + { + "regex": "!fakeroom_.*", + "exclusive": True, + } + ], + }, + ) + + # "Complete" a transaction. + # All this really does for us is make an entry in the application_services_state + # database table, which tracks the current stream_token per stream ID per AS. + self.get_success( + self.hs.get_datastores().main.complete_appservice_txn( + 0, + interested_appservice, + ) + ) + + # Now, pretend that we receive a large burst of read receipts (300 total) that + # all come in at once. + for i in range(300): + self.get_success( + # Insert a fake read receipt into the database + self.hs.get_datastores().main.insert_receipt( + # We have to use unique room ID + user ID combinations here, as the db query + # is an upsert. + room_id=f"!fakeroom_{i}:test", + receipt_type="m.read", + user_id=self.local_user, + event_ids=[f"$eventid_{i}"], + data={}, + ) + ) + + # Now notify the appservice handler that 300 read receipts have all arrived + # at once. What will it do! + # note: stream tokens start at 2 + for stream_token in range(2, 303): + self.get_success( + self.hs.get_application_service_handler()._notify_interested_services_ephemeral( + services=[interested_appservice], + stream_key="receipt_key", + new_token=stream_token, + users=[self.exclusive_as_user], + ) + ) + + # Using our txn send mock, we can see what the AS received. After iterating over every + # transaction, we'd like to see all 300 read receipts accounted for. + # No more, no less. + all_ephemeral_events = [] + for call in self.send_mock.call_args_list: + ephemeral_events = call[0][2] + all_ephemeral_events += ephemeral_events + + # Ensure that no duplicate events were sent + self.assertEqual(len(all_ephemeral_events), 300) + + # Check that the ephemeral event is a read receipt with the expected structure + latest_read_receipt = all_ephemeral_events[-1] + self.assertEqual(latest_read_receipt["type"], "m.receipt") + + event_id = list(latest_read_receipt["content"].keys())[0] + self.assertEqual( + latest_read_receipt["content"][event_id]["m.read"], {self.local_user: {}} + ) + @unittest.override_config( {"experimental_features": {"msc2409_to_device_messages_enabled": True}} ) diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index a54aa29cf1..2b21547d0f 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -201,4 +201,16 @@ class CasHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) + mock = Mock( + spec=[ + "finish", + "getClientAddress", + "getHeader", + "setHeader", + "setResponseCode", + "write", + ] + ) + # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. + mock._disconnected = False + return mock diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 8c74ed1fcf..1e6ad4b663 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -19,7 +19,6 @@ from unittest import mock from parameterized import parameterized from signedjson import key as key, sign as sign -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms @@ -704,7 +703,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" self.hs.get_federation_client().query_client_keys = mock.Mock( - return_value=defer.succeed( + return_value=make_awaitable( { "device_keys": {remote_user_id: {}}, "master_keys": { @@ -777,14 +776,14 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): # Pretend we're sharing a room with the user we're querying. If not, # `_query_devices_for_destination` will return early. self.store.get_rooms_for_user = mock.Mock( - return_value=defer.succeed({"some_room_id"}) + return_value=make_awaitable({"some_room_id"}) ) remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY" remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ" self.hs.get_federation_client().query_user_devices = mock.Mock( - return_value=defer.succeed( + return_value=make_awaitable( { "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 9684120c70..1231aed944 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -1300,7 +1300,7 @@ def _build_callback_request( "getCookie", "cookies", "requestHeaders", - "getClientIP", + "getClientAddress", "getHeader", ] ) @@ -1310,5 +1310,5 @@ def _build_callback_request( request.args = {} request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] - request.getClientIP.return_value = ip_address + request.getClientAddress.return_value.host = ip_address return request diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index d401fda938..82b3bb3b73 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -17,8 +17,6 @@ from typing import Any, Type, Union from unittest.mock import Mock -from twisted.internet import defer - import synapse from synapse.api.constants import LoginType from synapse.api.errors import Codes @@ -32,11 +30,9 @@ from tests.server import FakeChannel from tests.test_utils import make_awaitable from tests.unittest import override_config -# (possibly experimental) login flows we expect to appear in the list after the normal -# ones +# Login flows we expect to appear in the list after the normal ones. ADDITIONAL_LOGIN_FLOWS = [ {"type": "m.login.application_service"}, - {"type": "uk.half-shot.msc2778.login.application_service"}, ] # a mock instance which the dummy auth providers delegate to, so we can see what's going @@ -190,7 +186,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@u:test", channel.json_body["user_id"]) @@ -226,13 +222,13 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.get_success(module_api.register_user("u")) # log in twice, to get two devices - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) tok1 = self.login("u", "p") self.login("u", "p", device_id="dev2") mock_password_provider.reset_mock() # have the auth provider deny the request to start with - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) # make the initial request which returns a 401 session = self._start_delete_device_session(tok1, "dev2") @@ -246,7 +242,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # Finally, check the request goes through when we allow it - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) channel = self._authed_delete_device(tok1, "dev2", session, "u", "p") self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") @@ -260,7 +256,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("u", "p") self.assertEqual(channel.code, 403, channel.result) @@ -277,7 +273,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # have the auth provider deny the request - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) # log in twice, to get two devices tok1 = self.login("localuser", "localpass") @@ -320,7 +316,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # check_password must return an awaitable - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 403) self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN") @@ -342,7 +338,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.register_user("localuser", "localpass") # allow login via the auth provider - mock_password_provider.check_password.return_value = defer.succeed(True) + mock_password_provider.check_password.return_value = make_awaitable(True) # log in twice, to get two devices tok1 = self.login("localuser", "p") @@ -359,7 +355,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_password.assert_not_called() # now try deleting with the local password - mock_password_provider.check_password.return_value = defer.succeed(False) + mock_password_provider.check_password.return_value = make_awaitable(False) channel = self._authed_delete_device( tok1, "dev2", session, "localuser", "localpass" ) @@ -413,7 +409,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) channel = self._send_login("test.login_type", "u", test_field="y") @@ -427,7 +423,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # try a weird username. Again, it's unclear what we *expect* to happen # in these cases, but at least we can guard against the API changing # unexpectedly - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") @@ -477,7 +473,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", None) ) body["auth"]["test_field"] = "foo" @@ -490,7 +486,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # and finally, succeed - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) @@ -508,9 +504,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.custom_auth_provider_callback_test_body() def custom_auth_provider_callback_test_body(self): - callback = Mock(return_value=defer.succeed(None)) + callback = Mock(return_value=make_awaitable(None)) - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@user:bz", callback) ) channel = self._send_login("test.login_type", "u", test_field="y") @@ -646,7 +642,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") - mock_password_provider.check_auth.return_value = defer.succeed( + mock_password_provider.check_auth.return_value = make_awaitable( ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") diff --git a/tests/handlers/test_receipts.py b/tests/handlers/test_receipts.py index 5081b97573..0482a1ea34 100644 --- a/tests/handlers/test_receipts.py +++ b/tests/handlers/test_receipts.py @@ -15,7 +15,7 @@ from typing import List -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReceiptTypes from synapse.types import JsonDict from tests import unittest @@ -25,20 +25,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_source = hs.get_event_sources().sources.receipt - # In the first param of _test_filters_hidden we use "hidden" instead of - # ReadReceiptEventFields.MSC2285_HIDDEN. We do this because we're mocking - # the data from the database which doesn't use the prefix - - def test_filters_out_hidden_receipt(self): - self._test_filters_hidden( + def test_filters_out_private_receipt(self): + self._test_filters_private( [ { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, - "hidden": True, } } } @@ -50,58 +45,23 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [], ) - def test_does_not_filter_out_our_hidden_receipt(self): - self._test_filters_hidden( - [ - { - "content": { - "$1435641916hfgh4394fHBLK:matrix.org": { - "m.read": { - "@me:server.org": { - "ts": 1436451550453, - "hidden": True, - }, - } - } - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - } - ], - [ - { - "content": { - "$1435641916hfgh4394fHBLK:matrix.org": { - "m.read": { - "@me:server.org": { - "ts": 1436451550453, - ReadReceiptEventFields.MSC2285_HIDDEN: True, - }, - } - } - }, - "room_id": "!jEsUZKDJdhlrceRyVU:example.org", - "type": "m.receipt", - } - ], - ) - - def test_filters_out_hidden_receipt_and_ignores_rest(self): - self._test_filters_hidden( + def test_filters_out_private_receipt_and_ignores_rest(self): + self._test_filters_private( [ { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, - "hidden": True, }, + }, + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, }, - } - } + }, + }, }, "room_id": "!jEsUZKDJdhlrceRyVU:example.org", "type": "m.receipt", @@ -111,7 +71,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1dgdgrd5641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -124,21 +84,20 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_filters_out_event_with_only_hidden_receipts_and_ignores_the_rest(self): - self._test_filters_hidden( + def test_filters_out_event_with_only_private_receipts_and_ignores_the_rest(self): + self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, - "hidden": True, }, } }, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -153,7 +112,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -167,13 +126,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ) def test_handles_missing_content_of_m_read(self): - self._test_filters_hidden( + self._test_filters_private( [ { "content": { - "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}}, + "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -187,9 +146,9 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [ { "content": { - "$14356419ggffg114394fHBLK:matrix.org": {"m.read": {}}, + "$14356419ggffg114394fHBLK:matrix.org": {ReceiptTypes.READ: {}}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -203,13 +162,13 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ) def test_handles_empty_event(self): - self._test_filters_hidden( + self._test_filters_private( [ { "content": { "$143564gdfg6114394fHBLK:matrix.org": {}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -223,9 +182,8 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): [ { "content": { - "$143564gdfg6114394fHBLK:matrix.org": {}, "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -238,16 +196,15 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): ], ) - def test_filters_out_receipt_event_with_only_hidden_receipt_and_ignores_rest(self): - self._test_filters_hidden( + def test_filters_out_receipt_event_with_only_private_receipt_and_ignores_rest(self): + self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ_PRIVATE: { "@rikj:jki.re": { "ts": 1436451550453, - "hidden": True, }, } }, @@ -258,7 +215,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -273,7 +230,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): { "content": { "$1435641916114394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@user:jki.re": { "ts": 1436451550453, } @@ -292,12 +249,12 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): Context: https://github.com/matrix-org/synapse/issues/10603 """ - self._test_filters_hidden( + self._test_filters_private( [ { "content": { "$14356419edgd14394fHBLK:matrix.org": { - "m.read": { + ReceiptTypes.READ: { "@rikj:jki.re": "string", } }, @@ -306,12 +263,78 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): "type": "m.receipt", }, ], - [], + [ + { + "content": { + "$14356419edgd14394fHBLK:matrix.org": { + ReceiptTypes.READ: { + "@rikj:jki.re": "string", + } + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + }, + ], + ) + + def test_leaves_our_private_and_their_public(self): + self._test_filters_private( + [ + { + "content": { + "$1dgdgrd5641916114394fHBLK:matrix.org": { + ReceiptTypes.READ_PRIVATE: { + "@me:server.org": { + "ts": 1436451550453, + }, + }, + ReceiptTypes.READ: { + "@rikj:jki.re": { + "ts": 1436451550453, + }, + }, + "a.receipt.type": { + "@rikj:jki.re": { + "ts": 1436451550453, + }, + }, + }, + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], + [ + { + "content": { + "$1dgdgrd5641916114394fHBLK:matrix.org": { + ReceiptTypes.READ_PRIVATE: { + "@me:server.org": { + "ts": 1436451550453, + }, + }, + ReceiptTypes.READ: { + "@rikj:jki.re": { + "ts": 1436451550453, + }, + }, + "a.receipt.type": { + "@rikj:jki.re": { + "ts": 1436451550453, + }, + }, + } + }, + "room_id": "!jEsUZKDJdhlrceRyVU:example.org", + "type": "m.receipt", + } + ], ) - def _test_filters_hidden( + def _test_filters_private( self, events: List[JsonDict], expected_output: List[JsonDict] ): - """Tests that the _filter_out_hidden returns the expected output""" - filtered_events = self.event_source.filter_out_hidden(events, "@me:server.org") + """Tests that the _filter_out_private returns the expected output""" + filtered_events = self.event_source.filter_out_private(events, "@me:server.org") self.assertEqual(filtered_events, expected_output) diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index 45fd30cf43..b6ba19c739 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -193,8 +193,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_not_blocked(self): - # Type ignore: mypy doesn't like us assigning to methods. - self.store.count_monthly_users = Mock( # type: ignore[assignment] + self.store.count_monthly_users = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception @@ -202,8 +201,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): @override_config({"limit_usage_by_mau": True}) def test_get_or_create_user_mau_blocked(self): - # Type ignore: mypy doesn't like us assigning to methods. - self.store.get_monthly_active_count = Mock( # type: ignore[assignment] + self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) self.get_failure( @@ -211,8 +209,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ResourceLimitError, ) - # Type ignore: mypy doesn't like us assigning to methods. - self.store.get_monthly_active_count = Mock( # type: ignore[assignment] + self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py index d37292ce13..e74eb71774 100644 --- a/tests/handlers/test_room_summary.py +++ b/tests/handlers/test_room_summary.py @@ -1092,3 +1092,29 @@ class RoomSummaryTestCase(unittest.HomeserverTestCase): ) result = self.get_success(self.handler.get_room_summary(user2, self.room)) self.assertEqual(result.get("room_id"), self.room) + + def test_fed(self): + """ + Return data over federation and ensure that it is handled properly. + """ + fed_hostname = self.hs.hostname + "2" + fed_room = "#fed_room:" + fed_hostname + + requested_room_entry = _RoomEntry( + fed_room, + {"room_id": fed_room, "world_readable": True}, + ) + + async def summarize_remote_room_hierarchy(_self, room, suggested_only): + return requested_room_entry, {}, set() + + with mock.patch( + "synapse.handlers.room_summary.RoomSummaryHandler._summarize_remote_room_hierarchy", + new=summarize_remote_room_hierarchy, + ): + result = self.get_success( + self.handler.get_room_summary( + self.user, fed_room, remote_room_hosts=[fed_hostname] + ) + ) + self.assertEqual(result.get("room_id"), fed_room) diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 8d4404eda1..a0f84e2940 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -349,4 +349,16 @@ class SamlHandlerTestCase(HomeserverTestCase): def _mock_request(): """Returns a mock which will stand in as a SynapseRequest""" - return Mock(spec=["getClientIP", "getHeader", "_disconnected"]) + mock = Mock( + spec=[ + "finish", + "getClientAddress", + "getHeader", + "setHeader", + "setResponseCode", + "write", + ] + ) + # `_disconnected` musn't be another `Mock`, otherwise it will be truthy. + mock._disconnected = False + return mock diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index ffd5c4cb93..5f2e26a5fc 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -65,11 +65,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): # we mock out the keyring so as to skip the authentication check on the # federation API call. mock_keyring = Mock(spec=["verify_json_for_server"]) - mock_keyring.verify_json_for_server.return_value = defer.succeed(True) + mock_keyring.verify_json_for_server.return_value = make_awaitable(True) # we mock out the federation client too mock_federation_client = Mock(spec=["put_json"]) - mock_federation_client.put_json.return_value = defer.succeed((200, "OK")) + mock_federation_client.put_json.return_value = make_awaitable((200, "OK")) # the tests assume that we are starting at unix time 1000 reactor.pump((1000,)) @@ -98,7 +98,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore = hs.get_datastores().main self.datastore.get_destination_retry_timings = Mock( - return_value=defer.succeed(None) + return_value=make_awaitable(None) ) self.datastore.get_device_updates_by_remote = Mock( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index c6e501c7be..96e2e3039b 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -15,7 +15,6 @@ from typing import Tuple from unittest.mock import Mock, patch from urllib.parse import quote -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -30,6 +29,7 @@ from synapse.util import Clock from tests import unittest from tests.storage.test_user_directory import GetUserDirectoryTables +from tests.test_utils import make_awaitable from tests.test_utils.event_injection import inject_member_event from tests.unittest import override_config @@ -439,7 +439,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) ) - mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): @@ -454,7 +454,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.store.register_user(user_id=r_user_id, password_hash=None) ) - mock_remove_from_user_dir = Mock(return_value=defer.succeed(None)) + mock_remove_from_user_dir = Mock(return_value=make_awaitable(None)) with patch.object( self.store, "remove_from_user_dir", mock_remove_from_user_dir ): diff --git a/tests/module_api/test_account_data_manager.py b/tests/module_api/test_account_data_manager.py index bec018d9e7..89009bea8c 100644 --- a/tests/module_api/test_account_data_manager.py +++ b/tests/module_api/test_account_data_manager.py @@ -11,8 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import SynapseError from synapse.rest import admin +from synapse.server import HomeServer +from synapse.util import Clock from tests.unittest import HomeserverTestCase @@ -22,7 +26,9 @@ class ModuleApiTestCase(HomeserverTestCase): admin.register_servlets, ] - def prepare(self, reactor, clock, homeserver) -> None: + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: self._store = homeserver.get_datastores().main self._module_api = homeserver.get_module_api() self._account_data_mgr = self._module_api.account_data_manager @@ -91,7 +97,7 @@ class ModuleApiTestCase(HomeserverTestCase): ) with self.assertRaises(TypeError): # This throws an exception because it's a frozen dict. - the_data["wombat"] = False + the_data["wombat"] = False # type: ignore[index] def test_put_global(self) -> None: """ @@ -143,15 +149,14 @@ class ModuleApiTestCase(HomeserverTestCase): with self.assertRaises(TypeError): # The account data type must be a string. self.get_success_or_raise( - self._module_api.account_data_manager.put_global( - self.user_id, 42, {} # type: ignore - ) + self._module_api.account_data_manager.put_global(self.user_id, 42, {}) # type: ignore[arg-type] ) with self.assertRaises(TypeError): # The account data dict must be a dict. + # noinspection PyTypeChecker self.get_success_or_raise( self._module_api.account_data_manager.put_global( - self.user_id, "test.data", 42 # type: ignore + self.user_id, "test.data", 42 # type: ignore[arg-type] ) ) diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py index 9fd5d59c55..8bc84aaaca 100644 --- a/tests/module_api/test_api.py +++ b/tests/module_api/test_api.py @@ -19,8 +19,9 @@ from synapse.api.constants import EduTypes, EventTypes from synapse.events import EventBase from synapse.federation.units import Transaction from synapse.handlers.presence import UserPresenceState +from synapse.handlers.push_rules import InvalidRuleException from synapse.rest import admin -from synapse.rest.client import login, presence, profile, room +from synapse.rest.client import login, notifications, presence, profile, room from synapse.types import create_requester from tests.events.test_presence_router import send_presence_update, sync_presence @@ -38,6 +39,7 @@ class ModuleApiTestCase(HomeserverTestCase): room.register_servlets, presence.register_servlets, profile.register_servlets, + notifications.register_servlets, ] def prepare(self, reactor, clock, homeserver): @@ -553,6 +555,86 @@ class ModuleApiTestCase(HomeserverTestCase): self.assertEqual(state[("org.matrix.test", "")].state_key, "") self.assertEqual(state[("org.matrix.test", "")].content, {}) + def test_set_push_rules_action(self) -> None: + """Test that a module can change the actions of an existing push rule for a user.""" + + # Create a room with 2 users in it. Push rules must not match if the user is the + # event's sender, so we need one user to send messages and one user to receive + # notifications. + user_id = self.register_user("user", "password") + tok = self.login("user", "password") + + room_id = self.helper.create_room_as(user_id, is_public=True, tok=tok) + + user_id2 = self.register_user("user2", "password") + tok2 = self.login("user2", "password") + self.helper.join(room_id, user_id2, tok=tok2) + + # Register a 3rd user and join them to the room, so that we don't accidentally + # trigger 1:1 push rules. + user_id3 = self.register_user("user3", "password") + tok3 = self.login("user3", "password") + self.helper.join(room_id, user_id3, tok=tok3) + + # Send a message as the second user and check that it notifies. + res = self.helper.send(room_id=room_id, body="here's a message", tok=tok2) + event_id = res["event_id"] + + channel = self.make_request( + "GET", + "/notifications", + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + self.assertEqual( + channel.json_body["notifications"][0]["event"]["event_id"], + event_id, + channel.json_body, + ) + + # Change the .m.rule.message actions to not notify on new messages. + self.get_success( + defer.ensureDeferred( + self.module_api.set_push_rule_action( + user_id=user_id, + scope="global", + kind="underride", + rule_id=".m.rule.message", + actions=["dont_notify"], + ) + ) + ) + + # Send another message as the second user and check that the number of + # notifications didn't change. + self.helper.send(room_id=room_id, body="here's another message", tok=tok2) + + channel = self.make_request( + "GET", + "/notifications?from=", + access_token=tok, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(len(channel.json_body["notifications"]), 1, channel.json_body) + + def test_check_push_rules_actions(self) -> None: + """Test that modules can check whether a list of push rules actions are spec + compliant. + """ + with self.assertRaises(InvalidRuleException): + self.module_api.check_push_rule_actions(["foo"]) + + with self.assertRaises(InvalidRuleException): + self.module_api.check_push_rule_actions({"foo": "bar"}) + + self.module_api.check_push_rule_actions(["notify"]) + + self.module_api.check_push_rule_actions( + [{"set_tweak": "sound", "value": "default"}] + ) + class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase): """For testing ModuleApi functionality in a multi-worker setup""" diff --git a/tests/replication/_base.py b/tests/replication/_base.py index a0589b6d6a..a7602b4c96 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -154,10 +154,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(port, 8765) # Set up client side protocol - client_protocol = client_factory.buildProtocol(None) + client_address = IPv4Address("TCP", "127.0.0.1", 1234) + client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234)) # Set up the server side protocol - channel = self.site.buildProtocol(None) + server_address = IPv4Address("TCP", host, port) + channel = self.site.buildProtocol((host, port)) # hook into the channel's request factory so that we can keep a record # of the requests @@ -173,12 +175,12 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Connect client to server and vice versa. client_to_server_transport = FakeTransport( - channel, self.reactor, client_protocol + channel, self.reactor, client_protocol, server_address, client_address ) client_protocol.makeConnection(client_to_server_transport) server_to_client_transport = FakeTransport( - client_protocol, self.reactor, channel + client_protocol, self.reactor, channel, client_address, server_address ) channel.makeConnection(server_to_client_transport) @@ -406,19 +408,21 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(port, repl_port) # Set up client side protocol - client_protocol = client_factory.buildProtocol(None) + client_address = IPv4Address("TCP", "127.0.0.1", 1234) + client_protocol = client_factory.buildProtocol(("127.0.0.1", 1234)) # Set up the server side protocol - channel = self._hs_to_site[hs].buildProtocol(None) + server_address = IPv4Address("TCP", host, port) + channel = self._hs_to_site[hs].buildProtocol((host, port)) # Connect client to server and vice versa. client_to_server_transport = FakeTransport( - channel, self.reactor, client_protocol + channel, self.reactor, client_protocol, server_address, client_address ) client_protocol.makeConnection(client_to_server_transport) server_to_client_transport = FakeTransport( - client_protocol, self.reactor, channel + client_protocol, self.reactor, channel, client_address, server_address ) channel.makeConnection(server_to_client_transport) diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py index f47d94f690..5bbbd5fbcb 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/replication/slave/storage/test_receipts.py @@ -12,23 +12,248 @@ # See the License for the specific language governing permissions and # limitations under the License. +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage.receipts import SlavedReceiptsStore +from synapse.types import UserID, create_requester + +from tests.test_utils.event_injection import create_event from ._base import BaseSlavedStoreTestCase -USER_ID = "@feeling:blue" -ROOM_ID = "!room:blue" -EVENT_ID = "$event:blue" +OTHER_USER_ID = "@other:test" +OUR_USER_ID = "@our:test" class SlavedReceiptTestCase(BaseSlavedStoreTestCase): STORE_TYPE = SlavedReceiptsStore - def test_receipt(self): - self.check("get_receipts_for_user", [USER_ID, "m.read"], {}) + def prepare(self, reactor, clock, homeserver): + super().prepare(reactor, clock, homeserver) + self.room_creator = homeserver.get_room_creation_handler() + self.persist_event_storage = self.hs.get_storage().persistence + + # Create a test user + self.ourUser = UserID.from_string(OUR_USER_ID) + self.ourRequester = create_requester(self.ourUser) + + # Create a second test user + self.otherUser = UserID.from_string(OTHER_USER_ID) + self.otherRequester = create_requester(self.otherUser) + + # Create a test room + info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) + self.room_id1 = info["room_id"] + + # Create a second test room + info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {})) + self.room_id2 = info["room_id"] + + # Join the second user to the first room + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id1, + type="m.room.member", + sender=self.otherRequester.user.to_string(), + state_key=self.otherRequester.user.to_string(), + content={"membership": "join"}, + ) + ) self.get_success( - self.master_store.insert_receipt(ROOM_ID, "m.read", USER_ID, [EVENT_ID], {}) + self.persist_event_storage.persist_event(memberEvent, memberEventContext) + ) + + # Join the second user to the second room + memberEvent, memberEventContext = self.get_success( + create_event( + self.hs, + room_id=self.room_id2, + type="m.room.member", + sender=self.otherRequester.user.to_string(), + state_key=self.otherRequester.user.to_string(), + content={"membership": "join"}, + ) + ) + self.get_success( + self.persist_event_storage.persist_event(memberEvent, memberEventContext) + ) + + def test_return_empty_with_no_data(self): + res = self.get_success( + self.master_store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {}) + + res = self.get_success( + self.master_store.get_receipts_for_user_with_orderings( + OUR_USER_ID, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) + ) + self.assertEqual(res, {}) + + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id1, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) + ) + self.assertEqual(res, None) + + def test_get_receipts_for_user(self): + # Send some events into the first room + event1_1_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + event1_2_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + + # Send public read receipt for the first event + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + ) + ) + # Send private read receipt for the second event + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + ) + ) + + # Test we get the latest event when we want both private and public receipts + res = self.get_success( + self.master_store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Test we get the older event when we want only public receipt + res = self.get_success( + self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + ) + self.assertEqual(res, {self.room_id1: event1_1_id}) + + # Test we get the latest event when we want only the public receipt + res = self.get_success( + self.master_store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Test receipt updating + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + ) + ) + res = self.get_success( + self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + ) + self.assertEqual(res, {self.room_id1: event1_2_id}) + + # Send some events into the second room + event2_1_id = self.create_and_send_event( + self.room_id2, UserID.from_string(OTHER_USER_ID) + ) + + # Test new room is reflected in what the method returns + self.get_success( + self.master_store.insert_receipt( + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + ) + ) + res = self.get_success( + self.master_store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) + + def test_get_last_receipt_event_id_for_user(self): + # Send some events into the first room + event1_1_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + event1_2_id = self.create_and_send_event( + self.room_id1, UserID.from_string(OTHER_USER_ID) + ) + + # Send public read receipt for the first event + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} + ) + ) + # Send private read receipt for the second event + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + ) + ) + + # Test we get the latest event when we want both private and public receipts + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id1, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) + ) + self.assertEqual(res, event1_2_id) + + # Test we get the older event when we want only public receipt + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] + ) + ) + self.assertEqual(res, event1_1_id) + + # Test we get the latest event when we want only the private receipt + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] + ) + ) + self.assertEqual(res, event1_2_id) + + # Test receipt updating + self.get_success( + self.master_store.insert_receipt( + self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} + ) + ) + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] + ) + ) + self.assertEqual(res, event1_2_id) + + # Send some events into the second room + event2_1_id = self.create_and_send_event( + self.room_id2, UserID.from_string(OTHER_USER_ID) + ) + + # Test new room is reflected in what the method returns + self.get_success( + self.master_store.insert_receipt( + self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + ) + ) + res = self.get_success( + self.master_store.get_last_receipt_event_id_for_user( + OUR_USER_ID, + self.room_id2, + [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + ) ) - self.replicate() - self.check("get_receipts_for_user", [USER_ID, "m.read"], {ROOM_ID: EVENT_ID}) + self.assertEqual(res, event2_1_id) diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index ba1a63c0d6..6104a55aa1 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -102,8 +102,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() # type: ignore[attr-defined] - mock_client2.reset_mock() # type: ignore[attr-defined] + mock_client1.reset_mock() + mock_client2.reset_mock() self.create_and_send_event(room, UserID.from_string(user)) self.replicate() @@ -167,8 +167,8 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): for i in range(20): server_name = "other_server_%d" % (i,) room = self.create_room_with_remote_server(user, token, server_name) - mock_client1.reset_mock() # type: ignore[attr-defined] - mock_client2.reset_mock() # type: ignore[attr-defined] + mock_client1.reset_mock() + mock_client2.reset_mock() self.get_success( typing_handler.started_typing( diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e00b5c171c..e0a11da97b 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -520,8 +520,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # MSC3069 entered spec in Matrix 1.2 but maintained compatibility - "org.matrix.msc3069.is_guest": False, "is_guest": False, }, ) @@ -540,8 +538,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): { "user_id": user_id, "device_id": device_id, - # MSC3069 entered spec in Matrix 1.2 but maintained compatibility - "org.matrix.msc3069.is_guest": True, "is_guest": True, }, ) @@ -564,8 +560,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): whoami, { "user_id": user_id, - # MSC3069 entered spec in Matrix 1.2 but maintained compatibility - "org.matrix.msc3069.is_guest": False, "is_guest": False, }, ) diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 0a3d017dc9..4920468f7a 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -81,11 +81,9 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')] -# (possibly experimental) login flows we expect to appear in the list after the normal -# ones +# Login flows we expect to appear in the list after the normal ones. ADDITIONAL_LOGIN_FLOWS = [ {"type": "m.login.application_service"}, - {"type": "uk.half-shot.msc2778.login.application_service"}, ] diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 0abe378fe4..b3738a0304 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -14,7 +14,6 @@ from http import HTTPStatus from unittest.mock import Mock -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.handlers.presence import PresenceHandler @@ -24,6 +23,7 @@ from synapse.types import UserID from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable class PresenceTestCase(unittest.HomeserverTestCase): @@ -37,7 +37,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: presence_handler = Mock(spec=PresenceHandler) - presence_handler.set_state.return_value = defer.succeed(None) + presence_handler.set_state.return_value = make_awaitable(None) hs = self.setup_test_homeserver( "red", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 39667e3225..27dee8f697 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -620,6 +620,19 @@ class RelationsTestCase(BaseRelationsTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) + # Directly requesting the edit should not have the edit to the edit applied. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{edit_event_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual("Wibble", channel.json_body["content"]["body"]) + self.assertIn("m.new_content", channel.json_body["content"]) + + # The relations information should not include the edit to the edit. + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") @@ -984,6 +997,24 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + def test_annotation_to_annotation(self) -> None: + """Any relation to an annotation should be ignored.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + event_id = channel.json_body["event_id"] + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "b", parent_id=event_id + ) + + # Fetch the initial annotation event to see if it has bundled aggregations. + channel = self.make_request( + "GET", + f"/_matrix/client/v3/rooms/{self.room}/event/{event_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + # The first annotationt should not have any bundled aggregations. + self.assertNotIn("m.relations", channel.json_body["unsigned"]) + def test_reference(self) -> None: """ Test that references get correctly bundled. @@ -1029,7 +1060,106 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations.get("latest_event"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10) + + def test_thread_with_bundled_aggregations_for_latest(self) -> None: + """ + Bundled aggregations should get applied to the latest thread event. + """ + self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2 + ) + + def assert_thread(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertTrue(bundled_aggregations.get("current_user_participated")) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + bundled_aggregations.get("latest_event"), + ) + # Check the unsigned field on the latest event. + self.assert_dict( + { + "m.relations": { + RelationTypes.ANNOTATION: { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 1}, + ] + }, + } + }, + bundled_aggregations["latest_event"].get("unsigned"), + ) + + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10) + + def test_nested_thread(self) -> None: + """ + Ensure that a nested thread gets ignored by bundled aggregations, as + those are forbidden. + """ + + # Start a thread. + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + reply_event_id = channel.json_body["event_id"] + + # Disable the validation to pretend this came over federation, since it is + # not an event the Client-Server API will allow.. + with patch( + "synapse.handlers.message.EventCreationHandler._validate_event_relation", + new=lambda self, event: make_awaitable(None), + ): + # Create a sub-thread off the thread, which is not allowed. + self._send_relation( + RelationTypes.THREAD, "m.room.test", parent_id=reply_event_id + ) + + # Fetch the thread root, to get the bundled aggregation for the thread. + relations_from_event = self._get_bundled_aggregations() + + # Ensure that requesting the room messages also does not return the sub-thread. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + event = self._find_event_in_chunk(channel.json_body["chunk"]) + relations_from_messages = event["unsigned"]["m.relations"] + + # Check the bundled aggregations from each point. + for aggregations, desc in ( + (relations_from_event, "/event"), + (relations_from_messages, "/messages"), + ): + # The latest event should have bundled aggregations. + self.assertIn(RelationTypes.THREAD, aggregations, desc) + thread_summary = aggregations[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary, desc) + self.assertEqual( + thread_summary["latest_event"]["event_id"], reply_event_id, desc + ) + + # The latest event should not have any bundled aggregations (since the + # only relation to it is another thread, which is invalid). + self.assertNotIn( + "m.relations", thread_summary["latest_event"]["unsigned"], desc + ) def test_thread_edit_latest_event(self) -> None: """Test that editing the latest event in a thread works.""" @@ -1049,6 +1179,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=threaded_event_id, ) + edit_event_id = channel.json_body["event_id"] # Fetch the thread root, to get the bundled aggregation for the thread. relations_dict = self._get_bundled_aggregations() @@ -1061,6 +1192,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): self.assertIn("latest_event", thread_summary) latest_event_in_thread = thread_summary["latest_event"] self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!") + # The latest event in the thread should have the edit appear under the + # bundled aggregations. + self.assertDictContainsSubset( + {"event_id": edit_event_id, "sender": "@alice:test"}, + latest_event_in_thread["unsigned"]["m.relations"][RelationTypes.REPLACE], + ) def test_aggregation_get_event_for_annotation(self) -> None: """Test that annotations do not get bundled aggregations included diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 6ff79b9e2e..9443daa056 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin @@ -1426,9 +1425,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): def test_simple(self) -> None: "Simple test for searching rooms over federation" - self.federation_client.get_public_rooms.side_effect = lambda *a, **k: defer.succeed( # type: ignore[attr-defined] - {} - ) + self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined] search_filter = {"generic_search_term": "foobar"} @@ -1456,7 +1453,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase): # with a 404, when using search filters. self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined] HttpResponseException(404, "Not Found", b""), - defer.succeed({}), + make_awaitable({}), ) search_filter = {"generic_search_term": "foobar"} diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 773c16a54c..0108337649 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -23,7 +23,7 @@ import synapse.rest.admin from synapse.api.constants import ( EventContentFields, EventTypes, - ReadReceiptEventFields, + ReceiptTypes, RelationTypes, ) from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync @@ -346,7 +346,7 @@ class SyncKnockTestCase( # Knock on a room channel = self.make_request( "POST", - "/_matrix/client/r0/knock/%s" % (self.room_id,), + f"/_matrix/client/r0/knock/{self.room_id}", b"{}", self.knocker_tok, ) @@ -407,22 +407,83 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) @override_config({"experimental_features": {"msc2285_enabled": True}}) - def test_hidden_read_receipts(self) -> None: + def test_private_read_receipts(self) -> None: # Send a message as the first user res = self.helper.send(self.room_id, body="hello", tok=self.tok) - # Send a read receipt to tell the server the first user's message was read - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") + # Send a private read receipt to tell the server the first user's message was read channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, + f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + {}, access_token=self.tok2, ) self.assertEqual(channel.code, 200) - # Test that the first user can't see the other user's hidden read receipt - self.assertEqual(self._get_read_receipt(), None) + # Test that the first user can't see the other user's private read receipt + self.assertIsNone(self._get_read_receipt()) + + @override_config({"experimental_features": {"msc2285_enabled": True}}) + def test_public_receipt_can_override_private(self) -> None: + """ + Sending a public read receipt to the same event which has a private read + receipt should cause that receipt to become public. + """ + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a private read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertIsNone(self._get_read_receipt()) + + # Send a public read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that we did override the private read receipt + self.assertNotEqual(self._get_read_receipt(), None) + + @override_config({"experimental_features": {"msc2285_enabled": True}}) + def test_private_receipt_cannot_override_public(self) -> None: + """ + Sending a private read receipt to the same event which has a public read + receipt should cause no change. + """ + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a public read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertNotEqual(self._get_read_receipt(), None) + + # Send a private read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that we didn't override the public read receipt + self.assertIsNone(self._get_read_receipt()) @parameterized.expand( [ @@ -454,7 +515,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Send a read receipt for this message with an empty body channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), + f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", access_token=self.tok2, custom_headers=[("User-Agent", user_agent)], ) @@ -478,6 +539,9 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase): # Store the next batch for the next request. self.next_batch = channel.json_body["next_batch"] + if channel.json_body.get("rooms", None) is None: + return None + # Return the read receipt ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ "ephemeral" @@ -498,7 +562,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() - config["experimental_features"] = {"msc2654_enabled": True} + config["experimental_features"] = { + "msc2654_enabled": True, + "msc2285_enabled": True, + } return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -560,10 +627,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({"m.read": res["event_id"]}).encode("utf8") + body = json.dumps({ReceiptTypes.READ: res["event_id"]}).encode("utf8") channel = self.make_request( "POST", - "/rooms/%s/read_markers" % self.room_id, + f"/rooms/{self.room_id}/read_markers", body, access_token=self.tok, ) @@ -572,16 +639,15 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): # Check that the unread counter is back to 0. self._check_unread_count(0) - # Check that hidden read receipts don't break unread counts + # Check that private read receipts don't break unread counts res = self.helper.send(self.room_id, "hello", tok=self.tok2) self._check_unread_count(1) # Send a read receipt to tell the server we've read the latest event. - body = json.dumps({ReadReceiptEventFields.MSC2285_HIDDEN: True}).encode("utf8") channel = self.make_request( "POST", - "/rooms/%s/receipt/m.read/%s" % (self.room_id, res["event_id"]), - body, + f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}", + {}, access_token=self.tok, ) self.assertEqual(channel.code, 200, channel.json_body) @@ -643,13 +709,73 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self._check_unread_count(4) # Check that tombstone events changes increase the unread counter. - self.helper.send_state( + res1 = self.helper.send_state( self.room_id, EventTypes.Tombstone, {"replacement_room": "!someroom:test"}, tok=self.tok2, ) self._check_unread_count(5) + res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) + + # Make sure both m.read and org.matrix.msc2285.read.private advance + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(1) + + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(0) + + # We test for both receipt types that influence notification counts + @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]) + def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None: + # Join the new user + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) + + # Send messages + res1 = self.helper.send(self.room_id, "hello", tok=self.tok2) + res2 = self.helper.send(self.room_id, "hello", tok=self.tok2) + + # Read last event + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(0) + + # Make sure neither m.read nor org.matrix.msc2285.read.private make the + # read receipt go up to an older event + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(0) + + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}", + {}, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + self._check_unread_count(0) def _check_unread_count(self, expected_count: int) -> None: """Syncs and compares the unread count with the expected value.""" @@ -662,9 +788,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) - room_entry = channel.json_body["rooms"]["join"][self.room_id] + room_entry = ( + channel.json_body.get("rooms", {}).get("join", {}).get(self.room_id, {}) + ) self.assertEqual( - room_entry["org.matrix.msc2654.unread_count"], + room_entry.get("org.matrix.msc2654.unread_count", 0), expected_count, room_entry, ) diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index 8d8251b2ac..21a1ca2a68 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -22,6 +22,7 @@ from synapse.rest.client.transactions import CLEANUP_PERIOD_MS, HttpTransactionC from synapse.util import Clock from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import MockClock @@ -38,7 +39,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_executes_given_function(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg" ) @@ -47,7 +48,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_deduplicates_based_on_key(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) for i in range(3): # invoke multiple times res = yield self.cache.fetch_or_execute( self.mock_key, cb, "some_arg", keyword="arg", changing_args=i @@ -130,7 +131,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cleans_up(self): - cb = Mock(return_value=defer.succeed(self.mock_http_response)) + cb = Mock(return_value=make_awaitable(self.mock_http_response)) yield self.cache.fetch_or_execute(self.mock_key, cb, "an arg") # should NOT have cleaned up yet self.clock.advance_time_msec(CLEANUP_PERIOD_MS / 2) diff --git a/tests/server.py b/tests/server.py index 16559d2588..8f30e250c8 100644 --- a/tests/server.py +++ b/tests/server.py @@ -181,7 +181,7 @@ class FakeChannel: self.resource_usage = _self.logcontext.get_resource_usage() def getPeer(self): - # We give an address so that getClientIP returns a non null entry, + # We give an address so that getClientAddress/getClientIP returns a non null entry, # causing us to record the MAU return address.IPv4Address("TCP", self._ip, 3423) @@ -562,7 +562,10 @@ class FakeTransport: """ _peer_address: Optional[IAddress] = attr.ib(default=None) - """The value to be returend by getPeer""" + """The value to be returned by getPeer""" + + _host_address: Optional[IAddress] = attr.ib(default=None) + """The value to be returned by getHost""" disconnecting = False disconnected = False @@ -571,11 +574,11 @@ class FakeTransport: producer = attr.ib(default=None) autoflush = attr.ib(default=True) - def getPeer(self): + def getPeer(self) -> Optional[IAddress]: return self._peer_address - def getHost(self): - return None + def getHost(self) -> Optional[IAddress]: + return self._host_address def loseConnection(self, reason=None): if not self.disconnecting: diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 02b96c9e6e..9ee9509d3a 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -14,8 +14,6 @@ from unittest.mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, LimitBlockingTypes, ServerNoticeMsgType from synapse.api.errors import ResourceLimitError from synapse.rest import admin @@ -68,16 +66,16 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): return_value=make_awaitable(1000) ) self._rlsn._server_notices_manager.send_notice = Mock( - return_value=defer.succeed(Mock()) + return_value=make_awaitable(Mock()) ) self._send_notice = self._rlsn._server_notices_manager.send_notice self.user_id = "@user_id:test" self._rlsn._server_notices_manager.get_or_create_notice_room_for_user = Mock( - return_value=defer.succeed("!something:localhost") + return_value=make_awaitable("!something:localhost") ) - self._rlsn._store.add_tag_to_room = Mock(return_value=defer.succeed(None)) + self._rlsn._store.add_tag_to_room = Mock(return_value=make_awaitable(None)) self._rlsn._store.get_tags_for_room = Mock(return_value=make_awaitable({})) @override_config({"hs_disabled": True}) @@ -95,7 +93,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): def test_maybe_send_server_notice_to_user_remove_blocked_notice(self): """Test when user has blocked notice, but should have it removed""" - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) mock_event = Mock( type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType} ) @@ -111,7 +109,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user has blocked notice, but notice ought to be there (NOOP) """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") + return_value=make_awaitable(None), + side_effect=ResourceLimitError(403, "foo"), ) mock_event = Mock( @@ -130,7 +129,8 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user does not have blocked notice, but should have one """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), side_effect=ResourceLimitError(403, "foo") + return_value=make_awaitable(None), + side_effect=ResourceLimitError(403, "foo"), ) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -141,7 +141,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): """ Test when user does not have blocked notice, nor should they (NOOP) """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id)) @@ -152,7 +152,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test when user is not part of the MAU cohort - this should not ever happen - but ... """ - self._rlsn._auth.check_auth_blocking = Mock(return_value=defer.succeed(None)) + self._rlsn._auth.check_auth_blocking = Mock(return_value=make_awaitable(None)) self._rlsn._store.user_last_seen_monthly_active = Mock( return_value=make_awaitable(None) ) @@ -167,7 +167,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): an alert message is not sent into the room """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), @@ -182,7 +182,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): Test that when a server is disabled, that MAU limit alerting is ignored. """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.HS_DISABLED ), @@ -199,14 +199,14 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase): is suppressed that the room is returned to an unblocked state. """ self._rlsn._auth.check_auth_blocking = Mock( - return_value=defer.succeed(None), + return_value=make_awaitable(None), side_effect=ResourceLimitError( 403, "foo", limit_type=LimitBlockingTypes.MONTHLY_ACTIVE_USER ), ) self._rlsn._server_notices_manager.__is_room_currently_blocked = Mock( - return_value=defer.succeed((True, [])) + return_value=make_awaitable((True, [])) ) mock_event = Mock( diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index bf6374f93d..c237a8c7e2 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -13,7 +13,7 @@ # limitations under the License. import json from contextlib import contextmanager -from typing import Generator, Tuple +from typing import Generator, List, Tuple from unittest import mock from twisted.enterprise.adbapi import ConnectionPool @@ -21,6 +21,7 @@ from twisted.internet.defer import CancelledError, Deferred, ensureDeferred from twisted.test.proto_helpers import MemoryReactor from synapse.api.room_versions import EventFormatVersions, RoomVersions +from synapse.events import make_event_from_dict from synapse.logging.context import LoggingContext from synapse.rest import admin from synapse.rest.client import login, room @@ -49,23 +50,28 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): ) ) - for idx, (rid, eid) in enumerate( + self.event_ids: List[str] = [] + for idx, rid in enumerate( ( - ("room1", "event10"), - ("room1", "event11"), - ("room1", "event12"), - ("room2", "event20"), + "room1", + "room1", + "room1", + "room2", ) ): + event_json = {"type": f"test {idx}", "room_id": rid} + event = make_event_from_dict(event_json, room_version=RoomVersions.V4) + event_id = event.event_id + self.get_success( self.store.db_pool.simple_insert( "events", { - "event_id": eid, + "event_id": event_id, "room_id": rid, "topological_ordering": idx, "stream_ordering": idx, - "type": "test", + "type": event.type, "processed": True, "outlier": False, }, @@ -75,21 +81,22 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.store.db_pool.simple_insert( "event_json", { - "event_id": eid, + "event_id": event_id, "room_id": rid, - "json": json.dumps({"type": "test", "room_id": rid}), + "json": json.dumps(event_json), "internal_metadata": "{}", "format_version": 3, }, ) ) + self.event_ids.append(event_id) def test_simple(self): with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", ["event10", "event19"]) + self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) ) - self.assertEqual(res, {"event10"}) + self.assertEqual(res, {self.event_ids[0]}) # that should result in a single db query self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) @@ -97,19 +104,21 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", ["event10", "event19"]) + self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) ) - self.assertEqual(res, {"event10"}) + self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) def test_query_via_event_cache(self): # fetch an event into the event cache - self.get_success(self.store.get_event("event10")) + self.get_success(self.store.get_event(self.event_ids[0])) # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: - res = self.get_success(self.store.have_seen_events("room1", ["event10"])) - self.assertEqual(res, {"event10"}) + res = self.get_success( + self.store.have_seen_events("room1", [self.event_ids[0]]) + ) + self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) @@ -167,7 +176,6 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): self.store: EventsWorkerStore = hs.get_datastores().main self.room_id = f"!room:{hs.hostname}" - self.event_ids = [f"event{i}" for i in range(20)] self._populate_events() @@ -190,8 +198,14 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): ) ) - self.event_ids = [f"event{i}" for i in range(20)] - for idx, event_id in enumerate(self.event_ids): + self.event_ids: List[str] = [] + for idx in range(20): + event_json = { + "type": f"test {idx}", + "room_id": self.room_id, + } + event = make_event_from_dict(event_json, room_version=RoomVersions.V4) + event_id = event.event_id self.get_success( self.store.db_pool.simple_upsert( "events", @@ -201,7 +215,7 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): "room_id": self.room_id, "topological_ordering": idx, "stream_ordering": idx, - "type": "test", + "type": event.type, "processed": True, "outlier": False, }, @@ -213,12 +227,13 @@ class DatabaseOutageTestCase(unittest.HomeserverTestCase): {"event_id": event_id}, { "room_id": self.room_id, - "json": json.dumps({"type": "test", "room_id": self.room_id}), + "json": json.dumps(event_json), "internal_metadata": "{}", "format_version": EventFormatVersions.V3, }, ) ) + self.event_ids.append(event_id) @contextmanager def _outage(self) -> Generator[None, None, None]: diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 60c8d37594..0fbf465670 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -14,7 +14,6 @@ from typing import Any, Dict, List from unittest.mock import Mock -from twisted.internet import defer from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import UserTypes @@ -259,10 +258,10 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_update(self): self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(None) + return_value=make_awaitable(None) ) d = self.store.populate_monthly_active_users("user_id") self.get_success(d) @@ -272,9 +271,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_populate_monthly_users_should_not_update(self): self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment] - self.store.is_trial_user = Mock(return_value=defer.succeed(False)) # type: ignore[assignment] + self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment] self.store.user_last_seen_monthly_active = Mock( - return_value=defer.succeed(self.hs.get_clock().time_msec()) + return_value=make_awaitable(self.hs.get_clock().time_msec()) ) d = self.store.populate_monthly_active_users("user_id") diff --git a/tests/test_federation.py b/tests/test_federation.py index c39816de85..0cbef70bfa 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -233,7 +233,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): # Register mock device list retrieval on the federation client. federation_client = self.homeserver.get_federation_client() federation_client.query_user_devices = Mock( - return_value=succeed( + return_value=make_awaitable( { "user_id": remote_user_id, "stream_id": 1, diff --git a/tests/test_mau.py b/tests/test_mau.py index 46bd3075de..5bbc361aa2 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -14,6 +14,8 @@ """Tests REST events for /rooms paths.""" +from typing import List + from synapse.api.constants import APP_SERVICE_REGISTRATION_TYPE, LoginType from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.appservice import ApplicationService @@ -229,6 +231,78 @@ class TestMauLimit(unittest.HomeserverTestCase): self.reactor.advance(100) self.assertEqual(2, self.successResultOf(count)) + @override_config( + { + "mau_trial_days": 3, + "mau_appservice_trial_days": {"SomeASID": 1, "AnotherASID": 2}, + } + ) + def test_as_trial_days(self): + user_tokens: List[str] = [] + + def advance_time_and_sync(): + self.reactor.advance(24 * 60 * 61) + for token in user_tokens: + self.do_sync_for_user(token) + + # Cheekily add an application service that we use to register a new user + # with. + as_token_1 = "foobartoken1" + self.store.services_cache.append( + ApplicationService( + token=as_token_1, + hostname=self.hs.hostname, + id="SomeASID", + sender="@as_sender_1:test", + namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]}, + ) + ) + + as_token_2 = "foobartoken2" + self.store.services_cache.append( + ApplicationService( + token=as_token_2, + hostname=self.hs.hostname, + id="AnotherASID", + sender="@as_sender_2:test", + namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]}, + ) + ) + + user_tokens.append(self.create_user("kermit1")) + user_tokens.append(self.create_user("kermit2")) + user_tokens.append( + self.create_user("as_1kermit3", token=as_token_1, appservice=True) + ) + user_tokens.append( + self.create_user("as_2kermit4", token=as_token_2, appservice=True) + ) + + # Advance time by 1 day to include the first appservice + advance_time_and_sync() + self.assertEqual( + self.get_success(self.store.get_monthly_active_count_by_service()), + {"SomeASID": 1}, + ) + + # Advance time by 1 day to include the next appservice + advance_time_and_sync() + self.assertEqual( + self.get_success(self.store.get_monthly_active_count_by_service()), + {"SomeASID": 1, "AnotherASID": 1}, + ) + + # Advance time by 1 day to include the native users + advance_time_and_sync() + self.assertEqual( + self.get_success(self.store.get_monthly_active_count_by_service()), + { + "SomeASID": 1, + "AnotherASID": 1, + "native": 2, + }, + ) + def create_user(self, localpart, token=None, appservice=False): request_data = { "username": localpart, diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py index f05a373aa0..0d0d6faf0d 100644 --- a/tests/test_utils/__init__.py +++ b/tests/test_utils/__init__.py @@ -52,7 +52,7 @@ def make_awaitable(result: TV) -> Awaitable[TV]: This uses Futures as they can be awaited multiple times so can be returned to multiple callers. """ - future = Future() # type: ignore + future: Future[TV] = Future() future.set_result(result) return future @@ -69,7 +69,7 @@ def setup_awaitable_errors() -> Callable[[], None]: # State shared between unraisablehook and check_for_unraisable_exceptions. unraisable_exceptions = [] - orig_unraisablehook = sys.unraisablehook # type: ignore + orig_unraisablehook = sys.unraisablehook def unraisablehook(unraisable): unraisable_exceptions.append(unraisable.exc_value) @@ -78,11 +78,11 @@ def setup_awaitable_errors() -> Callable[[], None]: """ A method to be used as a clean-up that fails a test-case if there are any new unraisable exceptions. """ - sys.unraisablehook = orig_unraisablehook # type: ignore + sys.unraisablehook = orig_unraisablehook if unraisable_exceptions: raise unraisable_exceptions.pop() - sys.unraisablehook = unraisablehook # type: ignore + sys.unraisablehook = unraisablehook return cleanup diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index 51a197a8c6..9228454c9e 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -27,7 +27,7 @@ class ToTwistedHandler(logging.Handler): def emit(self, record): log_entry = self.format(record) log_level = record.levelname.lower().replace("warning", "warn") - self.tx_log.emit( # type: ignore + self.tx_log.emit( twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry ) |