diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/handlers/test_federation.py | 64 | ||||
-rw-r--r-- | tests/handlers/test_password_providers.py | 223 | ||||
-rw-r--r-- | tests/handlers/test_user_directory.py | 192 | ||||
-rw-r--r-- | tests/rest/admin/test_user.py | 401 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 55 | ||||
-rw-r--r-- | tests/rest/client/test_third_party_rules.py | 56 | ||||
-rw-r--r-- | tests/rest/media/v1/test_filepath.py | 238 | ||||
-rw-r--r-- | tests/rest/media/v1/test_oembed.py | 51 | ||||
-rw-r--r-- | tests/server.py | 54 | ||||
-rw-r--r-- | tests/storage/test_user_directory.py | 77 | ||||
-rw-r--r-- | tests/test_event_auth.py | 138 | ||||
-rw-r--r-- | tests/test_preview.py | 93 | ||||
-rw-r--r-- | tests/unittest.py | 4 |
13 files changed, 1155 insertions, 491 deletions
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 936ebf3dde..e1557566e4 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -23,6 +23,7 @@ from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room +from synapse.types import create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -30,6 +31,10 @@ from tests import unittest logger = logging.getLogger(__name__) +def generate_fake_event_id() -> str: + return "$fake_" + random_string(43) + + class FederationTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, @@ -198,6 +203,65 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) + def test_backfill_with_many_backward_extremities(self): + """ + Check that we can backfill with many backward extremities. + The goal is to make sure that when we only use a portion + of backwards extremities(the magic number is more than 5), + no errors are thrown. + + Regression test, see #11027 + """ + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + requester = create_requester(user_id) + + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + ev1 = self.helper.send(room_id, "first message", tok=tok) + + # Create "many" backward extremities. The magic number we're trying to + # create more than is 5 which corresponds to the number of backward + # extremities we slice off in `_maybe_backfill_inner` + for _ in range(0, 8): + event_handler = self.hs.get_event_creation_handler() + event, context = self.get_success( + event_handler.create_event( + requester, + { + "type": "m.room.message", + "content": { + "msgtype": "m.text", + "body": "message connected to fake event", + }, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=[ + ev1["event_id"], + # We're creating an backward extremity each time thanks + # to this fake event + generate_fake_event_id(), + ], + ) + ) + self.get_success( + event_handler.handle_new_client_event(requester, event, context) + ) + + current_depth = 1 + limit = 100 + with LoggingContext("receive_pdu"): + # Make sure backfill still works + d = run_in_background( + self.hs.get_federation_handler().maybe_backfill, + room_id, + current_depth, + limit, + ) + self.get_success(d) + def test_backfill_floating_outlier_membership_auth(self): """ As the local homeserver, check that we can properly process a federated diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 38e6d9f536..7dd4a5a367 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,6 +20,8 @@ from unittest.mock import Mock from twisted.internet import defer import synapse +from synapse.handlers.auth import load_legacy_password_auth_providers +from synapse.module_api import ModuleApi from synapse.rest.client import devices, login from synapse.types import JsonDict @@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi mock_password_provider = Mock() -class PasswordOnlyAuthProvider: - """A password_provider which only implements `check_password`.""" +class LegacyPasswordOnlyAuthProvider: + """A legacy password_provider which only implements `check_password`.""" @staticmethod def parse_config(self): @@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider: return mock_password_provider.check_password(*args) -class CustomAuthProvider: - """A password_provider which implements a custom login type.""" +class LegacyCustomAuthProvider: + """A legacy password_provider which implements a custom login type.""" @staticmethod def parse_config(self): @@ -67,7 +69,23 @@ class CustomAuthProvider: return mock_password_provider.check_auth(*args) -class PasswordCustomAuthProvider: +class CustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + ) + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +class LegacyPasswordCustomAuthProvider: """A password_provider which implements password login via `check_auth`, as well as a custom type.""" @@ -85,8 +103,32 @@ class PasswordCustomAuthProvider: return mock_password_provider.check_auth(*args) -def providers_config(*providers: Type[Any]) -> dict: - """Returns a config dict that will enable the given password auth providers""" +class PasswordCustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type. + as well as a password login""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("test.login_type", ("test_field",)): self.check_auth, + ("m.login.password", ("password",)): self.check_auth, + }, + ) + pass + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + def check_pass(self, *args): + return mock_password_provider.check_password(*args) + + +def legacy_providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given legacy password auth providers""" return { "password_providers": [ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} @@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict: } +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given modules""" + return { + "modules": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + class PasswordAuthProviderTests(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() super().setUp() - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_login(self): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + load_legacy_password_auth_providers(hs) + + return hs + + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_progiver_login_legacy(self): + self.password_only_auth_provider_login_test_body() + + def password_only_auth_provider_login_test_body(self): # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "@ USER🙂NAME :test", " pASS😢word " ) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth_legacy(self): + self.password_only_auth_provider_ui_auth_test_body() + + def password_only_auth_provider_ui_auth_test_body(self): """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_login(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_login_legacy(self): + self.local_user_fallback_login_test_body() + + def local_user_fallback_login_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth_legacy(self): + self.local_user_fallback_ui_auth_test_body() + + def local_user_fallback_ui_auth_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login(self): + def test_no_local_user_fallback_login_legacy(self): + self.no_local_user_fallback_login_test_body() + + def no_local_user_fallback_login_test_body(self): """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth(self): + def test_no_local_user_fallback_ui_auth_legacy(self): + self.no_local_user_fallback_ui_auth_test_body() + + def no_local_user_fallback_ui_auth_test_body(self): """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"enabled": False}, } ) - def test_password_auth_disabled(self): + def test_password_auth_disabled_legacy(self): + self.password_auth_disabled_test_body() + + def password_auth_disabled_test_body(self): """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_password.assert_not_called() + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_login_legacy(self): + self.custom_auth_provider_login_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_login(self): + self.custom_auth_provider_login_test_body() + + def custom_auth_provider_login_test_body(self): # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) @@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # in these cases, but at least we can guard against the API changing # unexpectedly mock_password_provider.check_auth.return_value = defer.succeed( - "@ MALFORMED! :bz" + ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") self.assertEqual(channel.code, 200, channel.result) @@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): " USER🙂NAME ", "test.login_type", {"test_field": " abc "} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_ui_auth_legacy(self): + self.custom_auth_provider_ui_auth_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_ui_auth(self): + self.custom_auth_provider_ui_auth_test_body() + + def custom_auth_provider_ui_auth_test_body(self): # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # and finally, succeed mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "localuser", "test.login_type", {"test_field": "foo"} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_callback_legacy(self): + self.custom_auth_provider_callback_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_callback(self): + self.custom_auth_provider_callback_test_body() + + def custom_auth_provider_callback_test_body(self): callback = Mock(return_value=defer.succeed(None)) mock_password_provider.check_auth.return_value = defer.succeed( @@ -411,9 +519,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertIn(p, call_args[0]) @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_custom_auth_password_disabled_legacy(self): + self.custom_auth_password_disabled_test_body() + + @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) def test_custom_auth_password_disabled(self): + self.custom_auth_password_disabled_test_body() + + def custom_auth_password_disabled_test_body(self): """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -427,11 +547,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + + @override_config( + { **providers_config(CustomAuthProvider), "password_config": {"enabled": False, "localdb_enabled": False}, } ) def test_custom_auth_password_disabled_localdb_enabled(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + + def custom_auth_password_disabled_localdb_enabled_test_body(self): """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -450,11 +582,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login_legacy(self): + self.password_custom_auth_password_disabled_login_test_body() + + @override_config( + { **providers_config(PasswordCustomAuthProvider), "password_config": {"enabled": False}, } ) def test_password_custom_auth_password_disabled_login(self): + self.password_custom_auth_password_disabled_login_test_body() + + def password_custom_auth_password_disabled_login_test_body(self): """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( { @@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_password_custom_auth_password_disabled_ui_auth(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() + + def password_custom_auth_password_disabled_ui_auth_test_body(self): """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, 200, channel.result) @@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "Password login has been disabled.", channel.json_body["error"] ) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() mock_password_provider.reset_mock() # successful auth @@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_auth.assert_called_once_with( "localuser", "test.login_type", {"test_field": "x"} ) + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback_legacy(self): + self.custom_auth_no_local_user_fallback_test_body() @override_config( { @@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_custom_auth_no_local_user_fallback(self): + self.custom_auth_no_local_user_fallback_test_body() + + def custom_auth_no_local_user_fallback_test_body(self): """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass") diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 0120b4688b..b9ad92b977 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): tok=alice_token, ) - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) - in_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() + # The user directory should reflect the room memberships above. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) - self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)}) self.assertEqual( - set(in_public), {(alice, public), (bob, public), (alice, public2)} - ) - self.assertEqual( - self.user_dir_helper._compress_shared(in_private), + in_private, {(alice, bob, private), (bob, alice, private)}, ) @@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) self.assertEqual(set(in_public), {(user1, room), (user2, room)}) + def test_excludes_users_when_making_room_public(self) -> None: + # Create a regular user and a support user. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + # Make a public and private room containing Alice and the support user + public, initially_private = self._create_rooms_and_inject_memberships( + alice, alice_token, support + ) + self._check_only_one_user_in_directory(alice, public) + + # Alice makes the private room public. + self.helper.send_state( + initially_private, + "m.room.join_rules", + {"join_rule": "public"}, + tok=alice_token, + ) + + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice}) + self.assertEqual(in_public, {(alice, public), (alice, initially_private)}) + self.assertEqual(in_private, set()) + + def test_switching_from_private_to_public_to_private(self) -> None: + """Check we update the room sharing tables when switching a room + from private to public, then back again to private.""" + # Alice and Bob share a private room. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + room = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.helper.invite(room, alice, bob, tok=alice_token) + self.helper.join(room, bob, tok=bob_token) + + # The user directory should reflect this. + def check_user_dir_for_private_room() -> None: + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, set()) + self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)}) + + check_user_dir_for_private_room() + + # Alice makes the room public. + self.helper.send_state( + room, + "m.room.join_rules", + {"join_rule": "public"}, + tok=alice_token, + ) + + # The user directory should be updated accordingly + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, room), (bob, room)}) + self.assertEqual(in_private, set()) + + # Alice makes the room private. + self.helper.send_state( + room, + "m.room.join_rules", + {"join_rule": "invite"}, + tok=alice_token, + ) + + # The user directory should be updated accordingly + check_user_dir_for_private_room() + def _create_rooms_and_inject_memberships( self, creator: str, token: str, joiner: str ) -> Tuple[str, str]: @@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): return public_room, private_room def _check_only_one_user_in_directory(self, user: str, public: str) -> None: - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) - in_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) + """Check that the user directory DB tables show that: + - only one user is in the user directory + - they belong to exactly one public room + - they don't share a private room with anyone. + """ + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) self.assertEqual(users, {user}) - self.assertEqual(set(in_public), {(user, public)}) - self.assertEqual(in_private, []) + self.assertEqual(in_public, {(user, public)}) + self.assertEqual(in_private, set()) def test_handle_local_profile_change_with_support_user(self) -> None: support_user_id = "@support:test" @@ -581,11 +662,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # We get one search result when searching for user2 by user1. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -610,8 +688,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, set()) + self.assertEqual(public_users, set()) # User1 now gets no search results for any of the other users. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -645,11 +723,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # We get one search result when searching for user2 by user1. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -704,11 +779,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # Configure a spam checker. spam_checker = self.hs.get_spam_checker() @@ -740,8 +812,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) # No users share rooms - self.assertEqual(public_users, []) - self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) + self.assertEqual(public_users, set()) + self.assertEqual(shares_private, set()) # Despite not sharing a room, search_all_users means we get a search # result. @@ -842,6 +914,56 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.hs.get_storage().persistence.persist_event(event, context) ) + def test_local_user_leaving_room_remains_in_user_directory(self) -> None: + """We've chosen to simplify the user directory's implementation by + always including local users. Ensure this invariant is maintained when + a local user + - leaves a room, and + - leaves the last room they're in which is visible to this server. + + This is user-visible if the "search_all_users" config option is on: the + local user who left a room would no longer be searchable if this test fails! + """ + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + # Alice makes two public rooms, which Bob joins. + room1 = self.helper.create_room_as(alice, is_public=True, tok=alice_token) + room2 = self.helper.create_room_as(alice, is_public=True, tok=alice_token) + self.helper.join(room1, bob, tok=bob_token) + self.helper.join(room2, bob, tok=bob_token) + + # The user directory tables are updated. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual( + in_public, {(alice, room1), (alice, room2), (bob, room1), (bob, room2)} + ) + self.assertEqual(in_private, set()) + + # Alice leaves one room. She should still be in the directory. + self.helper.leave(room1, alice, tok=alice_token) + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, room2), (bob, room1), (bob, room2)}) + self.assertEqual(in_private, set()) + + # Alice leaves the other. She should still be in the directory. + self.helper.leave(room2, alice, tok=alice_token) + self.wait_for_background_updates() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(bob, room1), (bob, room2)}) + self.assertEqual(in_private, set()) + class TestUserDirSearchDisabled(unittest.HomeserverTestCase): servlets = [ diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 6ed9e42173..c9e2754b09 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -14,14 +14,13 @@ import hashlib import hmac -import json import os import urllib.parse from binascii import unhexlify from typing import List, Optional from unittest.mock import Mock, patch -from parameterized import parameterized +from parameterized import parameterized, parameterized_class import synapse.rest.admin from synapse.api.constants import UserTypes @@ -104,8 +103,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 59 seconds self.reactor.advance(59) - body = json.dumps({"nonce": nonce}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) @@ -113,7 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # 61 seconds self.reactor.advance(2) - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -129,18 +128,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual("HMAC incorrect", channel.json_body["error"]) def test_register_correct_nonce(self): @@ -157,17 +154,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "user_type": UserTypes.SUPPORT, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -183,22 +178,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) # Now, try and reuse it - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("unrecognised nonce", channel.json_body["error"]) @@ -218,9 +211,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Nonce check # - # Must be present - body = json.dumps({}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + # Must be an empty body present + channel = self.make_request("POST", self.url, {}) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("nonce must be specified", channel.json_body["error"]) @@ -230,29 +222,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Must be present - body = json.dumps({"nonce": nonce()}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + channel = self.make_request("POST", self.url, {"nonce": nonce()}) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("username must be specified", channel.json_body["error"]) # Must be a string - body = json.dumps({"nonce": nonce(), "username": 1234}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": 1234} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "abcd\u0000"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a" * 1000}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a" * 1000} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid username", channel.json_body["error"]) @@ -262,29 +253,29 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Must be present - body = json.dumps({"nonce": nonce(), "username": "a"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("password must be specified", channel.json_body["error"]) # Must be a string - body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": 1234} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Must not have null bytes - body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) # Super long - body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000}) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = {"nonce": nonce(), "username": "a", "password": "A" * 1000} + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid password", channel.json_body["error"]) @@ -294,15 +285,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # # Invalid user_type - body = json.dumps( - { - "nonce": nonce(), - "username": "a", - "password": "1234", - "user_type": "invalid", - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce(), + "username": "a", + "password": "1234", + "user_type": "invalid", + } + channel = self.make_request("POST", self.url, body) self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual("Invalid user type", channel.json_body["error"]) @@ -320,10 +309,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac} - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob1", + "password": "abc123", + "mac": want_mac, + } + + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob1:test", channel.json_body["user_id"]) @@ -340,16 +333,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob2", - "displayname": None, - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob2", + "displayname": None, + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob2:test", channel.json_body["user_id"]) @@ -366,22 +357,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob3", - "displayname": "", - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob3", + "displayname": "", + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob3:test", channel.json_body["user_id"]) channel = self.make_request("GET", "/profile/@bob3:test/displayname") - self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(404, channel.code, msg=channel.json_body) # set displayname channel = self.make_request("GET", self.url) @@ -391,16 +380,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin") want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob4", - "displayname": "Bob's Name", - "password": "abc123", - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob4", + "displayname": "Bob's Name", + "password": "abc123", + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob4:test", channel.json_body["user_id"]) @@ -440,17 +427,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): ) want_mac = want_mac.hexdigest() - body = json.dumps( - { - "nonce": nonce, - "username": "bob", - "password": "abc123", - "admin": True, - "user_type": UserTypes.SUPPORT, - "mac": want_mac, - } - ) - channel = self.make_request("POST", self.url, body.encode("utf8")) + body = { + "nonce": nonce, + "username": "bob", + "password": "abc123", + "admin": True, + "user_type": UserTypes.SUPPORT, + "mac": want_mac, + } + channel = self.make_request("POST", self.url, body) self.assertEqual(200, channel.code, msg=channel.json_body) self.assertEqual("@bob:test", channel.json_body["user_id"]) @@ -993,12 +978,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase): """ If parameter `erase` is not boolean, return an error """ - body = json.dumps({"erase": "False"}) channel = self.make_request( "POST", self.url, - content=body.encode(encoding="utf_8"), + content={"erase": "False"}, access_token=self.admin_user_tok, ) @@ -2201,7 +2185,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2216,7 +2200,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -2359,7 +2343,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request("GET", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_no_admin(self): @@ -2374,7 +2358,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_does_not_exist(self): @@ -3073,7 +3057,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): """Try to login as a user without authentication.""" channel = self.make_request("POST", self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_not_admin(self): @@ -3082,7 +3066,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): "POST", self.url, b"{}", access_token=self.other_user_tok ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) def test_send_event(self): """Test that sending event as a user works.""" @@ -3127,7 +3111,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( @@ -3160,7 +3144,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", "devices", b"{}", access_token=self.other_user_tok ) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) def test_admin_logout_all(self): """Tests that the admin user calling `/logout/all` does expire the @@ -3181,7 +3165,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): # The puppet token should no longer work channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) # .. but the real user's tokens should still work channel = self.make_request( @@ -3242,6 +3226,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase): self.helper.join(room_id, user=self.other_user, tok=puppet_token) +@parameterized_class( + ("url_prefix",), + [ + ("/_synapse/admin/v1/whois/%s",), + ("/_matrix/client/r0/admin/whois/%s",), + ], +) class WhoisRestTestCase(unittest.HomeserverTestCase): servlets = [ @@ -3254,21 +3245,14 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): self.admin_user_tok = self.login("admin", "pass") self.other_user = self.register_user("user", "pass") - self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user) - self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote( - self.other_user - ) + self.url = self.url_prefix % self.other_user def test_no_auth(self): """ Try to get information of an user without authentication. """ - channel = self.make_request("GET", self.url1, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("GET", self.url2, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + channel = self.make_request("GET", self.url, b"{}") + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3280,38 +3264,21 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - self.url1, - access_token=other_user2_token, - ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=other_user2_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): """ Tests that a lookup for a user that is not a local returns a 400 """ - url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain" - url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain" - - channel = self.make_request( - "GET", - url1, - access_token=self.admin_user_tok, - ) - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only whois a local user", channel.json_body["error"]) + url = self.url_prefix % "@unknown_person:unknown_domain" channel = self.make_request( "GET", - url2, + url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) @@ -3323,16 +3290,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): """ channel = self.make_request( "GET", - self.url1, - access_token=self.admin_user_tok, - ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual(self.other_user, channel.json_body["user_id"]) - self.assertIn("devices", channel.json_body) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=self.admin_user_tok, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -3347,16 +3305,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase): channel = self.make_request( "GET", - self.url1, - access_token=other_user_token, - ) - self.assertEqual(200, channel.code, msg=channel.json_body) - self.assertEqual(self.other_user, channel.json_body["user_id"]) - self.assertIn("devices", channel.json_body) - - channel = self.make_request( - "GET", - self.url2, + self.url, access_token=other_user_token, ) self.assertEqual(200, channel.code, msg=channel.json_body) @@ -3388,7 +3337,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): Try to get information of an user without authentication. """ channel = self.make_request("POST", self.url) - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) def test_requester_is_not_admin(self): @@ -3398,7 +3347,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase): other_user_token = self.login("user", "pass") channel = self.make_request("POST", self.url, access_token=other_user_token) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) def test_user_is_not_local(self): @@ -3447,84 +3396,41 @@ class RateLimitTestCase(unittest.HomeserverTestCase): % urllib.parse.quote(self.other_user) ) - def test_no_auth(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_no_auth(self, method: str): """ Try to get information of a user without authentication. """ - channel = self.make_request("GET", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("POST", self.url, b"{}") + channel = self.make_request(method, self.url, b"{}") - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - - channel = self.make_request("DELETE", self.url, b"{}") - - self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(401, channel.code, msg=channel.json_body) self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) - def test_requester_is_no_admin(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_requester_is_no_admin(self, method: str): """ If the user is not a server admin, an error is returned. """ other_user_token = self.login("user", "pass") channel = self.make_request( - "GET", - self.url, - access_token=other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "POST", - self.url, - access_token=other_user_token, - ) - - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, self.url, access_token=other_user_token, ) - self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(403, channel.code, msg=channel.json_body) self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) - def test_user_does_not_exist(self): + @parameterized.expand(["GET", "POST", "DELETE"]) + def test_user_does_not_exist(self, method: str): """ Tests that a lookup for a user that does not exist returns a 404 """ url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit" channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "POST", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(404, channel.code, msg=channel.json_body) - self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) @@ -3532,7 +3438,14 @@ class RateLimitTestCase(unittest.HomeserverTestCase): self.assertEqual(404, channel.code, msg=channel.json_body) self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"]) - def test_user_is_not_local(self): + @parameterized.expand( + [ + ("GET", "Can only look up local users"), + ("POST", "Only local users can be ratelimited"), + ("DELETE", "Only local users can be ratelimited"), + ] + ) + def test_user_is_not_local(self, method: str, error_msg: str): """ Tests that a lookup for a user that is not a local returns a 400 """ @@ -3541,35 +3454,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase): ) channel = self.make_request( - "GET", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual("Can only look up local users", channel.json_body["error"]) - - channel = self.make_request( - "POST", - url, - access_token=self.admin_user_tok, - ) - - self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual( - "Only local users can be ratelimited", channel.json_body["error"] - ) - - channel = self.make_request( - "DELETE", + method, url, access_token=self.admin_user_tok, ) self.assertEqual(400, channel.code, msg=channel.json_body) - self.assertEqual( - "Only local users can be ratelimited", channel.json_body["error"] - ) + self.assertEqual(error_msg, channel.json_body["error"]) def test_invalid_parameter(self): """ diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 02b5e9a8d0..3c7d49f0b4 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -13,15 +13,15 @@ # limitations under the License. import itertools -import json -import urllib -from typing import Optional +import urllib.parse +from typing import Dict, List, Optional, Tuple from synapse.api.constants import EventTypes, RelationTypes from synapse.rest import admin from synapse.rest.client import login, register, relations, room from tests import unittest +from tests.server import FakeChannel class RelationsTestCase(unittest.HomeserverTestCase): @@ -34,16 +34,16 @@ class RelationsTestCase(unittest.HomeserverTestCase): ] hijack_auth = False - def make_homeserver(self, reactor, clock): + def default_config(self) -> dict: # We need to enable msc1849 support for aggregations - config = self.default_config() + config = super().default_config() config["experimental_msc1849_support_enabled"] = True # We enable frozen dicts as relations/edits change event contents, so we # want to test that we don't modify the events in the caches. config["use_frozen_dicts"] = True - return self.setup_test_homeserver(config=config) + return config def prepare(self, reactor, clock, hs): self.user_id, self.user_token = self._create_user("alice") @@ -146,8 +146,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) expected_event_ids.append(channel.json_body["event_id"]) - prev_token = None - found_event_ids = [] + prev_token: Optional[str] = None + found_event_ids: List[str] = [] for _ in range(20): from_token = "" if prev_token: @@ -203,8 +203,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): idx += 1 idx %= len(access_tokens) - prev_token = None - found_groups = {} + prev_token: Optional[str] = None + found_groups: Dict[str, int] = {} for _ in range(20): from_token = "" if prev_token: @@ -270,8 +270,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") self.assertEquals(200, channel.code, channel.json_body) - prev_token = None - found_event_ids = [] + prev_token: Optional[str] = None + found_event_ids: List[str] = [] encoded_key = urllib.parse.quote_plus("👍".encode()) for _ in range(20): from_token = "" @@ -677,24 +677,23 @@ class RelationsTestCase(unittest.HomeserverTestCase): def _send_relation( self, - relation_type, - event_type, - key=None, + relation_type: str, + event_type: str, + key: Optional[str] = None, content: Optional[dict] = None, - access_token=None, - parent_id=None, - ): + access_token: Optional[str] = None, + parent_id: Optional[str] = None, + ) -> FakeChannel: """Helper function to send a relation pointing at `self.parent_id` Args: - relation_type (str): One of `RelationTypes` - event_type (str): The type of the event to create - parent_id (str): The event_id this relation relates to. If None, then self.parent_id - key (str|None): The aggregation key used for m.annotation relation - type. - content(dict|None): The content of the created event. - access_token (str|None): The access token used to send the relation, - defaults to `self.user_token` + relation_type: One of `RelationTypes` + event_type: The type of the event to create + key: The aggregation key used for m.annotation relation type. + content: The content of the created event. + access_token: The access token used to send the relation, defaults + to `self.user_token` + parent_id: The event_id this relation relates to. If None, then self.parent_id Returns: FakeChannel @@ -712,12 +711,12 @@ class RelationsTestCase(unittest.HomeserverTestCase): "POST", "/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s" % (self.room, original_id, relation_type, event_type, query), - json.dumps(content or {}).encode("utf-8"), + content or {}, access_token=access_token, ) return channel - def _create_user(self, localpart): + def _create_user(self, localpart: str) -> Tuple[str, str]: user_id = self.register_user(localpart, "abc123") access_token = self.login(localpart, "abc123") diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py index 38ac9be113..531f09c48b 100644 --- a/tests/rest/client/test_third_party_rules.py +++ b/tests/rest/client/test_third_party_rules.py @@ -12,25 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. import threading -from typing import Dict +from typing import TYPE_CHECKING, Dict, Optional, Tuple from unittest.mock import Mock from synapse.api.constants import EventTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.events.third_party_rules import load_legacy_third_party_event_rules -from synapse.module_api import ModuleApi from synapse.rest import admin from synapse.rest.client import login, room -from synapse.types import Requester, StateMap +from synapse.types import JsonDict, Requester, StateMap from synapse.util.frozenutils import unfreeze from tests import unittest +if TYPE_CHECKING: + from synapse.module_api import ModuleApi + thread_local = threading.local() class LegacyThirdPartyRulesTestModule: - def __init__(self, config: Dict, module_api: ModuleApi): + def __init__(self, config: Dict, module_api: "ModuleApi"): # keep a record of the "current" rules module, so that the test can patch # it if desired. thread_local.rules_module = self @@ -50,7 +53,7 @@ class LegacyThirdPartyRulesTestModule: class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): - def __init__(self, config: Dict, module_api: ModuleApi): + def __init__(self, config: Dict, module_api: "ModuleApi"): super().__init__(config, module_api) def on_create_room( @@ -60,7 +63,7 @@ class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule): class LegacyChangeEvents(LegacyThirdPartyRulesTestModule): - def __init__(self, config: Dict, module_api: ModuleApi): + def __init__(self, config: Dict, module_api: "ModuleApi"): super().__init__(config, module_api) async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]): @@ -136,6 +139,47 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase): ) self.assertEquals(channel.result["code"], b"403", channel.result) + def test_third_party_rules_workaround_synapse_errors_pass_through(self): + """ + Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042 + is functional: that SynapseErrors are passed through from check_event_allowed + and bubble up to the web resource. + + NEW MODULES SHOULD NOT MAKE USE OF THIS WORKAROUND! + This is a temporary workaround! + """ + + class NastyHackException(SynapseError): + def error_dict(self): + """ + This overrides SynapseError's `error_dict` to nastily inject + JSON into the error response. + """ + result = super().error_dict() + result["nasty"] = "very" + return result + + # add a callback that will raise our hacky exception + async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]: + raise NastyHackException(429, "message") + + self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check] + + # Make a request + channel = self.make_request( + "PUT", + "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id, + {}, + access_token=self.tok, + ) + # Check the error code + self.assertEquals(channel.result["code"], b"429", channel.result) + # Check the JSON body has had the `nasty` key injected + self.assertEqual( + channel.json_body, + {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"}, + ) + def test_cannot_modify_event(self): """cannot accidentally modify an event before it is persisted""" diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py new file mode 100644 index 0000000000..09504a485f --- /dev/null +++ b/tests/rest/media/v1/test_filepath.py @@ -0,0 +1,238 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.rest.media.v1.filepath import MediaFilePaths + +from tests import unittest + + +class MediaFilePathsTestCase(unittest.TestCase): + def setUp(self): + super().setUp() + + self.filepaths = MediaFilePaths("/media_store") + + def test_local_media_filepath(self): + """Test local media paths""" + self.assertEqual( + self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), + "local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + self.assertEqual( + self.filepaths.local_media_filepath("GerZNDnDZVjsOtardLuwfIBg"), + "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_local_media_thumbnail(self): + """Test local media thumbnail paths""" + self.assertEqual( + self.filepaths.local_media_thumbnail_rel( + "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" + ), + "local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + self.assertEqual( + self.filepaths.local_media_thumbnail( + "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" + ), + "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + + def test_local_media_thumbnail_dir(self): + """Test local media thumbnail directory paths""" + self.assertEqual( + self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"), + "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_remote_media_filepath(self): + """Test remote media paths""" + self.assertEqual( + self.filepaths.remote_media_filepath_rel( + "example.com", "GerZNDnDZVjsOtardLuwfIBg" + ), + "remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + self.assertEqual( + self.filepaths.remote_media_filepath( + "example.com", "GerZNDnDZVjsOtardLuwfIBg" + ), + "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_remote_media_thumbnail(self): + """Test remote media thumbnail paths""" + self.assertEqual( + self.filepaths.remote_media_thumbnail_rel( + "example.com", + "GerZNDnDZVjsOtardLuwfIBg", + 800, + 600, + "image/jpeg", + "scale", + ), + "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + self.assertEqual( + self.filepaths.remote_media_thumbnail( + "example.com", + "GerZNDnDZVjsOtardLuwfIBg", + 800, + 600, + "image/jpeg", + "scale", + ), + "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + + def test_remote_media_thumbnail_legacy(self): + """Test old-style remote media thumbnail paths""" + self.assertEqual( + self.filepaths.remote_media_thumbnail_rel_legacy( + "example.com", "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg" + ), + "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg", + ) + + def test_remote_media_thumbnail_dir(self): + """Test remote media thumbnail directory paths""" + self.assertEqual( + self.filepaths.remote_media_thumbnail_dir( + "example.com", "GerZNDnDZVjsOtardLuwfIBg" + ), + "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_url_cache_filepath(self): + """Test URL cache paths""" + self.assertEqual( + self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"), + "url_cache/2020-01-02/GerZNDnDZVjsOtar", + ) + self.assertEqual( + self.filepaths.url_cache_filepath("2020-01-02_GerZNDnDZVjsOtar"), + "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar", + ) + + def test_url_cache_filepath_legacy(self): + """Test old-style URL cache paths""" + self.assertEqual( + self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"), + "url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + self.assertEqual( + self.filepaths.url_cache_filepath("GerZNDnDZVjsOtardLuwfIBg"), + "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_url_cache_filepath_dirs_to_delete(self): + """Test URL cache cleanup paths""" + self.assertEqual( + self.filepaths.url_cache_filepath_dirs_to_delete( + "2020-01-02_GerZNDnDZVjsOtar" + ), + ["/media_store/url_cache/2020-01-02"], + ) + + def test_url_cache_filepath_dirs_to_delete_legacy(self): + """Test old-style URL cache cleanup paths""" + self.assertEqual( + self.filepaths.url_cache_filepath_dirs_to_delete( + "GerZNDnDZVjsOtardLuwfIBg" + ), + [ + "/media_store/url_cache/Ge/rZ", + "/media_store/url_cache/Ge", + ], + ) + + def test_url_cache_thumbnail(self): + """Test URL cache thumbnail paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_rel( + "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale" + ), + "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", + ) + self.assertEqual( + self.filepaths.url_cache_thumbnail( + "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale" + ), + "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale", + ) + + def test_url_cache_thumbnail_legacy(self): + """Test old-style URL cache thumbnail paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_rel( + "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" + ), + "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + self.assertEqual( + self.filepaths.url_cache_thumbnail( + "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale" + ), + "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale", + ) + + def test_url_cache_thumbnail_directory(self): + """Test URL cache thumbnail directory paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_directory_rel( + "2020-01-02_GerZNDnDZVjsOtar" + ), + "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", + ) + self.assertEqual( + self.filepaths.url_cache_thumbnail_directory("2020-01-02_GerZNDnDZVjsOtar"), + "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", + ) + + def test_url_cache_thumbnail_directory_legacy(self): + """Test old-style URL cache thumbnail directory paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_directory_rel( + "GerZNDnDZVjsOtardLuwfIBg" + ), + "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + self.assertEqual( + self.filepaths.url_cache_thumbnail_directory("GerZNDnDZVjsOtardLuwfIBg"), + "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", + ) + + def test_url_cache_thumbnail_dirs_to_delete(self): + """Test URL cache thumbnail cleanup paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_dirs_to_delete( + "2020-01-02_GerZNDnDZVjsOtar" + ), + [ + "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar", + "/media_store/url_cache_thumbnails/2020-01-02", + ], + ) + + def test_url_cache_thumbnail_dirs_to_delete_legacy(self): + """Test old-style URL cache thumbnail cleanup paths""" + self.assertEqual( + self.filepaths.url_cache_thumbnail_dirs_to_delete( + "GerZNDnDZVjsOtardLuwfIBg" + ), + [ + "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg", + "/media_store/url_cache_thumbnails/Ge/rZ", + "/media_store/url_cache_thumbnails/Ge", + ], + ) diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py new file mode 100644 index 0000000000..048d0ca44a --- /dev/null +++ b/tests/rest/media/v1/test_oembed.py @@ -0,0 +1,51 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest.media.v1.oembed import OEmbedProvider +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class OEmbedTests(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer): + self.oembed = OEmbedProvider(homeserver) + + def parse_response(self, response: JsonDict): + return self.oembed.parse_oembed_response( + "https://test", json.dumps(response).encode("utf-8") + ) + + def test_version(self): + """Accept versions that are similar to 1.0 as a string or int (or missing).""" + for version in ("1.0", 1.0, 1): + result = self.parse_response({"version": version, "type": "link"}) + # An empty Open Graph response is an error, ensure the URL is included. + self.assertIn("og:url", result.open_graph_result) + + # A missing version should be treated as 1.0. + result = self.parse_response({"type": "link"}) + self.assertIn("og:url", result.open_graph_result) + + # Invalid versions should be rejected. + for version in ("2.0", "1", 1.1, 0, None, {}, []): + result = self.parse_response({"version": version, "type": "link"}) + # An empty Open Graph response is an error, ensure the URL is included. + self.assertEqual({}, result.open_graph_result) diff --git a/tests/server.py b/tests/server.py index 64645651ce..103351b487 100644 --- a/tests/server.py +++ b/tests/server.py @@ -1,3 +1,17 @@ +# Copyright 2018-2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import json import logging from collections import deque @@ -27,9 +41,10 @@ from twisted.python.failure import Failure from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock from twisted.web.http_headers import Headers from twisted.web.resource import IResource -from twisted.web.server import Site +from twisted.web.server import Request, Site from synapse.http.site import SynapseRequest +from synapse.types import JsonDict from synapse.util import Clock from tests.utils import setup_test_homeserver as _sth @@ -198,14 +213,14 @@ class FakeSite: def make_request( reactor, site: Union[Site, FakeSite], - method, - path, - content=b"", - access_token=None, - request=SynapseRequest, - shorthand=True, - federation_auth_origin=None, - content_is_form=False, + method: Union[bytes, str], + path: Union[bytes, str], + content: Union[bytes, str, JsonDict] = b"", + access_token: Optional[str] = None, + request: Request = SynapseRequest, + shorthand: bool = True, + federation_auth_origin: Optional[bytes] = None, + content_is_form: bool = False, await_result: bool = True, custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] @@ -218,26 +233,23 @@ def make_request( Returns the fake Channel object which records the response to the request. Args: + reactor: site: The twisted Site to use to render the request - - method (bytes/unicode): The HTTP request method ("verb"). - path (bytes/unicode): The HTTP path, suitably URL encoded (e.g. - escaped UTF-8 & spaces and such). - content (bytes or dict): The body of the request. JSON-encoded, if - a dict. + method: The HTTP request method ("verb"). + path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such). + content: The body of the request. JSON-encoded, if a str of bytes. + access_token: The access token to add as authorization for the request. + request: The request class to create. shorthand: Whether to try and be helpful and prefix the given URL - with the usual REST API path, if it doesn't contain it. - federation_auth_origin (bytes|None): if set to not-None, we will add a fake + with the usual REST API path, if it doesn't contain it. + federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. - - custom_headers: (name, value) pairs to add as request headers - await_result: whether to wait for the request to complete rendering. If true, will pump the reactor until the the renderer tells the channel the request is finished. - + custom_headers: (name, value) pairs to add as request headers client_ip: The IP to use as the requesting IP. Useful for testing ratelimiting. diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index be3ed64f5e..37cf7bb232 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Set, Tuple +from typing import Any, Dict, Set, Tuple from unittest import mock from unittest.mock import Mock, patch @@ -42,18 +42,7 @@ class GetUserDirectoryTables: def __init__(self, store: DataStore): self.store = store - def _compress_shared( - self, shared: List[Dict[str, str]] - ) -> Set[Tuple[str, str, str]]: - """ - Compress a list of users who share rooms dicts to a list of tuples. - """ - r = set() - for i in shared: - r.add((i["user_id"], i["other_user_id"], i["room_id"])) - return r - - async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: + async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]: """Fetch the entire `users_in_public_rooms` table. Returns a list of tuples (user_id, room_id) where room_id is public and @@ -63,24 +52,27 @@ class GetUserDirectoryTables: "users_in_public_rooms", None, ("user_id", "room_id") ) - retval = [] + retval = set() for i in r: - retval.append((i["user_id"], i["room_id"])) + retval.add((i["user_id"], i["room_id"])) return retval - async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]: + async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]: """Fetch the entire `users_who_share_private_rooms` table. - Returns a dict containing "user_id", "other_user_id" and "room_id" keys. - The dicts can be flattened to Tuples with the `_compress_shared` method. - (This seems a little awkward---maybe we could clean this up.) + Returns a set of tuples (user_id, other_user_id, room_id) corresponding + to the rows of `users_who_share_private_rooms`. """ - return await self.store.db_pool.simple_select_list( + rows = await self.store.db_pool.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], ) + rv = set() + for row in rows: + rv.add((row["user_id"], row["other_user_id"], row["room_id"])) + return rv async def get_users_in_user_directory(self) -> Set[str]: """Fetch the set of users in the `user_directory` table. @@ -113,6 +105,16 @@ class GetUserDirectoryTables: for row in rows } + async def get_tables( + self, + ) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]: + """Multiple tests want to inspect these tables, so expose them together.""" + return ( + await self.get_users_in_user_directory(), + await self.get_users_in_public_rooms(), + await self.get_users_who_share_private_rooms(), + ) + class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): """Ensure that rebuilding the directory writes the correct data to the DB. @@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): ) # Nothing updated yet - self.assertEqual(shares_private, []) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, set()) + self.assertEqual(public_users, set()) # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False @@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): # Do the initial population of the user directory via the background update self._purge_and_rebuild_user_dir() - shares_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - public_users = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), {(u1, room), (u2, room)}) - + self.assertEqual(in_public, {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u3, private_room), (u3, u1, private_room)}, - ) - + self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)}) # All three should have entries in the directory - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) self.assertEqual(users, {u1, u2, u3}) # The next four tests (test_population_excludes_*) all set up @@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): self, normal_user: str, public_room: str, private_room: str ) -> None: # After rebuilding the directory, we should only see the normal user. - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - self.assertEqual(users, {normal_user}) - in_public_rooms = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) - self.assertEqual(set(in_public_rooms), {(normal_user, public_room)}) - in_private_rooms = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - self.assertEqual(in_private_rooms, []) + self.assertEqual(users, {normal_user}) + self.assertEqual(in_public, {(normal_user, public_room)}) + self.assertEqual(in_private, set()) def test_population_excludes_support_user(self) -> None: # Create a normal and support user. diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py index cf407c51cf..e2c506e5a4 100644 --- a/tests/test_event_auth.py +++ b/tests/test_event_auth.py @@ -24,6 +24,47 @@ from synapse.types import JsonDict, get_domain_from_id class EventAuthTestCase(unittest.TestCase): + def test_rejected_auth_events(self): + """ + Events that refer to rejected events in their auth events are rejected + """ + creator = "@creator:example.com" + auth_events = [ + _create_event(creator), + _join_event(creator), + ] + + # creator should be able to send state + event_auth.check_auth_rules_for_event( + RoomVersions.V9, + _random_state_event(creator), + auth_events, + ) + + # ... but a rejected join_rules event should cause it to be rejected + rejected_join_rules = _join_rules_event(creator, "public") + rejected_join_rules.rejected_reason = "stinky" + auth_events.append(rejected_join_rules) + + self.assertRaises( + AuthError, + event_auth.check_auth_rules_for_event, + RoomVersions.V9, + _random_state_event(creator), + auth_events, + ) + + # ... even if there is *also* a good join rules + auth_events.append(_join_rules_event(creator, "public")) + + self.assertRaises( + AuthError, + event_auth.check_auth_rules_for_event, + RoomVersions.V9, + _random_state_event(creator), + auth_events, + ) + def test_random_users_cannot_send_state_before_first_pl(self): """ Check that, before the first PL lands, the creator is the only user @@ -31,11 +72,11 @@ class EventAuthTestCase(unittest.TestCase): """ creator = "@creator:example.com" joiner = "@joiner:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.member", joiner): _join_event(joiner), - } + auth_events = [ + _create_event(creator), + _join_event(creator), + _join_event(joiner), + ] # creator should be able to send state event_auth.check_auth_rules_for_event( @@ -62,15 +103,15 @@ class EventAuthTestCase(unittest.TestCase): pleb = "@joiner:example.com" king = "@joiner2:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.power_levels", ""): _power_levels_event( + auth_events = [ + _create_event(creator), + _join_event(creator), + _power_levels_event( creator, {"state_default": "30", "users": {pleb: "29", king: "30"}} ), - ("m.room.member", pleb): _join_event(pleb), - ("m.room.member", king): _join_event(king), - } + _join_event(pleb), + _join_event(king), + ] # pleb should not be able to send state self.assertRaises( @@ -92,10 +133,10 @@ class EventAuthTestCase(unittest.TestCase): """Alias events have special behavior up through room version 6.""" creator = "@creator:example.com" other = "@other:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - } + auth_events = [ + _create_event(creator), + _join_event(creator), + ] # creator should be able to send aliases event_auth.check_auth_rules_for_event( @@ -131,10 +172,10 @@ class EventAuthTestCase(unittest.TestCase): """After MSC2432, alias events have no special behavior.""" creator = "@creator:example.com" other = "@other:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - } + auth_events = [ + _create_event(creator), + _join_event(creator), + ] # creator should be able to send aliases event_auth.check_auth_rules_for_event( @@ -170,14 +211,14 @@ class EventAuthTestCase(unittest.TestCase): creator = "@creator:example.com" pleb = "@joiner:example.com" - auth_events = { - ("m.room.create", ""): _create_event(creator), - ("m.room.member", creator): _join_event(creator), - ("m.room.power_levels", ""): _power_levels_event( + auth_events = [ + _create_event(creator), + _join_event(creator), + _power_levels_event( creator, {"state_default": "30", "users": {pleb: "30"}} ), - ("m.room.member", pleb): _join_event(pleb), - } + _join_event(pleb), + ] # pleb should be able to modify the notifications power level. event_auth.check_auth_rules_for_event( @@ -211,7 +252,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user cannot be force-joined to a room. @@ -219,7 +260,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), - auth_events, + auth_events.values(), ) # Banned should be rejected. @@ -228,7 +269,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user who left can re-join. @@ -236,7 +277,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can send a join if they're in the room. @@ -244,7 +285,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can accept an invite. @@ -254,7 +295,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) def test_join_rules_invite(self): @@ -275,7 +316,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user cannot be force-joined to a room. @@ -283,7 +324,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _member_event(pleb, "join", sender=creator), - auth_events, + auth_events.values(), ) # Banned should be rejected. @@ -292,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user who left cannot re-join. @@ -301,7 +342,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can send a join if they're in the room. @@ -309,7 +350,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can accept an invite. @@ -319,7 +360,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) def test_join_rules_msc3083_restricted(self): @@ -347,7 +388,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V6, _join_event(pleb), - auth_events, + auth_events.values(), ) # A properly formatted join event should work. @@ -360,7 +401,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, - auth_events, + auth_events.values(), ) # A join issued by a specific user works (i.e. the power level checks @@ -380,7 +421,7 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@inviter:foo.test" }, ), - pl_auth_events, + pl_auth_events.values(), ) # A join which is missing an authorised server is rejected. @@ -388,7 +429,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), - auth_events, + auth_events.values(), ) # An join authorised by a user who is not in the room is rejected. @@ -405,7 +446,7 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@other:example.com" }, ), - auth_events, + auth_events.values(), ) # A user cannot be force-joined to a room. (This uses an event which @@ -421,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase): EventContentFields.AUTHORISING_USER: "@inviter:foo.test" }, ), - auth_events, + auth_events.values(), ) # Banned should be rejected. @@ -430,7 +471,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, - auth_events, + auth_events.values(), ) # A user who left can re-join. @@ -438,7 +479,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, authorised_join_event, - auth_events, + auth_events.values(), ) # A user can send a join if they're in the room. (This doesn't need to @@ -447,7 +488,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), - auth_events, + auth_events.values(), ) # A user can accept an invite. (This doesn't need to be authorised since @@ -458,7 +499,7 @@ class EventAuthTestCase(unittest.TestCase): event_auth.check_auth_rules_for_event( RoomVersions.V8, _join_event(pleb), - auth_events, + auth_events.values(), ) @@ -473,6 +514,7 @@ def _create_event(user_id: str) -> EventBase: "room_id": TEST_ROOM_ID, "event_id": _get_event_id(), "type": "m.room.create", + "state_key": "", "sender": user_id, "content": {"creator": user_id}, } diff --git a/tests/test_preview.py b/tests/test_preview.py index 09e017b4d9..9a576f9a4e 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -15,7 +15,7 @@ from synapse.rest.media.v1.preview_url_resource import ( _calc_og, decode_body, - get_html_media_encoding, + get_html_media_encodings, summarize_paragraphs, ) @@ -159,7 +159,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -175,7 +175,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -194,7 +194,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual( @@ -216,7 +216,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -230,7 +230,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -245,7 +245,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) @@ -260,7 +260,7 @@ class CalcOgTestCase(unittest.TestCase): </html> """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) @@ -268,13 +268,13 @@ class CalcOgTestCase(unittest.TestCase): def test_empty(self): """Test a body with no data in it.""" html = b"" - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) def test_no_tree(self): """A valid body with no tree in it.""" html = b"\x00" - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") self.assertIsNone(tree) def test_invalid_encoding(self): @@ -287,7 +287,7 @@ class CalcOgTestCase(unittest.TestCase): </body> </html> """ - tree = decode_body(html, "invalid-encoding") + tree = decode_body(html, "http://example.com/test.html", "invalid-encoding") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."}) @@ -302,15 +302,29 @@ class CalcOgTestCase(unittest.TestCase): </body> </html> """ - tree = decode_body(html) + tree = decode_body(html, "http://example.com/test.html") og = _calc_og(tree, "http://example.com/test.html") self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."}) + def test_windows_1252(self): + """A body which uses cp1252, but doesn't declare that.""" + html = b""" + <html> + <head><title>\xf3</title></head> + <body> + Some text. + </body> + </html> + """ + tree = decode_body(html, "http://example.com/test.html") + og = _calc_og(tree, "http://example.com/test.html") + self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."}) + class MediaEncodingTestCase(unittest.TestCase): def test_meta_charset(self): """A character encoding is found via the meta tag.""" - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" <html> <head><meta charset="ascii"> @@ -319,10 +333,10 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "ascii") + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) # A less well-formed version. - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" <html> <head>< meta charset = ascii> @@ -331,11 +345,11 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "ascii") + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) def test_meta_charset_underscores(self): """A character encoding contains underscore.""" - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" <html> <head><meta charset="Shift_JIS"> @@ -344,11 +358,11 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "Shift_JIS") + self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"]) def test_xml_encoding(self): """A character encoding is found via the meta tag.""" - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" <?xml version="1.0" encoding="ascii"?> <html> @@ -356,11 +370,11 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "ascii") + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) def test_meta_xml_encoding(self): """Meta tags take precedence over XML encoding.""" - encoding = get_html_media_encoding( + encodings = get_html_media_encodings( b""" <?xml version="1.0" encoding="ascii"?> <html> @@ -370,7 +384,7 @@ class MediaEncodingTestCase(unittest.TestCase): """, "text/html", ) - self.assertEqual(encoding, "UTF-16") + self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"]) def test_content_type(self): """A character encoding is found via the Content-Type header.""" @@ -384,10 +398,37 @@ class MediaEncodingTestCase(unittest.TestCase): 'text/html; charset=ascii";', ) for header in headers: - encoding = get_html_media_encoding(b"", header) - self.assertEqual(encoding, "ascii") + encodings = get_html_media_encodings(b"", header) + self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"]) def test_fallback(self): """A character encoding cannot be found in the body or header.""" - encoding = get_html_media_encoding(b"", "text/html") - self.assertEqual(encoding, "utf-8") + encodings = get_html_media_encodings(b"", "text/html") + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + def test_duplicates(self): + """Ensure each encoding is only attempted once.""" + encodings = get_html_media_encodings( + b""" + <?xml version="1.0" encoding="utf8"?> + <html> + <head><meta charset="UTF-8"> + </head> + </html> + """, + 'text/html; charset="UTF_8"', + ) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) + + def test_unknown_invalid(self): + """A character encoding should be ignored if it is unknown or invalid.""" + encodings = get_html_media_encodings( + b""" + <html> + <head><meta charset="invalid"> + </head> + </html> + """, + 'text/html; charset="invalid"', + ) + self.assertEqual(list(encodings), ["utf-8", "cp1252"]) diff --git a/tests/unittest.py b/tests/unittest.py index 81c1a9e9d2..a9b60b7eeb 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -46,7 +46,7 @@ from synapse.logging.context import ( set_current_context, ) from synapse.server import HomeServer -from synapse.types import UserID, create_requester +from synapse.types import JsonDict, UserID, create_requester from synapse.util import Clock from synapse.util.httpresourcetree import create_resource_tree from synapse.util.ratelimitutils import FederationRateLimiter @@ -401,7 +401,7 @@ class HomeserverTestCase(TestCase): self, method: Union[bytes, str], path: Union[bytes, str], - content: Union[bytes, dict] = b"", + content: Union[bytes, str, JsonDict] = b"", access_token: Optional[str] = None, request: Type[T] = SynapseRequest, shorthand: bool = True, |