diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 936ebf3dde..e1557566e4 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -23,6 +23,7 @@ from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.types import create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -30,6 +31,10 @@ from tests import unittest
logger = logging.getLogger(__name__)
+def generate_fake_event_id() -> str:
+ return "$fake_" + random_string(43)
+
+
class FederationTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -198,6 +203,65 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
+ def test_backfill_with_many_backward_extremities(self):
+ """
+ Check that we can backfill with many backward extremities.
+ The goal is to make sure that when we only use a portion
+ of backwards extremities(the magic number is more than 5),
+ no errors are thrown.
+
+ Regression test, see #11027
+ """
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ requester = create_requester(user_id)
+
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ ev1 = self.helper.send(room_id, "first message", tok=tok)
+
+ # Create "many" backward extremities. The magic number we're trying to
+ # create more than is 5 which corresponds to the number of backward
+ # extremities we slice off in `_maybe_backfill_inner`
+ for _ in range(0, 8):
+ event_handler = self.hs.get_event_creation_handler()
+ event, context = self.get_success(
+ event_handler.create_event(
+ requester,
+ {
+ "type": "m.room.message",
+ "content": {
+ "msgtype": "m.text",
+ "body": "message connected to fake event",
+ },
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=[
+ ev1["event_id"],
+ # We're creating an backward extremity each time thanks
+ # to this fake event
+ generate_fake_event_id(),
+ ],
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(requester, event, context)
+ )
+
+ current_depth = 1
+ limit = 100
+ with LoggingContext("receive_pdu"):
+ # Make sure backfill still works
+ d = run_in_background(
+ self.hs.get_federation_handler().maybe_backfill,
+ room_id,
+ current_depth,
+ limit,
+ )
+ self.get_success(d)
+
def test_backfill_floating_outlier_membership_auth(self):
"""
As the local homeserver, check that we can properly process a federated
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 38e6d9f536..7dd4a5a367 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,6 +20,8 @@ from unittest.mock import Mock
from twisted.internet import defer
import synapse
+from synapse.handlers.auth import load_legacy_password_auth_providers
+from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login
from synapse.types import JsonDict
@@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
mock_password_provider = Mock()
-class PasswordOnlyAuthProvider:
- """A password_provider which only implements `check_password`."""
+class LegacyPasswordOnlyAuthProvider:
+ """A legacy password_provider which only implements `check_password`."""
@staticmethod
def parse_config(self):
@@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
return mock_password_provider.check_password(*args)
-class CustomAuthProvider:
- """A password_provider which implements a custom login type."""
+class LegacyCustomAuthProvider:
+ """A legacy password_provider which implements a custom login type."""
@staticmethod
def parse_config(self):
@@ -67,7 +69,23 @@ class CustomAuthProvider:
return mock_password_provider.check_auth(*args)
-class PasswordCustomAuthProvider:
+class CustomAuthProvider:
+ """A module which registers password_auth_provider callbacks for a custom login type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, api: ModuleApi):
+ api.register_password_auth_provider_callbacks(
+ auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+ )
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+class LegacyPasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well
as a custom type."""
@@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
return mock_password_provider.check_auth(*args)
-def providers_config(*providers: Type[Any]) -> dict:
- """Returns a config dict that will enable the given password auth providers"""
+class PasswordCustomAuthProvider:
+ """A module which registers password_auth_provider callbacks for a custom login type.
+ as well as a password login"""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, api: ModuleApi):
+ api.register_password_auth_provider_callbacks(
+ auth_checkers={
+ ("test.login_type", ("test_field",)): self.check_auth,
+ ("m.login.password", ("password",)): self.check_auth,
+ },
+ )
+ pass
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+ def check_pass(self, *args):
+ return mock_password_provider.check_password(*args)
+
+
+def legacy_providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given legacy password auth providers"""
return {
"password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
@@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
}
+def providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given modules"""
+ return {
+ "modules": [
+ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+ for provider in providers
+ ]
+ }
+
+
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_password_only_auth_provider_login(self):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ # Load the modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+ load_legacy_password_auth_providers(hs)
+
+ return hs
+
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_password_only_auth_progiver_login_legacy(self):
+ self.password_only_auth_provider_login_test_body()
+
+ def password_only_auth_provider_login_test_body(self):
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"@ USER🙂NAME :test", " pASS😢word "
)
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_password_only_auth_provider_ui_auth(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_ui_auth_legacy(self):
+ self.password_only_auth_provider_ui_auth_test_body()
+
+ def password_only_auth_provider_ui_auth_test_body(self):
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
@@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_local_user_fallback_login(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_local_user_fallback_login_legacy(self):
+ self.local_user_fallback_login_test_body()
+
+ def local_user_fallback_login_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_local_user_fallback_ui_auth(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_local_user_fallback_ui_auth_legacy(self):
+ self.local_user_fallback_ui_auth_test_body()
+
+ def local_user_fallback_ui_auth_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_login(self):
+ def test_no_local_user_fallback_login_legacy(self):
+ self.no_local_user_fallback_login_test_body()
+
+ def no_local_user_fallback_login_test_body(self):
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
@@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_ui_auth(self):
+ def test_no_local_user_fallback_ui_auth_legacy(self):
+ self.no_local_user_fallback_ui_auth_test_body()
+
+ def no_local_user_fallback_ui_auth_test_body(self):
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
@@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"enabled": False},
}
)
- def test_password_auth_disabled(self):
+ def test_password_auth_disabled_legacy(self):
+ self.password_auth_disabled_test_body()
+
+ def password_auth_disabled_test_body(self):
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
@@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called()
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_login_legacy(self):
+ self.custom_auth_provider_login_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
+ self.custom_auth_provider_login_test_body()
+
+ def custom_auth_provider_login_test_body(self):
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
@@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
- mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", None)
+ )
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
@@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
- "@ MALFORMED! :bz"
+ ("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
@@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
)
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_ui_auth_legacy(self):
+ self.custom_auth_provider_ui_auth_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
+ self.custom_auth_provider_ui_auth_test_body()
+
+ def custom_auth_provider_ui_auth_test_body(self):
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
@@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
- mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", None)
+ )
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
@@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
- "@localuser:test"
+ ("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
@@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"localuser", "test.login_type", {"test_field": "foo"}
)
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_callback_legacy(self):
+ self.custom_auth_provider_callback_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
+ self.custom_auth_provider_callback_test_body()
+
+ def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
mock_password_provider.check_auth.return_value = defer.succeed(
@@ -411,9 +519,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertIn(p, call_args[0])
@override_config(
+ {
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_legacy(self):
+ self.custom_auth_password_disabled_test_body()
+
+ @override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
+ self.custom_auth_password_disabled_test_body()
+
+ def custom_auth_password_disabled_test_body(self):
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
@@ -427,11 +547,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+ self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+ @override_config(
+ {
**providers_config(CustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled(self):
+ self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+ def custom_auth_password_disabled_localdb_enabled_test_body(self):
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -450,11 +582,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_login_legacy(self):
+ self.password_custom_auth_password_disabled_login_test_body()
+
+ @override_config(
+ {
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self):
+ self.password_custom_auth_password_disabled_login_test_body()
+
+ def password_custom_auth_password_disabled_login_test_body(self):
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
@@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(
+ {
+ **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+ self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
{
@@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
+ self.password_custom_auth_password_disabled_ui_auth_test_body()
+
+ def password_custom_auth_password_disabled_ui_auth_test_body(self):
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
- "@localuser:test"
+ ("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
@@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"Password login has been disabled.", channel.json_body["error"]
)
mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.check_password.assert_not_called()
mock_password_provider.reset_mock()
# successful auth
@@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"}
)
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(
+ {
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_no_local_user_fallback_legacy(self):
+ self.custom_auth_no_local_user_fallback_test_body()
@override_config(
{
@@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_custom_auth_no_local_user_fallback(self):
+ self.custom_auth_no_local_user_fallback_test_body()
+
+ def custom_auth_no_local_user_fallback_test_body(self):
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 0120b4688b..b9ad92b977 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
tok=alice_token,
)
- users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
- in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
- in_private = self.get_success(
- self.user_dir_helper.get_users_who_share_private_rooms()
+ # The user directory should reflect the room memberships above.
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
)
-
self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)})
self.assertEqual(
- set(in_public), {(alice, public), (bob, public), (alice, public2)}
- )
- self.assertEqual(
- self.user_dir_helper._compress_shared(in_private),
+ in_private,
{(alice, bob, private), (bob, alice, private)},
)
@@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
self.assertEqual(set(in_public), {(user1, room), (user2, room)})
+ def test_excludes_users_when_making_room_public(self) -> None:
+ # Create a regular user and a support user.
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ support = "@support1:test"
+ self.get_success(
+ self.store.register_user(
+ user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+ )
+ )
+
+ # Make a public and private room containing Alice and the support user
+ public, initially_private = self._create_rooms_and_inject_memberships(
+ alice, alice_token, support
+ )
+ self._check_only_one_user_in_directory(alice, public)
+
+ # Alice makes the private room public.
+ self.helper.send_state(
+ initially_private,
+ "m.room.join_rules",
+ {"join_rule": "public"},
+ tok=alice_token,
+ )
+
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice})
+ self.assertEqual(in_public, {(alice, public), (alice, initially_private)})
+ self.assertEqual(in_private, set())
+
+ def test_switching_from_private_to_public_to_private(self) -> None:
+ """Check we update the room sharing tables when switching a room
+ from private to public, then back again to private."""
+ # Alice and Bob share a private room.
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+ room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+ self.helper.invite(room, alice, bob, tok=alice_token)
+ self.helper.join(room, bob, tok=bob_token)
+
+ # The user directory should reflect this.
+ def check_user_dir_for_private_room() -> None:
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, set())
+ self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)})
+
+ check_user_dir_for_private_room()
+
+ # Alice makes the room public.
+ self.helper.send_state(
+ room,
+ "m.room.join_rules",
+ {"join_rule": "public"},
+ tok=alice_token,
+ )
+
+ # The user directory should be updated accordingly
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, {(alice, room), (bob, room)})
+ self.assertEqual(in_private, set())
+
+ # Alice makes the room private.
+ self.helper.send_state(
+ room,
+ "m.room.join_rules",
+ {"join_rule": "invite"},
+ tok=alice_token,
+ )
+
+ # The user directory should be updated accordingly
+ check_user_dir_for_private_room()
+
def _create_rooms_and_inject_memberships(
self, creator: str, token: str, joiner: str
) -> Tuple[str, str]:
@@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
return public_room, private_room
def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
- users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
- in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
- in_private = self.get_success(
- self.user_dir_helper.get_users_who_share_private_rooms()
- )
+ """Check that the user directory DB tables show that:
+ - only one user is in the user directory
+ - they belong to exactly one public room
+ - they don't share a private room with anyone.
+ """
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
self.assertEqual(users, {user})
- self.assertEqual(set(in_public), {(user, public)})
- self.assertEqual(in_private, [])
+ self.assertEqual(in_public, {(user, public)})
+ self.assertEqual(in_private, set())
def test_handle_local_profile_change_with_support_user(self) -> None:
support_user_id = "@support:test"
@@ -581,11 +662,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms()
)
- self.assertEqual(
- self.user_dir_helper._compress_shared(shares_private),
- {(u1, u2, room), (u2, u1, room)},
- )
- self.assertEqual(public_users, [])
+ self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+ self.assertEqual(public_users, set())
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -610,8 +688,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms()
)
- self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
- self.assertEqual(public_users, [])
+ self.assertEqual(shares_private, set())
+ self.assertEqual(public_users, set())
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -645,11 +723,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms()
)
- self.assertEqual(
- self.user_dir_helper._compress_shared(shares_private),
- {(u1, u2, room), (u2, u1, room)},
- )
- self.assertEqual(public_users, [])
+ self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+ self.assertEqual(public_users, set())
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -704,11 +779,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms()
)
- self.assertEqual(
- self.user_dir_helper._compress_shared(shares_private),
- {(u1, u2, room), (u2, u1, room)},
- )
- self.assertEqual(public_users, [])
+ self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+ self.assertEqual(public_users, set())
# Configure a spam checker.
spam_checker = self.hs.get_spam_checker()
@@ -740,8 +812,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
# No users share rooms
- self.assertEqual(public_users, [])
- self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
+ self.assertEqual(public_users, set())
+ self.assertEqual(shares_private, set())
# Despite not sharing a room, search_all_users means we get a search
# result.
@@ -842,6 +914,56 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.hs.get_storage().persistence.persist_event(event, context)
)
+ def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
+ """We've chosen to simplify the user directory's implementation by
+ always including local users. Ensure this invariant is maintained when
+ a local user
+ - leaves a room, and
+ - leaves the last room they're in which is visible to this server.
+
+ This is user-visible if the "search_all_users" config option is on: the
+ local user who left a room would no longer be searchable if this test fails!
+ """
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+
+ # Alice makes two public rooms, which Bob joins.
+ room1 = self.helper.create_room_as(alice, is_public=True, tok=alice_token)
+ room2 = self.helper.create_room_as(alice, is_public=True, tok=alice_token)
+ self.helper.join(room1, bob, tok=bob_token)
+ self.helper.join(room2, bob, tok=bob_token)
+
+ # The user directory tables are updated.
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(
+ in_public, {(alice, room1), (alice, room2), (bob, room1), (bob, room2)}
+ )
+ self.assertEqual(in_private, set())
+
+ # Alice leaves one room. She should still be in the directory.
+ self.helper.leave(room1, alice, tok=alice_token)
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, {(alice, room2), (bob, room1), (bob, room2)})
+ self.assertEqual(in_private, set())
+
+ # Alice leaves the other. She should still be in the directory.
+ self.helper.leave(room2, alice, tok=alice_token)
+ self.wait_for_background_updates()
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, {(bob, room1), (bob, room2)})
+ self.assertEqual(in_private, set())
+
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
servlets = [
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 6ed9e42173..c9e2754b09 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -14,14 +14,13 @@
import hashlib
import hmac
-import json
import os
import urllib.parse
from binascii import unhexlify
from typing import List, Optional
from unittest.mock import Mock, patch
-from parameterized import parameterized
+from parameterized import parameterized, parameterized_class
import synapse.rest.admin
from synapse.api.constants import UserTypes
@@ -104,8 +103,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# 59 seconds
self.reactor.advance(59)
- body = json.dumps({"nonce": nonce})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
@@ -113,7 +112,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# 61 seconds
self.reactor.advance(2)
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -129,18 +128,16 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(b"notthenonce\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual("HMAC incorrect", channel.json_body["error"])
def test_register_correct_nonce(self):
@@ -157,17 +154,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "user_type": UserTypes.SUPPORT,
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -183,22 +178,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob\x00abc123\x00admin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
# Now, try and reuse it
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("unrecognised nonce", channel.json_body["error"])
@@ -218,9 +211,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
# Nonce check
#
- # Must be present
- body = json.dumps({})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ # Must be an empty body present
+ channel = self.make_request("POST", self.url, {})
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("nonce must be specified", channel.json_body["error"])
@@ -230,29 +222,28 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
#
# Must be present
- body = json.dumps({"nonce": nonce()})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ channel = self.make_request("POST", self.url, {"nonce": nonce()})
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("username must be specified", channel.json_body["error"])
# Must be a string
- body = json.dumps({"nonce": nonce(), "username": 1234})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": 1234}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "abcd\u0000"})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "abcd\u0000"}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
# Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "a" * 1000})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a" * 1000}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid username", channel.json_body["error"])
@@ -262,29 +253,29 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
#
# Must be present
- body = json.dumps({"nonce": nonce(), "username": "a"})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a"}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("password must be specified", channel.json_body["error"])
# Must be a string
- body = json.dumps({"nonce": nonce(), "username": "a", "password": 1234})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a", "password": 1234}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Must not have null bytes
- body = json.dumps({"nonce": nonce(), "username": "a", "password": "abcd\u0000"})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a", "password": "abcd\u0000"}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
# Super long
- body = json.dumps({"nonce": nonce(), "username": "a", "password": "A" * 1000})
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {"nonce": nonce(), "username": "a", "password": "A" * 1000}
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid password", channel.json_body["error"])
@@ -294,15 +285,13 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
#
# Invalid user_type
- body = json.dumps(
- {
- "nonce": nonce(),
- "username": "a",
- "password": "1234",
- "user_type": "invalid",
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce(),
+ "username": "a",
+ "password": "1234",
+ "user_type": "invalid",
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual("Invalid user type", channel.json_body["error"])
@@ -320,10 +309,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob1\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {"nonce": nonce, "username": "bob1", "password": "abc123", "mac": want_mac}
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob1",
+ "password": "abc123",
+ "mac": want_mac,
+ }
+
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob1:test", channel.json_body["user_id"])
@@ -340,16 +333,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob2\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob2",
- "displayname": None,
- "password": "abc123",
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob2",
+ "displayname": None,
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob2:test", channel.json_body["user_id"])
@@ -366,22 +357,20 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob3\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob3",
- "displayname": "",
- "password": "abc123",
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob3",
+ "displayname": "",
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob3:test", channel.json_body["user_id"])
channel = self.make_request("GET", "/profile/@bob3:test/displayname")
- self.assertEqual(404, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(404, channel.code, msg=channel.json_body)
# set displayname
channel = self.make_request("GET", self.url)
@@ -391,16 +380,14 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
want_mac.update(nonce.encode("ascii") + b"\x00bob4\x00abc123\x00notadmin")
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob4",
- "displayname": "Bob's Name",
- "password": "abc123",
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob4",
+ "displayname": "Bob's Name",
+ "password": "abc123",
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob4:test", channel.json_body["user_id"])
@@ -440,17 +427,15 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
)
want_mac = want_mac.hexdigest()
- body = json.dumps(
- {
- "nonce": nonce,
- "username": "bob",
- "password": "abc123",
- "admin": True,
- "user_type": UserTypes.SUPPORT,
- "mac": want_mac,
- }
- )
- channel = self.make_request("POST", self.url, body.encode("utf8"))
+ body = {
+ "nonce": nonce,
+ "username": "bob",
+ "password": "abc123",
+ "admin": True,
+ "user_type": UserTypes.SUPPORT,
+ "mac": want_mac,
+ }
+ channel = self.make_request("POST", self.url, body)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual("@bob:test", channel.json_body["user_id"])
@@ -993,12 +978,11 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
"""
If parameter `erase` is not boolean, return an error
"""
- body = json.dumps({"erase": "False"})
channel = self.make_request(
"POST",
self.url,
- content=body.encode(encoding="utf_8"),
+ content={"erase": "False"},
access_token=self.admin_user_tok,
)
@@ -2201,7 +2185,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -2216,7 +2200,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
@@ -2359,7 +2343,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request("GET", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_no_admin(self):
@@ -2374,7 +2358,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_does_not_exist(self):
@@ -3073,7 +3057,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"""Try to login as a user without authentication."""
channel = self.make_request("POST", self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_not_admin(self):
@@ -3082,7 +3066,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
"POST", self.url, b"{}", access_token=self.other_user_tok
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
def test_send_event(self):
"""Test that sending event as a user works."""
@@ -3127,7 +3111,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
@@ -3160,7 +3144,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET", "devices", b"{}", access_token=self.other_user_tok
)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
def test_admin_logout_all(self):
"""Tests that the admin user calling `/logout/all` does expire the
@@ -3181,7 +3165,7 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
# The puppet token should no longer work
channel = self.make_request("GET", "devices", b"{}", access_token=puppet_token)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
# .. but the real user's tokens should still work
channel = self.make_request(
@@ -3242,6 +3226,13 @@ class UserTokenRestTestCase(unittest.HomeserverTestCase):
self.helper.join(room_id, user=self.other_user, tok=puppet_token)
+@parameterized_class(
+ ("url_prefix",),
+ [
+ ("/_synapse/admin/v1/whois/%s",),
+ ("/_matrix/client/r0/admin/whois/%s",),
+ ],
+)
class WhoisRestTestCase(unittest.HomeserverTestCase):
servlets = [
@@ -3254,21 +3245,14 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
- self.url1 = "/_synapse/admin/v1/whois/%s" % urllib.parse.quote(self.other_user)
- self.url2 = "/_matrix/client/r0/admin/whois/%s" % urllib.parse.quote(
- self.other_user
- )
+ self.url = self.url_prefix % self.other_user
def test_no_auth(self):
"""
Try to get information of an user without authentication.
"""
- channel = self.make_request("GET", self.url1, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("GET", self.url2, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ channel = self.make_request("GET", self.url, b"{}")
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -3280,38 +3264,21 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- self.url1,
- access_token=other_user2_token,
- )
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "GET",
- self.url2,
+ self.url,
access_token=other_user2_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
- url1 = "/_synapse/admin/v1/whois/@unknown_person:unknown_domain"
- url2 = "/_matrix/client/r0/admin/whois/@unknown_person:unknown_domain"
-
- channel = self.make_request(
- "GET",
- url1,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only whois a local user", channel.json_body["error"])
+ url = self.url_prefix % "@unknown_person:unknown_domain"
channel = self.make_request(
"GET",
- url2,
+ url,
access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
@@ -3323,16 +3290,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
"""
channel = self.make_request(
"GET",
- self.url1,
- access_token=self.admin_user_tok,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(self.other_user, channel.json_body["user_id"])
- self.assertIn("devices", channel.json_body)
-
- channel = self.make_request(
- "GET",
- self.url2,
+ self.url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -3347,16 +3305,7 @@ class WhoisRestTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- self.url1,
- access_token=other_user_token,
- )
- self.assertEqual(200, channel.code, msg=channel.json_body)
- self.assertEqual(self.other_user, channel.json_body["user_id"])
- self.assertIn("devices", channel.json_body)
-
- channel = self.make_request(
- "GET",
- self.url2,
+ self.url,
access_token=other_user_token,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
@@ -3388,7 +3337,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
Try to get information of an user without authentication.
"""
channel = self.make_request("POST", self.url)
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_requester_is_not_admin(self):
@@ -3398,7 +3347,7 @@ class ShadowBanRestTestCase(unittest.HomeserverTestCase):
other_user_token = self.login("user", "pass")
channel = self.make_request("POST", self.url, access_token=other_user_token)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
def test_user_is_not_local(self):
@@ -3447,84 +3396,41 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
% urllib.parse.quote(self.other_user)
)
- def test_no_auth(self):
+ @parameterized.expand(["GET", "POST", "DELETE"])
+ def test_no_auth(self, method: str):
"""
Try to get information of a user without authentication.
"""
- channel = self.make_request("GET", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("POST", self.url, b"{}")
+ channel = self.make_request(method, self.url, b"{}")
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
-
- channel = self.make_request("DELETE", self.url, b"{}")
-
- self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
- def test_requester_is_no_admin(self):
+ @parameterized.expand(["GET", "POST", "DELETE"])
+ def test_requester_is_no_admin(self, method: str):
"""
If the user is not a server admin, an error is returned.
"""
other_user_token = self.login("user", "pass")
channel = self.make_request(
- "GET",
- self.url,
- access_token=other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "POST",
- self.url,
- access_token=other_user_token,
- )
-
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
- self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
+ method,
self.url,
access_token=other_user_token,
)
- self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"])
+ self.assertEqual(403, channel.code, msg=channel.json_body)
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
- def test_user_does_not_exist(self):
+ @parameterized.expand(["GET", "POST", "DELETE"])
+ def test_user_does_not_exist(self, method: str):
"""
Tests that a lookup for a user that does not exist returns a 404
"""
url = "/_synapse/admin/v1/users/@unknown_person:test/override_ratelimit"
channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- channel = self.make_request(
- "POST",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(404, channel.code, msg=channel.json_body)
- self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
-
- channel = self.make_request(
- "DELETE",
+ method,
url,
access_token=self.admin_user_tok,
)
@@ -3532,7 +3438,14 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
- def test_user_is_not_local(self):
+ @parameterized.expand(
+ [
+ ("GET", "Can only look up local users"),
+ ("POST", "Only local users can be ratelimited"),
+ ("DELETE", "Only local users can be ratelimited"),
+ ]
+ )
+ def test_user_is_not_local(self, method: str, error_msg: str):
"""
Tests that a lookup for a user that is not a local returns a 400
"""
@@ -3541,35 +3454,13 @@ class RateLimitTestCase(unittest.HomeserverTestCase):
)
channel = self.make_request(
- "GET",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual("Can only look up local users", channel.json_body["error"])
-
- channel = self.make_request(
- "POST",
- url,
- access_token=self.admin_user_tok,
- )
-
- self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(
- "Only local users can be ratelimited", channel.json_body["error"]
- )
-
- channel = self.make_request(
- "DELETE",
+ method,
url,
access_token=self.admin_user_tok,
)
self.assertEqual(400, channel.code, msg=channel.json_body)
- self.assertEqual(
- "Only local users can be ratelimited", channel.json_body["error"]
- )
+ self.assertEqual(error_msg, channel.json_body["error"])
def test_invalid_parameter(self):
"""
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 02b5e9a8d0..3c7d49f0b4 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -13,15 +13,15 @@
# limitations under the License.
import itertools
-import json
-import urllib
-from typing import Optional
+import urllib.parse
+from typing import Dict, List, Optional, Tuple
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room
from tests import unittest
+from tests.server import FakeChannel
class RelationsTestCase(unittest.HomeserverTestCase):
@@ -34,16 +34,16 @@ class RelationsTestCase(unittest.HomeserverTestCase):
]
hijack_auth = False
- def make_homeserver(self, reactor, clock):
+ def default_config(self) -> dict:
# We need to enable msc1849 support for aggregations
- config = self.default_config()
+ config = super().default_config()
config["experimental_msc1849_support_enabled"] = True
# We enable frozen dicts as relations/edits change event contents, so we
# want to test that we don't modify the events in the caches.
config["use_frozen_dicts"] = True
- return self.setup_test_homeserver(config=config)
+ return config
def prepare(self, reactor, clock, hs):
self.user_id, self.user_token = self._create_user("alice")
@@ -146,8 +146,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
- prev_token = None
- found_event_ids = []
+ prev_token: Optional[str] = None
+ found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
if prev_token:
@@ -203,8 +203,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
idx += 1
idx %= len(access_tokens)
- prev_token = None
- found_groups = {}
+ prev_token: Optional[str] = None
+ found_groups: Dict[str, int] = {}
for _ in range(20):
from_token = ""
if prev_token:
@@ -270,8 +270,8 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
- prev_token = None
- found_event_ids = []
+ prev_token: Optional[str] = None
+ found_event_ids: List[str] = []
encoded_key = urllib.parse.quote_plus("👍".encode())
for _ in range(20):
from_token = ""
@@ -677,24 +677,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
def _send_relation(
self,
- relation_type,
- event_type,
- key=None,
+ relation_type: str,
+ event_type: str,
+ key: Optional[str] = None,
content: Optional[dict] = None,
- access_token=None,
- parent_id=None,
- ):
+ access_token: Optional[str] = None,
+ parent_id: Optional[str] = None,
+ ) -> FakeChannel:
"""Helper function to send a relation pointing at `self.parent_id`
Args:
- relation_type (str): One of `RelationTypes`
- event_type (str): The type of the event to create
- parent_id (str): The event_id this relation relates to. If None, then self.parent_id
- key (str|None): The aggregation key used for m.annotation relation
- type.
- content(dict|None): The content of the created event.
- access_token (str|None): The access token used to send the relation,
- defaults to `self.user_token`
+ relation_type: One of `RelationTypes`
+ event_type: The type of the event to create
+ key: The aggregation key used for m.annotation relation type.
+ content: The content of the created event.
+ access_token: The access token used to send the relation, defaults
+ to `self.user_token`
+ parent_id: The event_id this relation relates to. If None, then self.parent_id
Returns:
FakeChannel
@@ -712,12 +711,12 @@ class RelationsTestCase(unittest.HomeserverTestCase):
"POST",
"/_matrix/client/unstable/rooms/%s/send_relation/%s/%s/%s%s"
% (self.room, original_id, relation_type, event_type, query),
- json.dumps(content or {}).encode("utf-8"),
+ content or {},
access_token=access_token,
)
return channel
- def _create_user(self, localpart):
+ def _create_user(self, localpart: str) -> Tuple[str, str]:
user_id = self.register_user(localpart, "abc123")
access_token = self.login(localpart, "abc123")
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 38ac9be113..531f09c48b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -12,25 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import threading
-from typing import Dict
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
from unittest.mock import Mock
from synapse.api.constants import EventTypes
+from synapse.api.errors import SynapseError
from synapse.events import EventBase
from synapse.events.third_party_rules import load_legacy_third_party_event_rules
-from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login, room
-from synapse.types import Requester, StateMap
+from synapse.types import JsonDict, Requester, StateMap
from synapse.util.frozenutils import unfreeze
from tests import unittest
+if TYPE_CHECKING:
+ from synapse.module_api import ModuleApi
+
thread_local = threading.local()
class LegacyThirdPartyRulesTestModule:
- def __init__(self, config: Dict, module_api: ModuleApi):
+ def __init__(self, config: Dict, module_api: "ModuleApi"):
# keep a record of the "current" rules module, so that the test can patch
# it if desired.
thread_local.rules_module = self
@@ -50,7 +53,7 @@ class LegacyThirdPartyRulesTestModule:
class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
- def __init__(self, config: Dict, module_api: ModuleApi):
+ def __init__(self, config: Dict, module_api: "ModuleApi"):
super().__init__(config, module_api)
def on_create_room(
@@ -60,7 +63,7 @@ class LegacyDenyNewRooms(LegacyThirdPartyRulesTestModule):
class LegacyChangeEvents(LegacyThirdPartyRulesTestModule):
- def __init__(self, config: Dict, module_api: ModuleApi):
+ def __init__(self, config: Dict, module_api: "ModuleApi"):
super().__init__(config, module_api)
async def check_event_allowed(self, event: EventBase, state: StateMap[EventBase]):
@@ -136,6 +139,47 @@ class ThirdPartyRulesTestCase(unittest.HomeserverTestCase):
)
self.assertEquals(channel.result["code"], b"403", channel.result)
+ def test_third_party_rules_workaround_synapse_errors_pass_through(self):
+ """
+ Tests that the workaround introduced by https://github.com/matrix-org/synapse/pull/11042
+ is functional: that SynapseErrors are passed through from check_event_allowed
+ and bubble up to the web resource.
+
+ NEW MODULES SHOULD NOT MAKE USE OF THIS WORKAROUND!
+ This is a temporary workaround!
+ """
+
+ class NastyHackException(SynapseError):
+ def error_dict(self):
+ """
+ This overrides SynapseError's `error_dict` to nastily inject
+ JSON into the error response.
+ """
+ result = super().error_dict()
+ result["nasty"] = "very"
+ return result
+
+ # add a callback that will raise our hacky exception
+ async def check(ev, state) -> Tuple[bool, Optional[JsonDict]]:
+ raise NastyHackException(429, "message")
+
+ self.hs.get_third_party_event_rules()._check_event_allowed_callbacks = [check]
+
+ # Make a request
+ channel = self.make_request(
+ "PUT",
+ "/_matrix/client/r0/rooms/%s/send/foo.bar.forbidden/2" % self.room_id,
+ {},
+ access_token=self.tok,
+ )
+ # Check the error code
+ self.assertEquals(channel.result["code"], b"429", channel.result)
+ # Check the JSON body has had the `nasty` key injected
+ self.assertEqual(
+ channel.json_body,
+ {"errcode": "M_UNKNOWN", "error": "message", "nasty": "very"},
+ )
+
def test_cannot_modify_event(self):
"""cannot accidentally modify an event before it is persisted"""
diff --git a/tests/rest/media/v1/test_filepath.py b/tests/rest/media/v1/test_filepath.py
new file mode 100644
index 0000000000..09504a485f
--- /dev/null
+++ b/tests/rest/media/v1/test_filepath.py
@@ -0,0 +1,238 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.rest.media.v1.filepath import MediaFilePaths
+
+from tests import unittest
+
+
+class MediaFilePathsTestCase(unittest.TestCase):
+ def setUp(self):
+ super().setUp()
+
+ self.filepaths = MediaFilePaths("/media_store")
+
+ def test_local_media_filepath(self):
+ """Test local media paths"""
+ self.assertEqual(
+ self.filepaths.local_media_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
+ "local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+ self.assertEqual(
+ self.filepaths.local_media_filepath("GerZNDnDZVjsOtardLuwfIBg"),
+ "/media_store/local_content/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_local_media_thumbnail(self):
+ """Test local media thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.local_media_thumbnail_rel(
+ "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+ ),
+ "local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+ self.assertEqual(
+ self.filepaths.local_media_thumbnail(
+ "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+ ),
+ "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+
+ def test_local_media_thumbnail_dir(self):
+ """Test local media thumbnail directory paths"""
+ self.assertEqual(
+ self.filepaths.local_media_thumbnail_dir("GerZNDnDZVjsOtardLuwfIBg"),
+ "/media_store/local_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_remote_media_filepath(self):
+ """Test remote media paths"""
+ self.assertEqual(
+ self.filepaths.remote_media_filepath_rel(
+ "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ "remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+ self.assertEqual(
+ self.filepaths.remote_media_filepath(
+ "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ "/media_store/remote_content/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_remote_media_thumbnail(self):
+ """Test remote media thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.remote_media_thumbnail_rel(
+ "example.com",
+ "GerZNDnDZVjsOtardLuwfIBg",
+ 800,
+ 600,
+ "image/jpeg",
+ "scale",
+ ),
+ "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+ self.assertEqual(
+ self.filepaths.remote_media_thumbnail(
+ "example.com",
+ "GerZNDnDZVjsOtardLuwfIBg",
+ 800,
+ 600,
+ "image/jpeg",
+ "scale",
+ ),
+ "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+
+ def test_remote_media_thumbnail_legacy(self):
+ """Test old-style remote media thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.remote_media_thumbnail_rel_legacy(
+ "example.com", "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg"
+ ),
+ "remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg",
+ )
+
+ def test_remote_media_thumbnail_dir(self):
+ """Test remote media thumbnail directory paths"""
+ self.assertEqual(
+ self.filepaths.remote_media_thumbnail_dir(
+ "example.com", "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ "/media_store/remote_thumbnail/example.com/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_url_cache_filepath(self):
+ """Test URL cache paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_filepath_rel("2020-01-02_GerZNDnDZVjsOtar"),
+ "url_cache/2020-01-02/GerZNDnDZVjsOtar",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_filepath("2020-01-02_GerZNDnDZVjsOtar"),
+ "/media_store/url_cache/2020-01-02/GerZNDnDZVjsOtar",
+ )
+
+ def test_url_cache_filepath_legacy(self):
+ """Test old-style URL cache paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_filepath_rel("GerZNDnDZVjsOtardLuwfIBg"),
+ "url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_filepath("GerZNDnDZVjsOtardLuwfIBg"),
+ "/media_store/url_cache/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_url_cache_filepath_dirs_to_delete(self):
+ """Test URL cache cleanup paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_filepath_dirs_to_delete(
+ "2020-01-02_GerZNDnDZVjsOtar"
+ ),
+ ["/media_store/url_cache/2020-01-02"],
+ )
+
+ def test_url_cache_filepath_dirs_to_delete_legacy(self):
+ """Test old-style URL cache cleanup paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_filepath_dirs_to_delete(
+ "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ [
+ "/media_store/url_cache/Ge/rZ",
+ "/media_store/url_cache/Ge",
+ ],
+ )
+
+ def test_url_cache_thumbnail(self):
+ """Test URL cache thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_rel(
+ "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale"
+ ),
+ "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail(
+ "2020-01-02_GerZNDnDZVjsOtar", 800, 600, "image/jpeg", "scale"
+ ),
+ "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar/800-600-image-jpeg-scale",
+ )
+
+ def test_url_cache_thumbnail_legacy(self):
+ """Test old-style URL cache thumbnail paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_rel(
+ "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+ ),
+ "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail(
+ "GerZNDnDZVjsOtardLuwfIBg", 800, 600, "image/jpeg", "scale"
+ ),
+ "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg/800-600-image-jpeg-scale",
+ )
+
+ def test_url_cache_thumbnail_directory(self):
+ """Test URL cache thumbnail directory paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_directory_rel(
+ "2020-01-02_GerZNDnDZVjsOtar"
+ ),
+ "url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_directory("2020-01-02_GerZNDnDZVjsOtar"),
+ "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+ )
+
+ def test_url_cache_thumbnail_directory_legacy(self):
+ """Test old-style URL cache thumbnail directory paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_directory_rel(
+ "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ "url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_directory("GerZNDnDZVjsOtardLuwfIBg"),
+ "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ )
+
+ def test_url_cache_thumbnail_dirs_to_delete(self):
+ """Test URL cache thumbnail cleanup paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_dirs_to_delete(
+ "2020-01-02_GerZNDnDZVjsOtar"
+ ),
+ [
+ "/media_store/url_cache_thumbnails/2020-01-02/GerZNDnDZVjsOtar",
+ "/media_store/url_cache_thumbnails/2020-01-02",
+ ],
+ )
+
+ def test_url_cache_thumbnail_dirs_to_delete_legacy(self):
+ """Test old-style URL cache thumbnail cleanup paths"""
+ self.assertEqual(
+ self.filepaths.url_cache_thumbnail_dirs_to_delete(
+ "GerZNDnDZVjsOtardLuwfIBg"
+ ),
+ [
+ "/media_store/url_cache_thumbnails/Ge/rZ/NDnDZVjsOtardLuwfIBg",
+ "/media_store/url_cache_thumbnails/Ge/rZ",
+ "/media_store/url_cache_thumbnails/Ge",
+ ],
+ )
diff --git a/tests/rest/media/v1/test_oembed.py b/tests/rest/media/v1/test_oembed.py
new file mode 100644
index 0000000000..048d0ca44a
--- /dev/null
+++ b/tests/rest/media/v1/test_oembed.py
@@ -0,0 +1,51 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest.media.v1.oembed import OEmbedProvider
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class OEmbedTests(HomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ self.oembed = OEmbedProvider(homeserver)
+
+ def parse_response(self, response: JsonDict):
+ return self.oembed.parse_oembed_response(
+ "https://test", json.dumps(response).encode("utf-8")
+ )
+
+ def test_version(self):
+ """Accept versions that are similar to 1.0 as a string or int (or missing)."""
+ for version in ("1.0", 1.0, 1):
+ result = self.parse_response({"version": version, "type": "link"})
+ # An empty Open Graph response is an error, ensure the URL is included.
+ self.assertIn("og:url", result.open_graph_result)
+
+ # A missing version should be treated as 1.0.
+ result = self.parse_response({"type": "link"})
+ self.assertIn("og:url", result.open_graph_result)
+
+ # Invalid versions should be rejected.
+ for version in ("2.0", "1", 1.1, 0, None, {}, []):
+ result = self.parse_response({"version": version, "type": "link"})
+ # An empty Open Graph response is an error, ensure the URL is included.
+ self.assertEqual({}, result.open_graph_result)
diff --git a/tests/server.py b/tests/server.py
index 64645651ce..103351b487 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -1,3 +1,17 @@
+# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import json
import logging
from collections import deque
@@ -27,9 +41,10 @@ from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock
from twisted.web.http_headers import Headers
from twisted.web.resource import IResource
-from twisted.web.server import Site
+from twisted.web.server import Request, Site
from synapse.http.site import SynapseRequest
+from synapse.types import JsonDict
from synapse.util import Clock
from tests.utils import setup_test_homeserver as _sth
@@ -198,14 +213,14 @@ class FakeSite:
def make_request(
reactor,
site: Union[Site, FakeSite],
- method,
- path,
- content=b"",
- access_token=None,
- request=SynapseRequest,
- shorthand=True,
- federation_auth_origin=None,
- content_is_form=False,
+ method: Union[bytes, str],
+ path: Union[bytes, str],
+ content: Union[bytes, str, JsonDict] = b"",
+ access_token: Optional[str] = None,
+ request: Request = SynapseRequest,
+ shorthand: bool = True,
+ federation_auth_origin: Optional[bytes] = None,
+ content_is_form: bool = False,
await_result: bool = True,
custom_headers: Optional[
Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
@@ -218,26 +233,23 @@ def make_request(
Returns the fake Channel object which records the response to the request.
Args:
+ reactor:
site: The twisted Site to use to render the request
-
- method (bytes/unicode): The HTTP request method ("verb").
- path (bytes/unicode): The HTTP path, suitably URL encoded (e.g.
- escaped UTF-8 & spaces and such).
- content (bytes or dict): The body of the request. JSON-encoded, if
- a dict.
+ method: The HTTP request method ("verb").
+ path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and such).
+ content: The body of the request. JSON-encoded, if a str of bytes.
+ access_token: The access token to add as authorization for the request.
+ request: The request class to create.
shorthand: Whether to try and be helpful and prefix the given URL
- with the usual REST API path, if it doesn't contain it.
- federation_auth_origin (bytes|None): if set to not-None, we will add a fake
+ with the usual REST API path, if it doesn't contain it.
+ federation_auth_origin: if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
-
- custom_headers: (name, value) pairs to add as request headers
-
await_result: whether to wait for the request to complete rendering. If true,
will pump the reactor until the the renderer tells the channel the request
is finished.
-
+ custom_headers: (name, value) pairs to add as request headers
client_ip: The IP to use as the requesting IP. Useful for testing
ratelimiting.
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index be3ed64f5e..37cf7bb232 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Any, Dict, List, Set, Tuple
+from typing import Any, Dict, Set, Tuple
from unittest import mock
from unittest.mock import Mock, patch
@@ -42,18 +42,7 @@ class GetUserDirectoryTables:
def __init__(self, store: DataStore):
self.store = store
- def _compress_shared(
- self, shared: List[Dict[str, str]]
- ) -> Set[Tuple[str, str, str]]:
- """
- Compress a list of users who share rooms dicts to a list of tuples.
- """
- r = set()
- for i in shared:
- r.add((i["user_id"], i["other_user_id"], i["room_id"]))
- return r
-
- async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
+ async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]:
"""Fetch the entire `users_in_public_rooms` table.
Returns a list of tuples (user_id, room_id) where room_id is public and
@@ -63,24 +52,27 @@ class GetUserDirectoryTables:
"users_in_public_rooms", None, ("user_id", "room_id")
)
- retval = []
+ retval = set()
for i in r:
- retval.append((i["user_id"], i["room_id"]))
+ retval.add((i["user_id"], i["room_id"]))
return retval
- async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]:
+ async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
"""Fetch the entire `users_who_share_private_rooms` table.
- Returns a dict containing "user_id", "other_user_id" and "room_id" keys.
- The dicts can be flattened to Tuples with the `_compress_shared` method.
- (This seems a little awkward---maybe we could clean this up.)
+ Returns a set of tuples (user_id, other_user_id, room_id) corresponding
+ to the rows of `users_who_share_private_rooms`.
"""
- return await self.store.db_pool.simple_select_list(
+ rows = await self.store.db_pool.simple_select_list(
"users_who_share_private_rooms",
None,
["user_id", "other_user_id", "room_id"],
)
+ rv = set()
+ for row in rows:
+ rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
+ return rv
async def get_users_in_user_directory(self) -> Set[str]:
"""Fetch the set of users in the `user_directory` table.
@@ -113,6 +105,16 @@ class GetUserDirectoryTables:
for row in rows
}
+ async def get_tables(
+ self,
+ ) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
+ """Multiple tests want to inspect these tables, so expose them together."""
+ return (
+ await self.get_users_in_user_directory(),
+ await self.get_users_in_public_rooms(),
+ await self.get_users_who_share_private_rooms(),
+ )
+
class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
"""Ensure that rebuilding the directory writes the correct data to the DB.
@@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
)
# Nothing updated yet
- self.assertEqual(shares_private, [])
- self.assertEqual(public_users, [])
+ self.assertEqual(shares_private, set())
+ self.assertEqual(public_users, set())
# Ugh, have to reset this flag
self.store.db_pool.updates._all_done = False
@@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
# Do the initial population of the user directory via the background update
self._purge_and_rebuild_user_dir()
- shares_private = self.get_success(
- self.user_dir_helper.get_users_who_share_private_rooms()
- )
- public_users = self.get_success(
- self.user_dir_helper.get_users_in_public_rooms()
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
)
# User 1 and User 2 are in the same public room
- self.assertEqual(set(public_users), {(u1, room), (u2, room)})
-
+ self.assertEqual(in_public, {(u1, room), (u2, room)})
# User 1 and User 3 share private rooms
- self.assertEqual(
- self.user_dir_helper._compress_shared(shares_private),
- {(u1, u3, private_room), (u3, u1, private_room)},
- )
-
+ self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)})
# All three should have entries in the directory
- users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
self.assertEqual(users, {u1, u2, u3})
# The next four tests (test_population_excludes_*) all set up
@@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
self, normal_user: str, public_room: str, private_room: str
) -> None:
# After rebuilding the directory, we should only see the normal user.
- users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
- self.assertEqual(users, {normal_user})
- in_public_rooms = self.get_success(
- self.user_dir_helper.get_users_in_public_rooms()
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
)
- self.assertEqual(set(in_public_rooms), {(normal_user, public_room)})
- in_private_rooms = self.get_success(
- self.user_dir_helper.get_users_who_share_private_rooms()
- )
- self.assertEqual(in_private_rooms, [])
+ self.assertEqual(users, {normal_user})
+ self.assertEqual(in_public, {(normal_user, public_room)})
+ self.assertEqual(in_private, set())
def test_population_excludes_support_user(self) -> None:
# Create a normal and support user.
diff --git a/tests/test_event_auth.py b/tests/test_event_auth.py
index cf407c51cf..e2c506e5a4 100644
--- a/tests/test_event_auth.py
+++ b/tests/test_event_auth.py
@@ -24,6 +24,47 @@ from synapse.types import JsonDict, get_domain_from_id
class EventAuthTestCase(unittest.TestCase):
+ def test_rejected_auth_events(self):
+ """
+ Events that refer to rejected events in their auth events are rejected
+ """
+ creator = "@creator:example.com"
+ auth_events = [
+ _create_event(creator),
+ _join_event(creator),
+ ]
+
+ # creator should be able to send state
+ event_auth.check_auth_rules_for_event(
+ RoomVersions.V9,
+ _random_state_event(creator),
+ auth_events,
+ )
+
+ # ... but a rejected join_rules event should cause it to be rejected
+ rejected_join_rules = _join_rules_event(creator, "public")
+ rejected_join_rules.rejected_reason = "stinky"
+ auth_events.append(rejected_join_rules)
+
+ self.assertRaises(
+ AuthError,
+ event_auth.check_auth_rules_for_event,
+ RoomVersions.V9,
+ _random_state_event(creator),
+ auth_events,
+ )
+
+ # ... even if there is *also* a good join rules
+ auth_events.append(_join_rules_event(creator, "public"))
+
+ self.assertRaises(
+ AuthError,
+ event_auth.check_auth_rules_for_event,
+ RoomVersions.V9,
+ _random_state_event(creator),
+ auth_events,
+ )
+
def test_random_users_cannot_send_state_before_first_pl(self):
"""
Check that, before the first PL lands, the creator is the only user
@@ -31,11 +72,11 @@ class EventAuthTestCase(unittest.TestCase):
"""
creator = "@creator:example.com"
joiner = "@joiner:example.com"
- auth_events = {
- ("m.room.create", ""): _create_event(creator),
- ("m.room.member", creator): _join_event(creator),
- ("m.room.member", joiner): _join_event(joiner),
- }
+ auth_events = [
+ _create_event(creator),
+ _join_event(creator),
+ _join_event(joiner),
+ ]
# creator should be able to send state
event_auth.check_auth_rules_for_event(
@@ -62,15 +103,15 @@ class EventAuthTestCase(unittest.TestCase):
pleb = "@joiner:example.com"
king = "@joiner2:example.com"
- auth_events = {
- ("m.room.create", ""): _create_event(creator),
- ("m.room.member", creator): _join_event(creator),
- ("m.room.power_levels", ""): _power_levels_event(
+ auth_events = [
+ _create_event(creator),
+ _join_event(creator),
+ _power_levels_event(
creator, {"state_default": "30", "users": {pleb: "29", king: "30"}}
),
- ("m.room.member", pleb): _join_event(pleb),
- ("m.room.member", king): _join_event(king),
- }
+ _join_event(pleb),
+ _join_event(king),
+ ]
# pleb should not be able to send state
self.assertRaises(
@@ -92,10 +133,10 @@ class EventAuthTestCase(unittest.TestCase):
"""Alias events have special behavior up through room version 6."""
creator = "@creator:example.com"
other = "@other:example.com"
- auth_events = {
- ("m.room.create", ""): _create_event(creator),
- ("m.room.member", creator): _join_event(creator),
- }
+ auth_events = [
+ _create_event(creator),
+ _join_event(creator),
+ ]
# creator should be able to send aliases
event_auth.check_auth_rules_for_event(
@@ -131,10 +172,10 @@ class EventAuthTestCase(unittest.TestCase):
"""After MSC2432, alias events have no special behavior."""
creator = "@creator:example.com"
other = "@other:example.com"
- auth_events = {
- ("m.room.create", ""): _create_event(creator),
- ("m.room.member", creator): _join_event(creator),
- }
+ auth_events = [
+ _create_event(creator),
+ _join_event(creator),
+ ]
# creator should be able to send aliases
event_auth.check_auth_rules_for_event(
@@ -170,14 +211,14 @@ class EventAuthTestCase(unittest.TestCase):
creator = "@creator:example.com"
pleb = "@joiner:example.com"
- auth_events = {
- ("m.room.create", ""): _create_event(creator),
- ("m.room.member", creator): _join_event(creator),
- ("m.room.power_levels", ""): _power_levels_event(
+ auth_events = [
+ _create_event(creator),
+ _join_event(creator),
+ _power_levels_event(
creator, {"state_default": "30", "users": {pleb: "30"}}
),
- ("m.room.member", pleb): _join_event(pleb),
- }
+ _join_event(pleb),
+ ]
# pleb should be able to modify the notifications power level.
event_auth.check_auth_rules_for_event(
@@ -211,7 +252,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user cannot be force-joined to a room.
@@ -219,7 +260,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_member_event(pleb, "join", sender=creator),
- auth_events,
+ auth_events.values(),
)
# Banned should be rejected.
@@ -228,7 +269,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user who left can re-join.
@@ -236,7 +277,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user can send a join if they're in the room.
@@ -244,7 +285,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user can accept an invite.
@@ -254,7 +295,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
def test_join_rules_invite(self):
@@ -275,7 +316,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user cannot be force-joined to a room.
@@ -283,7 +324,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_member_event(pleb, "join", sender=creator),
- auth_events,
+ auth_events.values(),
)
# Banned should be rejected.
@@ -292,7 +333,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user who left cannot re-join.
@@ -301,7 +342,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user can send a join if they're in the room.
@@ -309,7 +350,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user can accept an invite.
@@ -319,7 +360,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
def test_join_rules_msc3083_restricted(self):
@@ -347,7 +388,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V6,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A properly formatted join event should work.
@@ -360,7 +401,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
- auth_events,
+ auth_events.values(),
)
# A join issued by a specific user works (i.e. the power level checks
@@ -380,7 +421,7 @@ class EventAuthTestCase(unittest.TestCase):
EventContentFields.AUTHORISING_USER: "@inviter:foo.test"
},
),
- pl_auth_events,
+ pl_auth_events.values(),
)
# A join which is missing an authorised server is rejected.
@@ -388,7 +429,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# An join authorised by a user who is not in the room is rejected.
@@ -405,7 +446,7 @@ class EventAuthTestCase(unittest.TestCase):
EventContentFields.AUTHORISING_USER: "@other:example.com"
},
),
- auth_events,
+ auth_events.values(),
)
# A user cannot be force-joined to a room. (This uses an event which
@@ -421,7 +462,7 @@ class EventAuthTestCase(unittest.TestCase):
EventContentFields.AUTHORISING_USER: "@inviter:foo.test"
},
),
- auth_events,
+ auth_events.values(),
)
# Banned should be rejected.
@@ -430,7 +471,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
- auth_events,
+ auth_events.values(),
)
# A user who left can re-join.
@@ -438,7 +479,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V8,
authorised_join_event,
- auth_events,
+ auth_events.values(),
)
# A user can send a join if they're in the room. (This doesn't need to
@@ -447,7 +488,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
# A user can accept an invite. (This doesn't need to be authorised since
@@ -458,7 +499,7 @@ class EventAuthTestCase(unittest.TestCase):
event_auth.check_auth_rules_for_event(
RoomVersions.V8,
_join_event(pleb),
- auth_events,
+ auth_events.values(),
)
@@ -473,6 +514,7 @@ def _create_event(user_id: str) -> EventBase:
"room_id": TEST_ROOM_ID,
"event_id": _get_event_id(),
"type": "m.room.create",
+ "state_key": "",
"sender": user_id,
"content": {"creator": user_id},
}
diff --git a/tests/test_preview.py b/tests/test_preview.py
index 09e017b4d9..9a576f9a4e 100644
--- a/tests/test_preview.py
+++ b/tests/test_preview.py
@@ -15,7 +15,7 @@
from synapse.rest.media.v1.preview_url_resource import (
_calc_og,
decode_body,
- get_html_media_encoding,
+ get_html_media_encodings,
summarize_paragraphs,
)
@@ -159,7 +159,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -175,7 +175,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -194,7 +194,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(
@@ -216,7 +216,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -230,7 +230,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -245,7 +245,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
@@ -260,7 +260,7 @@ class CalcOgTestCase(unittest.TestCase):
</html>
"""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
@@ -268,13 +268,13 @@ class CalcOgTestCase(unittest.TestCase):
def test_empty(self):
"""Test a body with no data in it."""
html = b""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
def test_no_tree(self):
"""A valid body with no tree in it."""
html = b"\x00"
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
self.assertIsNone(tree)
def test_invalid_encoding(self):
@@ -287,7 +287,7 @@ class CalcOgTestCase(unittest.TestCase):
</body>
</html>
"""
- tree = decode_body(html, "invalid-encoding")
+ tree = decode_body(html, "http://example.com/test.html", "invalid-encoding")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
@@ -302,15 +302,29 @@ class CalcOgTestCase(unittest.TestCase):
</body>
</html>
"""
- tree = decode_body(html)
+ tree = decode_body(html, "http://example.com/test.html")
og = _calc_og(tree, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
+ def test_windows_1252(self):
+ """A body which uses cp1252, but doesn't declare that."""
+ html = b"""
+ <html>
+ <head><title>\xf3</title></head>
+ <body>
+ Some text.
+ </body>
+ </html>
+ """
+ tree = decode_body(html, "http://example.com/test.html")
+ og = _calc_og(tree, "http://example.com/test.html")
+ self.assertEqual(og, {"og:title": "ó", "og:description": "Some text."})
+
class MediaEncodingTestCase(unittest.TestCase):
def test_meta_charset(self):
"""A character encoding is found via the meta tag."""
- encoding = get_html_media_encoding(
+ encodings = get_html_media_encodings(
b"""
<html>
<head><meta charset="ascii">
@@ -319,10 +333,10 @@ class MediaEncodingTestCase(unittest.TestCase):
""",
"text/html",
)
- self.assertEqual(encoding, "ascii")
+ self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
# A less well-formed version.
- encoding = get_html_media_encoding(
+ encodings = get_html_media_encodings(
b"""
<html>
<head>< meta charset = ascii>
@@ -331,11 +345,11 @@ class MediaEncodingTestCase(unittest.TestCase):
""",
"text/html",
)
- self.assertEqual(encoding, "ascii")
+ self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_meta_charset_underscores(self):
"""A character encoding contains underscore."""
- encoding = get_html_media_encoding(
+ encodings = get_html_media_encodings(
b"""
<html>
<head><meta charset="Shift_JIS">
@@ -344,11 +358,11 @@ class MediaEncodingTestCase(unittest.TestCase):
""",
"text/html",
)
- self.assertEqual(encoding, "Shift_JIS")
+ self.assertEqual(list(encodings), ["shift_jis", "utf-8", "cp1252"])
def test_xml_encoding(self):
"""A character encoding is found via the meta tag."""
- encoding = get_html_media_encoding(
+ encodings = get_html_media_encodings(
b"""
<?xml version="1.0" encoding="ascii"?>
<html>
@@ -356,11 +370,11 @@ class MediaEncodingTestCase(unittest.TestCase):
""",
"text/html",
)
- self.assertEqual(encoding, "ascii")
+ self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_meta_xml_encoding(self):
"""Meta tags take precedence over XML encoding."""
- encoding = get_html_media_encoding(
+ encodings = get_html_media_encodings(
b"""
<?xml version="1.0" encoding="ascii"?>
<html>
@@ -370,7 +384,7 @@ class MediaEncodingTestCase(unittest.TestCase):
""",
"text/html",
)
- self.assertEqual(encoding, "UTF-16")
+ self.assertEqual(list(encodings), ["utf-16", "ascii", "utf-8", "cp1252"])
def test_content_type(self):
"""A character encoding is found via the Content-Type header."""
@@ -384,10 +398,37 @@ class MediaEncodingTestCase(unittest.TestCase):
'text/html; charset=ascii";',
)
for header in headers:
- encoding = get_html_media_encoding(b"", header)
- self.assertEqual(encoding, "ascii")
+ encodings = get_html_media_encodings(b"", header)
+ self.assertEqual(list(encodings), ["ascii", "utf-8", "cp1252"])
def test_fallback(self):
"""A character encoding cannot be found in the body or header."""
- encoding = get_html_media_encoding(b"", "text/html")
- self.assertEqual(encoding, "utf-8")
+ encodings = get_html_media_encodings(b"", "text/html")
+ self.assertEqual(list(encodings), ["utf-8", "cp1252"])
+
+ def test_duplicates(self):
+ """Ensure each encoding is only attempted once."""
+ encodings = get_html_media_encodings(
+ b"""
+ <?xml version="1.0" encoding="utf8"?>
+ <html>
+ <head><meta charset="UTF-8">
+ </head>
+ </html>
+ """,
+ 'text/html; charset="UTF_8"',
+ )
+ self.assertEqual(list(encodings), ["utf-8", "cp1252"])
+
+ def test_unknown_invalid(self):
+ """A character encoding should be ignored if it is unknown or invalid."""
+ encodings = get_html_media_encodings(
+ b"""
+ <html>
+ <head><meta charset="invalid">
+ </head>
+ </html>
+ """,
+ 'text/html; charset="invalid"',
+ )
+ self.assertEqual(list(encodings), ["utf-8", "cp1252"])
diff --git a/tests/unittest.py b/tests/unittest.py
index 81c1a9e9d2..a9b60b7eeb 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -46,7 +46,7 @@ from synapse.logging.context import (
set_current_context,
)
from synapse.server import HomeServer
-from synapse.types import UserID, create_requester
+from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
@@ -401,7 +401,7 @@ class HomeserverTestCase(TestCase):
self,
method: Union[bytes, str],
path: Union[bytes, str],
- content: Union[bytes, dict] = b"",
+ content: Union[bytes, str, JsonDict] = b"",
access_token: Optional[str] = None,
request: Type[T] = SynapseRequest,
shorthand: bool = True,
|