diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_directory.py | 84 | ||||
-rw-r--r-- | tests/handlers/test_e2e_keys.py | 36 | ||||
-rw-r--r-- | tests/handlers/test_oidc.py | 94 | ||||
-rw-r--r-- | tests/handlers/test_profile.py | 43 | ||||
-rw-r--r-- | tests/handlers/test_saml.py | 24 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 769 | ||||
-rw-r--r-- | tests/rest/media/v1/test_html_preview.py | 54 | ||||
-rw-r--r-- | tests/storage/test_account_data.py | 17 | ||||
-rw-r--r-- | tests/storage/test_database.py | 162 |
9 files changed, 707 insertions, 576 deletions
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 6e403a87c5..11ad44223d 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -12,14 +12,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.api.errors import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.rest.client import directory, login, room -from synapse.types import RoomAlias, create_requester +from synapse.server import HomeServer +from synapse.types import JsonDict, RoomAlias, create_requester +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -28,13 +32,15 @@ from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): """Tests the directory service.""" - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation = Mock() self.mock_registry = Mock() - self.query_handlers = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} - def register_query_handler(query_type, handler): + def register_query_handler( + query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] + ) -> None: self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler @@ -54,7 +60,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): return hs - def test_get_local_association(self): + def test_get_local_association(self) -> None: self.get_success( self.store.create_room_alias_association( self.my_room, "!8765qwer:test", ["test"] @@ -65,7 +71,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual({"room_id": "!8765qwer:test", "servers": ["test"]}, result) - def test_get_remote_association(self): + def test_get_remote_association(self) -> None: self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) @@ -83,7 +89,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_incoming_fed_query(self): + def test_incoming_fed_query(self) -> None: self.get_success( self.store.create_room_alias_association( self.your_room, "!8765asdf:test", ["test"] @@ -105,7 +111,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_directory_handler() # Create user @@ -125,7 +131,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def test_create_alias_joined_room(self): + def test_create_alias_joined_room(self) -> None: """A user can create an alias for a room they're in.""" self.get_success( self.handler.create_association( @@ -135,7 +141,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): ) ) - def test_create_alias_other_room(self): + def test_create_alias_other_room(self) -> None: """A user cannot create an alias for a room they're NOT in.""" other_room_id = self.helper.create_room_as( self.admin_user, tok=self.admin_user_tok @@ -150,7 +156,7 @@ class TestCreateAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_create_alias_admin(self): + def test_create_alias_admin(self) -> None: """An admin can create an alias for a room they're NOT in.""" other_room_id = self.helper.create_room_as( self.test_user, tok=self.test_user_tok @@ -173,7 +179,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -195,7 +201,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): self.test_user_tok = self.login("user", "pass") self.helper.join(room=self.room_id, user=self.test_user, tok=self.test_user_tok) - def _create_alias(self, user): + def _create_alias(self, user) -> None: # Create a new alias to this room. self.get_success( self.store.create_room_alias_association( @@ -203,7 +209,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): ) ) - def test_delete_alias_not_allowed(self): + def test_delete_alias_not_allowed(self) -> None: """A user that doesn't meet the expected guidelines cannot delete an alias.""" self._create_alias(self.admin_user) self.get_failure( @@ -213,7 +219,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.AuthError, ) - def test_delete_alias_creator(self): + def test_delete_alias_creator(self) -> None: """An alias creator can delete their own alias.""" # Create an alias from a different user. self._create_alias(self.test_user) @@ -232,7 +238,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_delete_alias_admin(self): + def test_delete_alias_admin(self) -> None: """A server admin can delete an alias created by another user.""" # Create an alias from a different user. self._create_alias(self.test_user) @@ -251,7 +257,7 @@ class TestDeleteAlias(unittest.HomeserverTestCase): synapse.api.errors.SynapseError, ) - def test_delete_alias_sufficient_power(self): + def test_delete_alias_sufficient_power(self) -> None: """A user with a sufficient power level should be able to delete an alias.""" self._create_alias(self.admin_user) @@ -288,7 +294,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): directory.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() @@ -317,7 +323,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) return room_alias - def _set_canonical_alias(self, content): + def _set_canonical_alias(self, content) -> None: """Configure the canonical alias state on the room.""" self.helper.send_state( self.room_id, @@ -334,7 +340,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): ) ) - def test_remove_alias(self): + def test_remove_alias(self) -> None: """Removing an alias that is the canonical alias should remove it there too.""" # Set this new alias as the canonical alias for this room self._set_canonical_alias( @@ -356,7 +362,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): self.assertNotIn("alias", data["content"]) self.assertNotIn("alt_aliases", data["content"]) - def test_remove_other_alias(self): + def test_remove_other_alias(self) -> None: """Removing an alias listed as in alt_aliases should remove it there too.""" # Create a second alias. other_test_alias = "#test2:test" @@ -393,7 +399,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # Add custom alias creation rules to the config. @@ -403,7 +409,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): return config - def test_denied(self): + def test_denied(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -413,7 +419,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): ) self.assertEqual(403, channel.code, channel.result) - def test_allowed(self): + def test_allowed(self) -> None: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -423,7 +429,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): ) self.assertEqual(200, channel.code, channel.result) - def test_denied_during_creation(self): + def test_denied_during_creation(self) -> None: """A room alias that is not allowed should be rejected during creation.""" # Invalid room alias. self.helper.create_room_as( @@ -432,7 +438,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): extra_content={"room_alias_name": "foo"}, ) - def test_allowed_during_creation(self): + def test_allowed_during_creation(self) -> None: """A valid room alias should be allowed during creation.""" room_id = self.helper.create_room_as( self.user_id, @@ -459,7 +465,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): data = {"room_alias_name": "unofficial_test"} allowed_localpart = "allowed" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() # Add custom room list publication rules to the config. @@ -474,7 +480,9 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): return config - def prepare(self, reactor, clock, hs): + def prepare( + self, reactor: MemoryReactor, clock: Clock, hs: HomeServer + ) -> HomeServer: self.allowed_user_id = self.register_user(self.allowed_localpart, "pass") self.allowed_access_token = self.login(self.allowed_localpart, "pass") @@ -483,7 +491,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): return hs - def test_denied_without_publication_permission(self): + def test_denied_without_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user without permission to publish rooms. @@ -497,7 +505,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=403, ) - def test_allowed_when_creating_private_room(self): + def test_allowed_when_creating_private_room(self) -> None: """ Try to create a room, register an alias for it, and NOT publish it, as a user without permission to publish rooms. @@ -511,7 +519,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=200, ) - def test_allowed_with_publication_permission(self): + def test_allowed_with_publication_permission(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. @@ -525,7 +533,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=200, ) - def test_denied_publication_with_invalid_alias(self): + def test_denied_publication_with_invalid_alias(self) -> None: """ Try to create a room, register an alias for it, and publish it, as a user WITH permission to publish rooms. @@ -538,7 +546,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): expect_code=403, ) - def test_can_create_as_private_room_after_rejection(self): + def test_can_create_as_private_room_after_rejection(self) -> None: """ After failing to publish a room with an alias as a user without publish permission, retry as the same user, but without publishing the room. @@ -549,7 +557,7 @@ class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): self.test_denied_without_publication_permission() self.test_allowed_when_creating_private_room() - def test_can_create_with_permission_after_rejection(self): + def test_can_create_with_permission_after_rejection(self) -> None: """ After failing to publish a room with an alias as a user without publish permission, retry as someone with permission, using the same alias. @@ -566,7 +574,9 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): servlets = [directory.register_servlets, room.register_servlets] - def prepare(self, reactor, clock, hs): + def prepare( + self, reactor: MemoryReactor, clock: Clock, hs: HomeServer + ) -> HomeServer: room_id = self.helper.create_room_as(self.user_id) channel = self.make_request( @@ -579,7 +589,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase): return hs - def test_disabling_room_list(self): + def test_disabling_room_list(self) -> None: self.room_list_handler.enable_room_list_search = True self.directory_handler.enable_room_list_search = True diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 9338ab92e9..ac21a28c43 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -20,33 +20,37 @@ from parameterized import parameterized from signedjson import key as key, sign as sign from twisted.internet import defer +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import Codes, SynapseError +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: return self.setup_test_homeserver(federation_client=mock.Mock()) - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.handler = hs.get_e2e_keys_handler() self.store = self.hs.get_datastores().main - def test_query_local_devices_no_devices(self): + def test_query_local_devices_no_devices(self) -> None: """If the user has no devices, we expect an empty list.""" local_user = "@boris:" + self.hs.hostname res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - def test_reupload_one_time_keys(self): + def test_reupload_one_time_keys(self) -> None: """we should be able to re-upload the same keys""" local_user = "@boris:" + self.hs.hostname device_id = "xyz" - keys = { + keys: JsonDict = { "alg1:k1": "key1", "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, "alg2:k3": {"key": "key3"}, @@ -74,7 +78,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res, {"one_time_key_counts": {"alg1": 1, "alg2": 2, "signed_curve25519": 0}} ) - def test_change_one_time_keys(self): + def test_change_one_time_keys(self) -> None: """attempts to change one-time-keys should be rejected""" local_user = "@boris:" + self.hs.hostname @@ -134,7 +138,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_claim_one_time_key(self): + def test_claim_one_time_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" keys = {"alg1:k1": "key1"} @@ -161,7 +165,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_fallback_key(self): + def test_fallback_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" fallback_key = {"alg1:k1": "fallback_key1"} @@ -294,7 +298,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): {"failures": {}, "one_time_keys": {local_user: {device_id: fallback_key3}}}, ) - def test_replace_master_key(self): + def test_replace_master_key(self) -> None: """uploading a new signing key should make the old signing key unavailable""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -328,7 +332,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): ) self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]}) - def test_reupload_signatures(self): + def test_reupload_signatures(self) -> None: """re-uploading a signature should not fail""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -433,7 +437,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): self.assertDictEqual(devices["device_keys"][local_user]["abc"], device_key_1) self.assertDictEqual(devices["device_keys"][local_user]["def"], device_key_2) - def test_self_signing_key_doesnt_show_up_as_device(self): + def test_self_signing_key_doesnt_show_up_as_device(self) -> None: """signing keys should be hidden when fetching a user's devices""" local_user = "@boris:" + self.hs.hostname keys1 = { @@ -462,7 +466,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): res = self.get_success(self.handler.query_local_devices({local_user: None})) self.assertDictEqual(res, {local_user: {}}) - def test_upload_signatures(self): + def test_upload_signatures(self) -> None: """should check signatures that are uploaded""" # set up a user with cross-signing keys and a device. This user will # try uploading signatures @@ -686,7 +690,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey], ) - def test_query_devices_remote_no_sync(self): + def test_query_devices_remote_no_sync(self) -> None: """Tests that querying keys for a remote user that we don't share a room with returns the cross signing keys correctly. """ @@ -759,7 +763,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, ) - def test_query_devices_remote_sync(self): + def test_query_devices_remote_sync(self) -> None: """Tests that querying keys for a remote user that we share a room with, but haven't yet fetched the keys for, returns the cross signing keys correctly. @@ -845,7 +849,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): (["device_1", "device_2"],), ] ) - def test_query_all_devices_caches_result(self, device_ids: Iterable[str]): + def test_query_all_devices_caches_result(self, device_ids: Iterable[str]) -> None: """Test that requests for all of a remote user's devices are cached. We do this by asserting that only one call over federation was made, and that @@ -853,7 +857,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): """ local_user_id = "@test:test" remote_user_id = "@test:other" - request_body = {"device_keys": {remote_user_id: []}} + request_body: JsonDict = {"device_keys": {remote_user_id: []}} response_devices = [ { diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index e8418b6638..014815db6e 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,14 +13,18 @@ # limitations under the License. import json import os +from typing import Any, Dict from unittest.mock import ANY, Mock, patch from urllib.parse import parse_qs, urlparse import pymacaroons +from twisted.test.proto_helpers import MemoryReactor + from synapse.handlers.sso import MappingException from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID +from synapse.util import Clock from synapse.util.macaroons import get_value_from_macaroon from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock @@ -98,7 +102,7 @@ class TestMappingProviderFailures(TestMappingProvider): } -async def get_json(url): +async def get_json(url: str) -> JsonDict: # Mock get_json calls to handle jwks & oidc discovery endpoints if url == WELL_KNOWN: # Minimal discovery document, as defined in OpenID.Discovery @@ -116,6 +120,8 @@ async def get_json(url): elif url == JWKS_URI: return {"keys": []} + return {} + def _key_file_path() -> str: """path to a file containing the private half of a test key""" @@ -147,12 +153,12 @@ class OidcHandlerTestCase(HomeserverTestCase): if not HAS_OIDC: skip = "requires OIDC" - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.http_client = Mock(spec=["get_json"]) self.http_client.get_json.side_effect = get_json self.http_client.user_agent = b"Synapse Test" @@ -164,7 +170,7 @@ class OidcHandlerTestCase(HomeserverTestCase): sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) - sso_handler.render_error = self.render_error + sso_handler.render_error = self.render_error # type: ignore[assignment] # Reduce the number of attempts when generating MXIDs. sso_handler._MAP_USERNAME_RETRIES = 3 @@ -193,14 +199,14 @@ class OidcHandlerTestCase(HomeserverTestCase): return args @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_config(self): + def test_config(self) -> None: """Basic config correctly sets up the callback URL and client auth correctly.""" self.assertEqual(self.provider._callback_url, CALLBACK_URL) self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID) self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}}) - def test_discovery(self): + def test_discovery(self) -> None: """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid metadata = self.get_success(self.provider.load_metadata()) @@ -219,13 +225,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.http_client.get_json.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) - def test_no_discovery(self): + def test_no_discovery(self) -> None: """When discovery is disabled, it should not try to load from discovery document.""" self.get_success(self.provider.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG}) - def test_load_jwks(self): + def test_load_jwks(self) -> None: """JWKS loading is done once (then cached) if used.""" jwks = self.get_success(self.provider.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) @@ -253,7 +259,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_failure(self.provider.load_jwks(force=True), RuntimeError) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_validate_config(self): + def test_validate_config(self) -> None: """Provider metadatas are extensively validated.""" h = self.provider @@ -336,14 +342,14 @@ class OidcHandlerTestCase(HomeserverTestCase): force_load_metadata() @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}}) - def test_skip_verification(self): + def test_skip_verification(self) -> None: """Provider metadata validation can be disabled by config.""" with self.metadata_edit({"issuer": "http://insecure"}): # This should not throw get_awaitable_result(self.provider.load_metadata()) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_redirect_request(self): + def test_redirect_request(self) -> None: """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["cookies"]) req.cookies = [] @@ -387,7 +393,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(redirect, "http://client/redirect") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback_error(self): + def test_callback_error(self) -> None: """Errors from the provider returned in the callback are displayed.""" request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] @@ -399,7 +405,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("invalid_client", "some description") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback(self): + def test_callback(self) -> None: """Code callback works and display errors if something went wrong. A lot of scenarios are tested here: @@ -428,9 +434,9 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": username, } expected_user_id = "@%s:%s" % (username, self.hs.hostname) - self.provider._exchange_code = simple_async_mock(return_value=token) - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) - self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] + self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -468,7 +474,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertRenderedError("mapping_error") # Handle ID token errors - self.provider._parse_id_token = simple_async_mock(raises=Exception()) + self.provider._parse_id_token = simple_async_mock(raises=Exception()) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") @@ -483,7 +489,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "type": "bearer", "access_token": "access_token", } - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( @@ -510,8 +516,8 @@ class OidcHandlerTestCase(HomeserverTestCase): id_token = { "sid": "abcdefgh", } - self.provider._parse_id_token = simple_async_mock(return_value=id_token) - self.provider._exchange_code = simple_async_mock(return_value=token) + self.provider._parse_id_token = simple_async_mock(return_value=id_token) # type: ignore[assignment] + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] auth_handler.complete_sso_login.reset_mock() self.provider._fetch_userinfo.reset_mock() self.get_success(self.handler.handle_oidc_callback(request)) @@ -531,21 +537,21 @@ class OidcHandlerTestCase(HomeserverTestCase): self.render_error.assert_not_called() # Handle userinfo fetching error - self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) + self.provider._fetch_userinfo = simple_async_mock(raises=Exception()) # type: ignore[assignment] self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure from synapse.handlers.oidc import OidcError - self.provider._exchange_code = simple_async_mock( + self.provider._exchange_code = simple_async_mock( # type: ignore[assignment] raises=OidcError("invalid_request") ) self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_callback_session(self): + def test_callback_session(self) -> None: """The callback verifies the session presence and validity""" request = Mock(spec=["args", "getCookie", "cookies"]) @@ -590,7 +596,7 @@ class OidcHandlerTestCase(HomeserverTestCase): @override_config( {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}} ) - def test_exchange_code(self): + def test_exchange_code(self) -> None: """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} token_json = json.dumps(token).encode("utf-8") @@ -686,7 +692,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_exchange_code_jwt_key(self): + def test_exchange_code_jwt_key(self) -> None: """Test that code exchange works with a JWK client secret.""" from authlib.jose import jwt @@ -741,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_exchange_code_no_auth(self): + def test_exchange_code_no_auth(self) -> None: """Test that code exchange works with no client secret.""" token = {"type": "bearer"} self.http_client.request = simple_async_mock( @@ -776,7 +782,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_extra_attributes(self): + def test_extra_attributes(self) -> None: """ Login while using a mapping provider that implements get_extra_attributes. """ @@ -790,8 +796,8 @@ class OidcHandlerTestCase(HomeserverTestCase): "username": "foo", "phone": "1234567", } - self.provider._exchange_code = simple_async_mock(return_value=token) - self.provider._parse_id_token = simple_async_mock(return_value=userinfo) + self.provider._exchange_code = simple_async_mock(return_value=token) # type: ignore[assignment] + self.provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -817,12 +823,12 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_map_userinfo_to_user(self): + def test_map_userinfo_to_user(self) -> None: """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() - userinfo = { + userinfo: dict = { "sub": "test_user", "username": "test_user", } @@ -870,7 +876,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}}) - def test_map_userinfo_to_existing_user(self): + def test_map_userinfo_to_existing_user(self) -> None: """Existing users can log in with OpenID Connect when allow_existing_users is True.""" store = self.hs.get_datastores().main user = UserID.from_string("@test_user:test") @@ -974,7 +980,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_map_userinfo_to_invalid_localpart(self): + def test_map_userinfo_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" self.get_success( _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "fΓΆΓΆ"}) @@ -991,7 +997,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_map_userinfo_to_user_retries(self): + def test_map_userinfo_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1039,7 +1045,7 @@ class OidcHandlerTestCase(HomeserverTestCase): ) @override_config({"oidc_config": DEFAULT_CONFIG}) - def test_empty_localpart(self): + def test_empty_localpart(self) -> None: """Attempts to map onto an empty localpart should be rejected.""" userinfo = { "sub": "tester", @@ -1058,7 +1064,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_null_localpart(self): + def test_null_localpart(self) -> None: """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" userinfo = { "sub": "tester", @@ -1075,7 +1081,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements(self): + def test_attribute_requirements(self) -> None: """The required attributes must be met from the OIDC userinfo response.""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1115,7 +1121,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements_contains(self): + def test_attribute_requirements_contains(self) -> None: """Test that auth succeeds if userinfo attribute CONTAINS required value""" auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() @@ -1146,7 +1152,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } } ) - def test_attribute_requirements_mismatch(self): + def test_attribute_requirements_mismatch(self) -> None: """ Test that auth fails if attributes exist but don't match, or are non-string values. @@ -1154,7 +1160,7 @@ class OidcHandlerTestCase(HomeserverTestCase): auth_handler = self.hs.get_auth_handler() auth_handler.complete_sso_login = simple_async_mock() # userinfo with "test": "not_foobar" attribute should fail - userinfo = { + userinfo: dict = { "sub": "tester", "username": "tester", "test": "not_foobar", @@ -1248,9 +1254,9 @@ async def _make_callback_with_userinfo( handler = hs.get_oidc_handler() provider = handler._providers["oidc"] - provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) - provider._parse_id_token = simple_async_mock(return_value=userinfo) - provider._fetch_userinfo = simple_async_mock(return_value=userinfo) + provider._exchange_code = simple_async_mock(return_value={"id_token": ""}) # type: ignore[assignment] + provider._parse_id_token = simple_async_mock(return_value=userinfo) # type: ignore[assignment] + provider._fetch_userinfo = simple_async_mock(return_value=userinfo) # type: ignore[assignment] state = "state" session = handler._token_generator.generate_oidc_session_token( diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 972cbac6e4..08733a9f2d 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -11,14 +11,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict +from typing import Any, Awaitable, Callable, Dict from unittest.mock import Mock +from twisted.test.proto_helpers import MemoryReactor + import synapse.types from synapse.api.errors import AuthError, SynapseError from synapse.rest import admin from synapse.server import HomeServer -from synapse.types import UserID +from synapse.types import JsonDict, UserID +from synapse.util import Clock from tests import unittest from tests.test_utils import make_awaitable @@ -29,13 +32,15 @@ class ProfileTestCase(unittest.HomeserverTestCase): servlets = [admin.register_servlets] - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.mock_federation = Mock() self.mock_registry = Mock() - self.query_handlers = {} + self.query_handlers: Dict[str, Callable[[dict], Awaitable[JsonDict]]] = {} - def register_query_handler(query_type, handler): + def register_query_handler( + query_type: str, handler: Callable[[dict], Awaitable[JsonDict]] + ) -> None: self.query_handlers[query_type] = handler self.mock_registry.register_query_handler = register_query_handler @@ -47,7 +52,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) return hs - def prepare(self, reactor, clock, hs: HomeServer): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main self.frank = UserID.from_string("@1234abcd:test") @@ -58,7 +63,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.handler = hs.get_profile_handler() - def test_get_my_name(self): + def test_get_my_name(self) -> None: self.get_success( self.store.set_profile_displayname(self.frank.localpart, "Frank") ) @@ -67,7 +72,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual("Frank", displayname) - def test_set_my_name(self): + def test_set_my_name(self) -> None: self.get_success( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.frank), "Frank Jr." @@ -110,7 +115,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.get_success(self.store.get_profile_displayname(self.frank.localpart)) ) - def test_set_my_name_if_disabled(self): + def test_set_my_name_if_disabled(self) -> None: self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed @@ -135,7 +140,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_set_my_name_noauth(self): + def test_set_my_name_noauth(self) -> None: self.get_failure( self.handler.set_displayname( self.frank, synapse.types.create_requester(self.bob), "Frank Jr." @@ -143,7 +148,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): AuthError, ) - def test_get_other_name(self): + def test_get_other_name(self) -> None: self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) @@ -158,7 +163,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ignore_backoff=True, ) - def test_incoming_fed_query(self): + def test_incoming_fed_query(self) -> None: self.get_success(self.store.create_profile("caroline")) self.get_success(self.store.set_profile_displayname("caroline", "Caroline")) @@ -174,7 +179,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual({"displayname": "Caroline"}, response) - def test_get_my_avatar(self): + def test_get_my_avatar(self) -> None: self.get_success( self.store.set_profile_avatar_url( self.frank.localpart, "http://my.server/me.png" @@ -184,7 +189,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertEqual("http://my.server/me.png", avatar_url) - def test_set_my_avatar(self): + def test_set_my_avatar(self) -> None: self.get_success( self.handler.set_avatar_url( self.frank, @@ -225,7 +230,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): (self.get_success(self.store.get_profile_avatar_url(self.frank.localpart))), ) - def test_set_my_avatar_if_disabled(self): + def test_set_my_avatar_if_disabled(self) -> None: self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed @@ -250,7 +255,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): SynapseError, ) - def test_avatar_constraints_no_config(self): + def test_avatar_constraints_no_config(self) -> None: """Tests that the method to check an avatar against configured constraints skips all of its check if no constraint is configured. """ @@ -263,7 +268,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertTrue(res) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_constraints_missing(self): + def test_avatar_constraints_missing(self) -> None: """Tests that an avatar isn't allowed if the file at the given MXC URI couldn't be found. """ @@ -273,7 +278,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertFalse(res) @unittest.override_config({"max_avatar_size": 50}) - def test_avatar_constraints_file_size(self): + def test_avatar_constraints_file_size(self) -> None: """Tests that a file that's above the allowed file size is forbidden but one that's below it is allowed. """ @@ -295,7 +300,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): self.assertFalse(res) @unittest.override_config({"allowed_avatar_mimetypes": ["image/png"]}) - def test_avatar_constraint_mime_type(self): + def test_avatar_constraint_mime_type(self) -> None: """Tests that a file with an unauthorised MIME type is forbidden but one with an authorised content type is allowed. """ diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 23941abed8..8d4404eda1 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, Optional from unittest.mock import Mock import attr +from twisted.test.proto_helpers import MemoryReactor + from synapse.api.errors import RedirectException +from synapse.server import HomeServer +from synapse.util import Clock from tests.test_utils import simple_async_mock from tests.unittest import HomeserverTestCase, override_config @@ -81,10 +85,10 @@ class TestRedirectMappingProvider(TestMappingProvider): class SamlHandlerTestCase(HomeserverTestCase): - def default_config(self): + def default_config(self) -> Dict[str, Any]: config = super().default_config() config["public_baseurl"] = BASE_URL - saml_config = { + saml_config: Dict[str, Any] = { "sp_config": {"metadata": {}}, # Disable grandfathering. "grandfathered_mxid_source_attribute": None, @@ -98,7 +102,7 @@ class SamlHandlerTestCase(HomeserverTestCase): return config - def make_homeserver(self, reactor, clock): + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: hs = self.setup_test_homeserver() self.handler = hs.get_saml_handler() @@ -114,7 +118,7 @@ class SamlHandlerTestCase(HomeserverTestCase): elif not has_xmlsec1: skip = "Requires xmlsec1" - def test_map_saml_response_to_user(self): + def test_map_saml_response_to_user(self) -> None: """Ensure that mapping the SAML response returned from a provider to an MXID works properly.""" # stub out the auth handler @@ -140,7 +144,7 @@ class SamlHandlerTestCase(HomeserverTestCase): ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) - def test_map_saml_response_to_existing_user(self): + def test_map_saml_response_to_existing_user(self) -> None: """Existing users can log in with SAML account.""" store = self.hs.get_datastores().main self.get_success( @@ -186,7 +190,7 @@ class SamlHandlerTestCase(HomeserverTestCase): auth_provider_session_id=None, ) - def test_map_saml_response_to_invalid_localpart(self): + def test_map_saml_response_to_invalid_localpart(self) -> None: """If the mapping provider generates an invalid localpart it should be rejected.""" # stub out the auth handler @@ -207,7 +211,7 @@ class SamlHandlerTestCase(HomeserverTestCase): ) auth_handler.complete_sso_login.assert_not_called() - def test_map_saml_response_to_user_retries(self): + def test_map_saml_response_to_user_retries(self) -> None: """The mapping provider can retry generating an MXID if the MXID is already in use.""" # stub out the auth handler and error renderer @@ -271,7 +275,7 @@ class SamlHandlerTestCase(HomeserverTestCase): } } ) - def test_map_saml_response_redirect(self): + def test_map_saml_response_redirect(self) -> None: """Test a mapping provider that raises a RedirectException""" saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"}) @@ -292,7 +296,7 @@ class SamlHandlerTestCase(HomeserverTestCase): }, } ) - def test_attribute_requirements(self): + def test_attribute_requirements(self) -> None: """The required attributes must be met from the SAML response.""" # stub out the auth handler diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 171f4e97c8..f3741b3001 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -79,6 +79,7 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): content: Optional[dict] = None, access_token: Optional[str] = None, parent_id: Optional[str] = None, + expected_response_code: int = 200, ) -> FakeChannel: """Helper function to send a relation pointing at `self.parent_id` @@ -115,16 +116,50 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): content, access_token=access_token, ) + self.assertEqual(expected_response_code, channel.code, channel.json_body) return channel + def _get_related_events(self) -> List[str]: + """ + Requests /relations on the parent ID and returns a list of event IDs. + """ + # Request the relations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return [ev["event_id"] for ev in channel.json_body["chunk"]] + + def _get_bundled_aggregations(self) -> JsonDict: + """ + Requests /event on the parent ID and returns the m.relations field (from unsigned), if it exists. + """ + # Fetch the bundled aggregations of the event. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + return channel.json_body["unsigned"].get("m.relations", {}) + + def _get_aggregations(self) -> List[JsonDict]: + """Request /aggregations on the parent ID and includes the returned chunk.""" + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + return channel.json_body["chunk"] + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: """Tests that sending a relation works.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="π") - self.assertEqual(200, channel.code, channel.json_body) - event_id = channel.json_body["event_id"] channel = self.make_request( @@ -151,13 +186,13 @@ class RelationsTestCase(BaseRelationsTestCase): def test_deny_invalid_event(self) -> None: """Test that we deny relations on non-existant events""" - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) # Unless that event is referenced from another event! self.get_success( @@ -171,13 +206,12 @@ class RelationsTestCase(BaseRelationsTestCase): desc="test_deny_invalid_event", ) ) - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, EventTypes.Message, parent_id="foo", content={"body": "foo", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) def test_deny_invalid_room(self) -> None: """Test that we deny relations on non-existant events""" @@ -187,18 +221,20 @@ class RelationsTestCase(BaseRelationsTestCase): parent_id = res["event_id"] # Attempt to send an annotation to that event. - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", parent_id=parent_id, key="A" + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + parent_id=parent_id, + key="A", + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) def test_deny_double_react(self) -> None: """Test that we deny relations on membership events""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(400, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", expected_response_code=400 + ) def test_deny_forked_thread(self) -> None: """It is invalid to start a thread off a thread.""" @@ -208,316 +244,24 @@ class RelationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo"}, parent_id=self.parent_id, ) - self.assertEqual(200, channel.code, channel.json_body) parent_id = channel.json_body["event_id"] - channel = self._send_relation( + self._send_relation( RelationTypes.THREAD, "m.room.message", content={"msgtype": "m.text", "body": "foo"}, parent_id=parent_id, + expected_response_code=400, ) - self.assertEqual(400, channel.code, channel.json_body) - - def test_basic_paginate_relations(self) -> None: - """Tests that calling pagination API correctly the latest relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - first_annotation_id = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) - second_annotation_id = channel.json_body["event_id"] - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the latest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - { - "event_id": second_annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, - channel.json_body["chunk"][0], - ) - - # We also expect to get the original event (the id of which is self.parent_id) - self.assertEqual( - channel.json_body["original_event"]["event_id"], self.parent_id - ) - - # Make sure next_batch has something in it that looks like it could be a - # valid token. - self.assertIsInstance( - channel.json_body.get("next_batch"), str, channel.json_body - ) - - # Request the relations again, but with a different direction. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations" - f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # We expect to get back a single pagination result, which is the earliest - # full relation event we sent above. - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - self.assert_dict( - { - "event_id": first_annotation_id, - "sender": self.user_id, - "type": "m.reaction", - }, - channel.json_body["chunk"][0], - ) - - def test_repeated_paginate_relations(self) -> None: - """Test that if we paginate using a limit and tokens then we get the - expected events. - """ - - expected_event_ids = [] - for idx in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - prev_token = "" - found_event_ids: List[str] = [] - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) - - def test_pagination_from_sync_and_messages(self) -> None: - """Pagination tokens from /sync and /messages can be used to paginate /relations.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") - self.assertEqual(200, channel.code, channel.json_body) - annotation_id = channel.json_body["event_id"] - # Send an event after the relation events. - self.helper.send(self.room, body="Latest event", tok=self.user_token) - - # Request /sync, limiting it such that only the latest event is returned - # (and not the relation). - filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') - channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token - ) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - sync_prev_batch = room_timeline["prev_batch"] - self.assertIsNotNone(sync_prev_batch) - # Ensure the relation event is not in the batch returned from /sync. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in room_timeline["events"]] - ) - - # Request /messages, limiting it such that only the latest event is - # returned (and not the relation). - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b&limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - messages_end = channel.json_body["end"] - self.assertIsNotNone(messages_end) - # Ensure the relation event is not in the chunk returned from /messages. - self.assertNotIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] - ) - - # Request /relations with the pagination tokens received from both the - # /sync and /messages responses above, in turn. - # - # This is a tiny bit silly since the client wouldn't know the parent ID - # from the requests above; consider the parent ID to be known from a - # previous /sync. - for from_token in (sync_prev_batch, messages_end): - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - # The relation should be in the returned chunk. - self.assertIn( - annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] - ) - - def test_aggregation_pagination_groups(self) -> None: - """Test that we can paginate annotation groups correctly.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - sent_groups = {"π": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} - for key in itertools.chain.from_iterable( - itertools.repeat(key, num) for key, num in sent_groups.items() - ): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key=key, - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - - idx += 1 - idx %= len(access_tokens) - - prev_token: Optional[str] = None - found_groups: Dict[str, int] = {} - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - for groups in channel.json_body["chunk"]: - # We only expect reactions - self.assertEqual(groups["type"], "m.reaction", channel.json_body) - - # We should only see each key once - self.assertNotIn(groups["key"], found_groups, channel.json_body) - - found_groups[groups["key"]] = groups["count"] - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - self.assertEqual(sent_groups, found_groups) - - def test_aggregation_pagination_within_group(self) -> None: - """Test that we can paginate within an annotation group.""" - - # We need to create ten separate users to send each reaction. - access_tokens = [self.user_token, self.user2_token] - idx = 0 - while len(access_tokens) < 10: - user_id, token = self._create_user("test" + str(idx)) - idx += 1 - - self.helper.join(self.room, user=user_id, tok=token) - access_tokens.append(token) - - idx = 0 - expected_event_ids = [] - for _ in range(10): - channel = self._send_relation( - RelationTypes.ANNOTATION, - "m.reaction", - key="π", - access_token=access_tokens[idx], - ) - self.assertEqual(200, channel.code, channel.json_body) - expected_event_ids.append(channel.json_body["event_id"]) - - idx += 1 - - # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") - self.assertEqual(200, channel.code, channel.json_body) - - prev_token = "" - found_event_ids: List[str] = [] - encoded_key = urllib.parse.quote_plus("π".encode()) - for _ in range(20): - from_token = "" - if prev_token: - from_token = "&from=" + prev_token - - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}" - f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" - f"/m.reaction/{encoded_key}?limit=1{from_token}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - - self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) - - found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) - - next_batch = channel.json_body.get("next_batch") - - self.assertNotEqual(prev_token, next_batch) - prev_token = next_batch - - if not prev_token: - break - - # We paginated backwards, so reverse - found_event_ids.reverse() - self.assertEqual(found_event_ids, expected_event_ids) def test_aggregation(self) -> None: """Test that annotations get correctly aggregated.""" - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") channel = self.make_request( "GET", @@ -558,30 +302,21 @@ class RelationsTestCase(BaseRelationsTestCase): See test_edit for a similar test for edits. """ # Setup by sending a variety of relations. - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation( + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) - - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - self.assertEqual(200, channel.code, channel.json_body) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) reply_1 = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) reply_2 = channel.json_body["event_id"] - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) + self._send_relation(RelationTypes.THREAD, "m.room.test") channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) thread_2 = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -693,14 +428,12 @@ class RelationsTestCase(BaseRelationsTestCase): when directly requested. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) annotation_id = channel.json_body["event_id"] # Annotate the annotation. - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -713,14 +446,12 @@ class RelationsTestCase(BaseRelationsTestCase): def test_aggregation_get_event_for_thread(self) -> None: """Test that threads get bundled aggregations included when directly requested.""" channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) thread_id = channel.json_body["event_id"] # Annotate the annotation. - channel = self._send_relation( + self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -877,8 +608,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] def assert_bundle(event_json: JsonDict) -> None: @@ -954,7 +683,7 @@ class RelationsTestCase(BaseRelationsTestCase): shouldn't be allowed, are correctly handled. """ - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -963,7 +692,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) new_body = {"msgtype": "m.text", "body": "I've been edited!"} channel = self._send_relation( @@ -971,11 +699,9 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message.WRONG_TYPE", content={ @@ -984,7 +710,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "Edit, but wrong type"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) channel = self.make_request( "GET", @@ -1015,7 +740,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A reply!"}, ) - self.assertEqual(200, channel.code, channel.json_body) reply = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1025,8 +749,6 @@ class RelationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=reply, ) - self.assertEqual(200, channel.code, channel.json_body) - edit_event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1071,7 +793,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.room.message", content={"msgtype": "m.text", "body": "A threaded reply!"}, ) - self.assertEqual(200, channel.code, channel.json_body) threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} @@ -1081,7 +802,6 @@ class RelationsTestCase(BaseRelationsTestCase): content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, parent_id=threaded_event_id, ) - self.assertEqual(200, channel.code, channel.json_body) # Fetch the thread root, to get the bundled aggregation for the thread. channel = self.make_request( @@ -1113,7 +833,6 @@ class RelationsTestCase(BaseRelationsTestCase): "m.new_content": new_body, }, ) - self.assertEqual(200, channel.code, channel.json_body) edit_event_id = channel.json_body["event_id"] # Edit the edit event. @@ -1127,7 +846,6 @@ class RelationsTestCase(BaseRelationsTestCase): }, parent_id=edit_event_id, ) - self.assertEqual(200, channel.code, channel.json_body) # Request the original event. channel = self.make_request( @@ -1154,7 +872,6 @@ class RelationsTestCase(BaseRelationsTestCase): def test_unknown_relations(self) -> None: """Unknown relations should be accepted.""" channel = self._send_relation("m.relation.test", "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) event_id = channel.json_body["event_id"] channel = self.make_request( @@ -1208,15 +925,12 @@ class RelationsTestCase(BaseRelationsTestCase): def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="π") - self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_good = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="A") - self.assertEqual(200, channel.code, channel.json_body) annotation_event_id_bad = channel.json_body["event_id"] channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - self.assertEqual(200, channel.code, channel.json_body) thread_event_id = channel.json_body["event_id"] # Clean-up the table as if the inserts did not happen during event creation. @@ -1296,71 +1010,312 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertIn("m.relations", parent_event["unsigned"]) -class RelationRedactionTestCase(BaseRelationsTestCase): - """ - Test the behaviour of relations when the parent or child event is redacted. +class RelationPaginationTestCase(BaseRelationsTestCase): + def test_basic_paginate_relations(self) -> None: + """Tests that calling pagination API correctly the latest relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + first_annotation_id = channel.json_body["event_id"] - The behaviour of each relation type is subtly different which causes the tests - to be a bit repetitive, they follow a naming scheme of: + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + second_annotation_id = channel.json_body["event_id"] - test_redact_(relation|parent)_{relation_type} + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) - The first bit of "relation" means that the event with the relation defined - on it (the child event) is to be redacted. A "parent" means that the target - of the relation (the parent event) is to be redacted. + # We expect to get back a single pagination result, which is the latest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": second_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) - The relation_type describes which type of relation is under test (i.e. it is - related to the value of rel_type in the event content). - """ + # We also expect to get the original event (the id of which is self.parent_id) + self.assertEqual( + channel.json_body["original_event"]["event_id"], self.parent_id + ) - def _redact(self, event_id: str) -> None: + # Make sure next_batch has something in it that looks like it could be a + # valid token. + self.assertIsInstance( + channel.json_body.get("next_batch"), str, channel.json_body + ) + + # Request the relations again, but with a different direction. channel = self.make_request( - "POST", - f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}", + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations" + f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f", access_token=self.user_token, - content={}, ) self.assertEqual(200, channel.code, channel.json_body) - def _make_relation_requests(self) -> Tuple[List[str], JsonDict]: - """ - Makes requests and ensures they result in a 200 response, returns a - tuple of results: + # We expect to get back a single pagination result, which is the earliest + # full relation event we sent above. + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + self.assert_dict( + { + "event_id": first_annotation_id, + "sender": self.user_id, + "type": "m.reaction", + }, + channel.json_body["chunk"][0], + ) - 1. `/relations` -> Returns a list of event IDs. - 2. `/event` -> Returns the response's m.relations field (from unsigned), - if it exists. + def test_repeated_paginate_relations(self) -> None: + """Test that if we paginate using a limit and tokens then we get the + expected events. """ - # Request the relations of the event. + expected_event_ids = [] + for idx in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", chr(ord("a") + idx) + ) + expected_event_ids.append(channel.json_body["event_id"]) + + prev_token = "" + found_event_ids: List[str] = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) + + def test_pagination_from_sync_and_messages(self) -> None: + """Pagination tokens from /sync and /messages can be used to paginate /relations.""" + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") + annotation_id = channel.json_body["event_id"] + # Send an event after the relation events. + self.helper.send(self.room, body="Latest event", tok=self.user_token) + + # Request /sync, limiting it such that only the latest event is returned + # (and not the relation). + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 1}}}') channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}", - access_token=self.user_token, + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + sync_prev_batch = room_timeline["prev_batch"] + self.assertIsNotNone(sync_prev_batch) + # Ensure the relation event is not in the batch returned from /sync. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in room_timeline["events"]] ) - self.assertEquals(200, channel.code, channel.json_body) - event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]] - # Fetch the bundled aggregations of the event. + # Request /messages, limiting it such that only the latest event is + # returned (and not the relation). channel = self.make_request( "GET", - f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}", + f"/rooms/{self.room}/messages?dir=b&limit=1", access_token=self.user_token, ) - self.assertEquals(200, channel.code, channel.json_body) - bundled_relations = channel.json_body["unsigned"].get("m.relations", {}) + self.assertEqual(200, channel.code, channel.json_body) + messages_end = channel.json_body["end"] + self.assertIsNotNone(messages_end) + # Ensure the relation event is not in the chunk returned from /messages. + self.assertNotIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) + + # Request /relations with the pagination tokens received from both the + # /sync and /messages responses above, in turn. + # + # This is a tiny bit silly since the client wouldn't know the parent ID + # from the requests above; consider the parent ID to be known from a + # previous /sync. + for from_token in (sync_prev_batch, messages_end): + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) - return event_ids, bundled_relations + # The relation should be in the returned chunk. + self.assertIn( + annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]] + ) - def _get_aggregations(self) -> List[JsonDict]: - """Request /aggregations on the parent ID and includes the returned chunk.""" + def test_aggregation_pagination_groups(self) -> None: + """Test that we can paginate annotation groups correctly.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + sent_groups = {"π": 10, "a": 7, "b": 5, "c": 3, "d": 2, "e": 1} + for key in itertools.chain.from_iterable( + itertools.repeat(key, num) for key, num in sent_groups.items() + ): + self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key=key, + access_token=access_tokens[idx], + ) + + idx += 1 + idx %= len(access_tokens) + + prev_token: Optional[str] = None + found_groups: Dict[str, int] = {} + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + for groups in channel.json_body["chunk"]: + # We only expect reactions + self.assertEqual(groups["type"], "m.reaction", channel.json_body) + + # We should only see each key once + self.assertNotIn(groups["key"], found_groups, channel.json_body) + + found_groups[groups["key"]] = groups["count"] + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEqual(sent_groups, found_groups) + + def test_aggregation_pagination_within_group(self) -> None: + """Test that we can paginate within an annotation group.""" + + # We need to create ten separate users to send each reaction. + access_tokens = [self.user_token, self.user2_token] + idx = 0 + while len(access_tokens) < 10: + user_id, token = self._create_user("test" + str(idx)) + idx += 1 + + self.helper.join(self.room, user=user_id, tok=token) + access_tokens.append(token) + + idx = 0 + expected_event_ids = [] + for _ in range(10): + channel = self._send_relation( + RelationTypes.ANNOTATION, + "m.reaction", + key="π", + access_token=access_tokens[idx], + ) + expected_event_ids.append(channel.json_body["event_id"]) + + idx += 1 + + # Also send a different type of reaction so that we test we don't see it + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + + prev_token = "" + found_event_ids: List[str] = [] + encoded_key = urllib.parse.quote_plus("π".encode()) + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}" + f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}" + f"/m.reaction/{encoded_key}?limit=1{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + # We paginated backwards, so reverse + found_event_ids.reverse() + self.assertEqual(found_event_ids, expected_event_ids) + + +class RelationRedactionTestCase(BaseRelationsTestCase): + """ + Test the behaviour of relations when the parent or child event is redacted. + + The behaviour of each relation type is subtly different which causes the tests + to be a bit repetitive, they follow a naming scheme of: + + test_redact_(relation|parent)_{relation_type} + + The first bit of "relation" means that the event with the relation defined + on it (the child event) is to be redacted. A "parent" means that the target + of the relation (the parent event) is to be redacted. + + The relation_type describes which type of relation is under test (i.e. it is + related to the value of rel_type in the event content). + """ + + def _redact(self, event_id: str) -> None: channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}", + "POST", + f"/_matrix/client/r0/rooms/{self.room}/redact/{event_id}", access_token=self.user_token, + content={}, ) self.assertEqual(200, channel.code, channel.json_body) - return channel.json_body["chunk"] def test_redact_relation_annotation(self) -> None: """ @@ -1371,17 +1326,16 @@ class RelationRedactionTestCase(BaseRelationsTestCase): the response to relations. """ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] channel = self._send_relation( RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token ) - self.assertEqual(200, channel.code, channel.json_body) unredacted_event_id = channel.json_body["event_id"] # Both relations should exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertEquals( relations["m.annotation"], @@ -1396,7 +1350,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(to_redact_event_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [unredacted_event_id]) self.assertEquals( relations["m.annotation"], @@ -1419,7 +1374,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 1", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) unredacted_event_id = channel.json_body["event_id"] # Note that the *last* event in the thread is redacted, as that gets @@ -1429,11 +1383,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 2", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) to_redact_event_id = channel.json_body["event_id"] # Both relations exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id]) self.assertDictContainsSubset( { @@ -1452,7 +1406,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(to_redact_event_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [unredacted_event_id]) self.assertDictContainsSubset( { @@ -1472,7 +1427,7 @@ class RelationRedactionTestCase(BaseRelationsTestCase): is redacted. """ # Add a relation - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", parent_id=self.parent_id, @@ -1482,10 +1437,10 @@ class RelationRedactionTestCase(BaseRelationsTestCase): "m.new_content": {"msgtype": "m.text", "body": "First edit"}, }, ) - self.assertEqual(200, channel.code, channel.json_body) # Check the relation is returned - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.REPLACE, relations) @@ -1493,7 +1448,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(self.parent_id) # The relations are not returned. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 0) self.assertEqual(relations, {}) @@ -1503,11 +1459,11 @@ class RelationRedactionTestCase(BaseRelationsTestCase): """ # Add a relation channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="π") - self.assertEqual(200, channel.code, channel.json_body) related_event_id = channel.json_body["event_id"] # The relations should exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEqual(len(event_ids), 1) self.assertIn(RelationTypes.ANNOTATION, relations) @@ -1519,7 +1475,8 @@ class RelationRedactionTestCase(BaseRelationsTestCase): self._redact(self.parent_id) # The relations are returned. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(event_ids, [related_event_id]) self.assertEquals( relations["m.annotation"], @@ -1540,14 +1497,14 @@ class RelationRedactionTestCase(BaseRelationsTestCase): EventTypes.Message, content={"body": "reply 1", "msgtype": "m.text"}, ) - self.assertEqual(200, channel.code, channel.json_body) related_event_id = channel.json_body["event_id"] # Redact one of the reactions. self._redact(self.parent_id) # The unredacted relation should still exist. - event_ids, relations = self._make_relation_requests() + event_ids = self._get_related_events() + relations = self._get_bundled_aggregations() self.assertEquals(len(event_ids), 1) self.assertDictContainsSubset( { diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index 3fb37a2a59..62e308814d 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -16,7 +16,6 @@ from synapse.rest.media.v1.preview_html import ( _get_html_media_encodings, decode_body, parse_html_to_open_graph, - rebase_url, summarize_paragraphs, ) @@ -161,7 +160,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -177,7 +176,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -196,7 +195,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual( og, @@ -218,7 +217,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -232,7 +231,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -247,7 +246,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) @@ -262,7 +261,7 @@ class CalcOgTestCase(unittest.TestCase): """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -290,7 +289,7 @@ class CalcOgTestCase(unittest.TestCase): <head><title>Foo</title></head><body>Some text.</body></html> """.strip() tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding(self) -> None: @@ -304,7 +303,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) def test_invalid_encoding2(self) -> None: @@ -319,7 +318,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "ΓΏΓΏ Foo", "og:description": "Some text."}) def test_windows_1252(self) -> None: @@ -333,7 +332,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ tree = decode_body(html, "http://example.com/test.html") - og = parse_html_to_open_graph(tree, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) self.assertEqual(og, {"og:title": "Γ³", "og:description": "Some text."}) @@ -448,34 +447,3 @@ class MediaEncodingTestCase(unittest.TestCase): 'text/html; charset="invalid"', ) self.assertEqual(list(encodings), ["utf-8", "cp1252"]) - - -class RebaseUrlTestCase(unittest.TestCase): - def test_relative(self) -> None: - """Relative URLs should be resolved based on the context of the base URL.""" - self.assertEqual( - rebase_url("subpage", "https://example.com/foo/"), - "https://example.com/foo/subpage", - ) - self.assertEqual( - rebase_url("sibling", "https://example.com/foo"), - "https://example.com/sibling", - ) - self.assertEqual( - rebase_url("/bar", "https://example.com/foo/"), - "https://example.com/bar", - ) - - def test_absolute(self) -> None: - """Absolute URLs should not be modified.""" - self.assertEqual( - rebase_url("https://alice.com/a/", "https://example.com/foo/"), - "https://alice.com/a/", - ) - - def test_data(self) -> None: - """Data URLs should not be modified.""" - self.assertEqual( - rebase_url("data:,Hello%2C%20World%21", "https://example.com/foo/"), - "data:,Hello%2C%20World%21", - ) diff --git a/tests/storage/test_account_data.py b/tests/storage/test_account_data.py index 272cd35402..72bf5b3d31 100644 --- a/tests/storage/test_account_data.py +++ b/tests/storage/test_account_data.py @@ -47,9 +47,18 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): expected_ignorer_user_ids, ) + def assert_ignored( + self, ignorer_user_id: str, expected_ignored_user_ids: Set[str] + ) -> None: + self.assertEqual( + self.get_success(self.store.ignored_users(ignorer_user_id)), + expected_ignored_user_ids, + ) + def test_ignoring_users(self): """Basic adding/removing of users from the ignore list.""" self._update_ignore_list("@other:test", "@another:remote") + self.assert_ignored(self.user, {"@other:test", "@another:remote"}) # Check a user which no one ignores. self.assert_ignorers("@user:test", set()) @@ -62,6 +71,7 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): # Add one user, remove one user, and leave one user. self._update_ignore_list("@foo:test", "@another:remote") + self.assert_ignored(self.user, {"@foo:test", "@another:remote"}) # Check the removed user. self.assert_ignorers("@other:test", set()) @@ -76,20 +86,24 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): """Ensure that caching works properly between different users.""" # The first user ignores a user. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # The second user ignores them. self._update_ignore_list("@other:test", ignorer_user_id="@second:test") + self.assert_ignored("@second:test", {"@other:test"}) self.assert_ignorers("@other:test", {self.user, "@second:test"}) # The first user un-ignores them. self._update_ignore_list() + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", {"@second:test"}) def test_invalid_data(self): """Invalid data ends up clearing out the ignored users list.""" # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # No ignored_users key. @@ -102,10 +116,12 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) # Add some data and ensure it is there. self._update_ignore_list("@other:test") + self.assert_ignored(self.user, {"@other:test"}) self.assert_ignorers("@other:test", {self.user}) # Invalid data. @@ -118,4 +134,5 @@ class IgnoredUsersTestCase(unittest.HomeserverTestCase): ) # No one ignores the user now. + self.assert_ignored(self.user, set()) self.assert_ignorers("@other:test", set()) diff --git a/tests/storage/test_database.py b/tests/storage/test_database.py index 8597867563..a40fc20ef9 100644 --- a/tests/storage/test_database.py +++ b/tests/storage/test_database.py @@ -12,7 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.storage.database import make_tuple_comparison_clause +from typing import Callable, Tuple +from unittest.mock import Mock, call + +from twisted.internet import defer +from twisted.internet.defer import CancelledError, Deferred +from twisted.test.proto_helpers import MemoryReactor + +from synapse.server import HomeServer +from synapse.storage.database import ( + DatabasePool, + LoggingTransaction, + make_tuple_comparison_clause, +) +from synapse.util import Clock from tests import unittest @@ -22,3 +35,150 @@ class TupleComparisonClauseTestCase(unittest.TestCase): clause, args = make_tuple_comparison_clause([("a", 1), ("b", 2)]) self.assertEqual(clause, "(a,b) > (?,?)") self.assertEqual(args, [1, 2]) + + +class CallbacksTestCase(unittest.HomeserverTestCase): + """Tests for transaction callbacks.""" + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def _run_interaction( + self, func: Callable[[LoggingTransaction], object] + ) -> Tuple[Mock, Mock]: + """Run the given function in a database transaction, with callbacks registered. + + Args: + func: The function to be run in a transaction. The transaction will be + retried if `func` raises an `OperationalError`. + + Returns: + Two mocks, which were registered as an `after_callback` and an + `exception_callback` respectively, on every transaction attempt. + """ + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + func(txn) + + try: + self.get_success_or_raise( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + except Exception: + pass + + return after_callback, exception_callback + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + after_callback, exception_callback = self._run_interaction(lambda txn: None) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + _test_txn = Mock(side_effect=ZeroDivisionError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_called_once_with(987, 654, extra=321) + + def test_failed_retry(self) -> None: + """Test that the exception callback is called for every failed attempt.""" + # Always raise an `OperationalError`. + _test_txn = Mock(side_effect=self.db_pool.engine.module.OperationalError) + after_callback, exception_callback = self._run_interaction(_test_txn) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls + + def test_successful_retry(self) -> None: + """Test callbacks for a failed transaction followed by a successful attempt.""" + # Raise an `OperationalError` on the first attempt only. + _test_txn = Mock( + side_effect=[self.db_pool.engine.module.OperationalError, None] + ) + after_callback, exception_callback = self._run_interaction(_test_txn) + + # Calling both `after_callback`s when the first attempt failed is rather + # surprising (#12184). Let's document the behaviour in a test. + after_callback.assert_has_calls( + [ + call(123, 456, extra=789), + call(123, 456, extra=789), + ] + ) + self.assertEqual(after_callback.call_count, 2) # no additional calls + exception_callback.assert_not_called() + + +class CancellationTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.db_pool: DatabasePool = self.store.db_pool + + def test_after_callback(self) -> None: + """Test that the after callback is called when a transaction succeeds.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_called_once_with(123, 456, extra=789) + exception_callback.assert_not_called() + + def test_exception_callback(self) -> None: + """Test that the exception callback is called when a transaction fails.""" + d: "Deferred[None]" + after_callback = Mock() + exception_callback = Mock() + + def _test_txn(txn: LoggingTransaction) -> None: + txn.call_after(after_callback, 123, 456, extra=789) + txn.call_on_exception(exception_callback, 987, 654, extra=321) + d.cancel() + # Simulate a retryable failure on every attempt. + raise self.db_pool.engine.module.OperationalError() + + d = defer.ensureDeferred( + self.db_pool.runInteraction("test_transaction", _test_txn) + ) + self.get_failure(d, CancelledError) + + after_callback.assert_not_called() + exception_callback.assert_has_calls( + [ + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + call(987, 654, extra=321), + ] + ) + self.assertEqual(exception_callback.call_count, 6) # no additional calls |