diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index 6a2e76ca4a..be008227df 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -15,8 +15,8 @@
from unittest.mock import Mock
-import synapse
import synapse.api.errors
+import synapse.rest.admin
from synapse.api.constants import EventTypes
from synapse.config.room_directory import RoomDirectoryConfig
from synapse.rest.client import directory, login, room
@@ -432,6 +432,106 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.result)
+class TestCreatePublishedRoomACL(unittest.HomeserverTestCase):
+ data = {"room_alias_name": "unofficial_test"}
+
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ directory.register_servlets,
+ room.register_servlets,
+ ]
+ hijack_auth = False
+
+ def prepare(self, reactor, clock, hs):
+ self.allowed_user_id = self.register_user("allowed", "pass")
+ self.allowed_access_token = self.login("allowed", "pass")
+
+ self.denied_user_id = self.register_user("denied", "pass")
+ self.denied_access_token = self.login("denied", "pass")
+
+ # This time we add custom room list publication rules
+ config = {}
+ config["alias_creation_rules"] = []
+ config["room_list_publication_rules"] = [
+ {"user_id": "*", "alias": "*", "action": "deny"},
+ {"user_id": self.allowed_user_id, "alias": "*", "action": "allow"},
+ ]
+
+ rd_config = RoomDirectoryConfig()
+ rd_config.read_config(config)
+
+ self.hs.config.roomdirectory.is_publishing_room_allowed = (
+ rd_config.is_publishing_room_allowed
+ )
+
+ return hs
+
+ def test_denied_without_publication_permission(self):
+ """
+ Try to create a room, register an alias for it, and publish it,
+ as a user without permission to publish rooms.
+ (This is used as both a standalone test & as a helper function.)
+ """
+ self.helper.create_room_as(
+ self.denied_user_id,
+ tok=self.denied_access_token,
+ extra_content=self.data,
+ is_public=True,
+ expect_code=403,
+ )
+
+ def test_allowed_when_creating_private_room(self):
+ """
+ Try to create a room, register an alias for it, and NOT publish it,
+ as a user without permission to publish rooms.
+ (This is used as both a standalone test & as a helper function.)
+ """
+ self.helper.create_room_as(
+ self.denied_user_id,
+ tok=self.denied_access_token,
+ extra_content=self.data,
+ is_public=False,
+ expect_code=200,
+ )
+
+ def test_allowed_with_publication_permission(self):
+ """
+ Try to create a room, register an alias for it, and publish it,
+ as a user WITH permission to publish rooms.
+ (This is used as both a standalone test & as a helper function.)
+ """
+ self.helper.create_room_as(
+ self.allowed_user_id,
+ tok=self.allowed_access_token,
+ extra_content=self.data,
+ is_public=False,
+ expect_code=200,
+ )
+
+ def test_can_create_as_private_room_after_rejection(self):
+ """
+ After failing to publish a room with an alias as a user without publish permission,
+ retry as the same user, but without publishing the room.
+
+ This should pass, but used to fail because the alias was registered by the first
+ request, even though the room creation was denied.
+ """
+ self.test_denied_without_publication_permission()
+ self.test_allowed_when_creating_private_room()
+
+ def test_can_create_with_permission_after_rejection(self):
+ """
+ After failing to publish a room with an alias as a user without publish permission,
+ retry as someone with permission, using the same alias.
+
+ This also used to fail because of the alias having been registered by the first
+ request, leaving it unavailable for any other user's new rooms.
+ """
+ self.test_denied_without_publication_permission()
+ self.test_allowed_with_publication_permission()
+
+
class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
user_id = "@test:test"
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 936ebf3dde..e1557566e4 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -23,6 +23,7 @@ from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin
from synapse.rest.client import login, room
+from synapse.types import create_requester
from synapse.util.stringutils import random_string
from tests import unittest
@@ -30,6 +31,10 @@ from tests import unittest
logger = logging.getLogger(__name__)
+def generate_fake_event_id() -> str:
+ return "$fake_" + random_string(43)
+
+
class FederationTestCase(unittest.HomeserverTestCase):
servlets = [
admin.register_servlets,
@@ -198,6 +203,65 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
+ def test_backfill_with_many_backward_extremities(self):
+ """
+ Check that we can backfill with many backward extremities.
+ The goal is to make sure that when we only use a portion
+ of backwards extremities(the magic number is more than 5),
+ no errors are thrown.
+
+ Regression test, see #11027
+ """
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+ requester = create_requester(user_id)
+
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+ ev1 = self.helper.send(room_id, "first message", tok=tok)
+
+ # Create "many" backward extremities. The magic number we're trying to
+ # create more than is 5 which corresponds to the number of backward
+ # extremities we slice off in `_maybe_backfill_inner`
+ for _ in range(0, 8):
+ event_handler = self.hs.get_event_creation_handler()
+ event, context = self.get_success(
+ event_handler.create_event(
+ requester,
+ {
+ "type": "m.room.message",
+ "content": {
+ "msgtype": "m.text",
+ "body": "message connected to fake event",
+ },
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=[
+ ev1["event_id"],
+ # We're creating an backward extremity each time thanks
+ # to this fake event
+ generate_fake_event_id(),
+ ],
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(requester, event, context)
+ )
+
+ current_depth = 1
+ limit = 100
+ with LoggingContext("receive_pdu"):
+ # Make sure backfill still works
+ d = run_in_background(
+ self.hs.get_federation_handler().maybe_backfill,
+ room_id,
+ current_depth,
+ limit,
+ )
+ self.get_success(d)
+
def test_backfill_floating_outlier_membership_auth(self):
"""
As the local homeserver, check that we can properly process a federated
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index 38e6d9f536..7dd4a5a367 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -20,6 +20,8 @@ from unittest.mock import Mock
from twisted.internet import defer
import synapse
+from synapse.handlers.auth import load_legacy_password_auth_providers
+from synapse.module_api import ModuleApi
from synapse.rest.client import devices, login
from synapse.types import JsonDict
@@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi
mock_password_provider = Mock()
-class PasswordOnlyAuthProvider:
- """A password_provider which only implements `check_password`."""
+class LegacyPasswordOnlyAuthProvider:
+ """A legacy password_provider which only implements `check_password`."""
@staticmethod
def parse_config(self):
@@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider:
return mock_password_provider.check_password(*args)
-class CustomAuthProvider:
- """A password_provider which implements a custom login type."""
+class LegacyCustomAuthProvider:
+ """A legacy password_provider which implements a custom login type."""
@staticmethod
def parse_config(self):
@@ -67,7 +69,23 @@ class CustomAuthProvider:
return mock_password_provider.check_auth(*args)
-class PasswordCustomAuthProvider:
+class CustomAuthProvider:
+ """A module which registers password_auth_provider callbacks for a custom login type."""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, api: ModuleApi):
+ api.register_password_auth_provider_callbacks(
+ auth_checkers={("test.login_type", ("test_field",)): self.check_auth},
+ )
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+
+class LegacyPasswordCustomAuthProvider:
"""A password_provider which implements password login via `check_auth`, as well
as a custom type."""
@@ -85,8 +103,32 @@ class PasswordCustomAuthProvider:
return mock_password_provider.check_auth(*args)
-def providers_config(*providers: Type[Any]) -> dict:
- """Returns a config dict that will enable the given password auth providers"""
+class PasswordCustomAuthProvider:
+ """A module which registers password_auth_provider callbacks for a custom login type.
+ as well as a password login"""
+
+ @staticmethod
+ def parse_config(self):
+ pass
+
+ def __init__(self, config, api: ModuleApi):
+ api.register_password_auth_provider_callbacks(
+ auth_checkers={
+ ("test.login_type", ("test_field",)): self.check_auth,
+ ("m.login.password", ("password",)): self.check_auth,
+ },
+ )
+ pass
+
+ def check_auth(self, *args):
+ return mock_password_provider.check_auth(*args)
+
+ def check_pass(self, *args):
+ return mock_password_provider.check_password(*args)
+
+
+def legacy_providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given legacy password auth providers"""
return {
"password_providers": [
{"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
@@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict:
}
+def providers_config(*providers: Type[Any]) -> dict:
+ """Returns a config dict that will enable the given modules"""
+ return {
+ "modules": [
+ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}}
+ for provider in providers
+ ]
+ }
+
+
class PasswordAuthProviderTests(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
@@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
super().setUp()
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_password_only_auth_provider_login(self):
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+ # Load the modules into the homeserver
+ module_api = hs.get_module_api()
+ for module, config in hs.config.modules.loaded_modules:
+ module(config=config, api=module_api)
+ load_legacy_password_auth_providers(hs)
+
+ return hs
+
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_password_only_auth_progiver_login_legacy(self):
+ self.password_only_auth_provider_login_test_body()
+
+ def password_only_auth_provider_login_test_body(self):
# login flows should only have m.login.password
flows = self._get_login_flows()
self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS)
@@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"@ USER🙂NAME :test", " pASS😢word "
)
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_password_only_auth_provider_ui_auth(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_password_only_auth_provider_ui_auth_legacy(self):
+ self.password_only_auth_provider_ui_auth_test_body()
+
+ def password_only_auth_provider_ui_auth_test_body(self):
"""UI Auth should delegate correctly to the password provider"""
# create the user, otherwise access doesn't work
@@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
mock_password_provider.check_password.assert_called_once_with("@u:test", "p")
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_local_user_fallback_login(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_local_user_fallback_login_legacy(self):
+ self.local_user_fallback_login_test_body()
+
+ def local_user_fallback_login_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@localuser:test", channel.json_body["user_id"])
- @override_config(providers_config(PasswordOnlyAuthProvider))
- def test_local_user_fallback_ui_auth(self):
+ @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider))
+ def test_local_user_fallback_ui_auth_legacy(self):
+ self.local_user_fallback_ui_auth_test_body()
+
+ def local_user_fallback_ui_auth_test_body(self):
"""rejected login should fall back to local db"""
self.register_user("localuser", "localpass")
@@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_login(self):
+ def test_no_local_user_fallback_login_legacy(self):
+ self.no_local_user_fallback_login_test_body()
+
+ def no_local_user_fallback_login_test_body(self):
"""localdb_enabled can block login with the local password"""
self.register_user("localuser", "localpass")
@@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"localdb_enabled": False},
}
)
- def test_no_local_user_fallback_ui_auth(self):
+ def test_no_local_user_fallback_ui_auth_legacy(self):
+ self.no_local_user_fallback_ui_auth_test_body()
+
+ def no_local_user_fallback_ui_auth_test_body(self):
"""localdb_enabled can block ui auth with the local password"""
self.register_user("localuser", "localpass")
@@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
- **providers_config(PasswordOnlyAuthProvider),
+ **legacy_providers_config(LegacyPasswordOnlyAuthProvider),
"password_config": {"enabled": False},
}
)
- def test_password_auth_disabled(self):
+ def test_password_auth_disabled_legacy(self):
+ self.password_auth_disabled_test_body()
+
+ def password_auth_disabled_test_body(self):
"""password auth doesn't work if it's disabled across the board"""
# login flows should be empty
flows = self._get_login_flows()
@@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_password.assert_not_called()
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_login_legacy(self):
+ self.custom_auth_provider_login_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_login(self):
+ self.custom_auth_provider_login_test_body()
+
+ def custom_auth_provider_login_test_body(self):
# login flows should have the custom flow and m.login.password, since we
# haven't disabled local password lookup.
# (password must come first, because reasons)
@@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
- mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", None)
+ )
channel = self._send_login("test.login_type", "u", test_field="y")
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual("@user:bz", channel.json_body["user_id"])
@@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# in these cases, but at least we can guard against the API changing
# unexpectedly
mock_password_provider.check_auth.return_value = defer.succeed(
- "@ MALFORMED! :bz"
+ ("@ MALFORMED! :bz", None)
)
channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ")
self.assertEqual(channel.code, 200, channel.result)
@@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
" USER🙂NAME ", "test.login_type", {"test_field": " abc "}
)
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_ui_auth_legacy(self):
+ self.custom_auth_provider_ui_auth_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_ui_auth(self):
+ self.custom_auth_provider_ui_auth_test_body()
+
+ def custom_auth_provider_ui_auth_test_body(self):
# register the user and log in twice, to get two devices
self.register_user("localuser", "localpass")
tok1 = self.login("localuser", "localpass")
@@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.reset_mock()
# right params, but authing as the wrong user
- mock_password_provider.check_auth.return_value = defer.succeed("@user:bz")
+ mock_password_provider.check_auth.return_value = defer.succeed(
+ ("@user:bz", None)
+ )
body["auth"]["test_field"] = "foo"
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 403)
@@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
# and finally, succeed
mock_password_provider.check_auth.return_value = defer.succeed(
- "@localuser:test"
+ ("@localuser:test", None)
)
channel = self._delete_device(tok1, "dev2", body)
self.assertEqual(channel.code, 200)
@@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"localuser", "test.login_type", {"test_field": "foo"}
)
+ @override_config(legacy_providers_config(LegacyCustomAuthProvider))
+ def test_custom_auth_provider_callback_legacy(self):
+ self.custom_auth_provider_callback_test_body()
+
@override_config(providers_config(CustomAuthProvider))
def test_custom_auth_provider_callback(self):
+ self.custom_auth_provider_callback_test_body()
+
+ def custom_auth_provider_callback_test_body(self):
callback = Mock(return_value=defer.succeed(None))
mock_password_provider.check_auth.return_value = defer.succeed(
@@ -411,9 +519,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertIn(p, call_args[0])
@override_config(
+ {
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_legacy(self):
+ self.custom_auth_password_disabled_test_body()
+
+ @override_config(
{**providers_config(CustomAuthProvider), "password_config": {"enabled": False}}
)
def test_custom_auth_password_disabled(self):
+ self.custom_auth_password_disabled_test_body()
+
+ def custom_auth_password_disabled_test_body(self):
"""Test login with a custom auth provider where password login is disabled"""
self.register_user("localuser", "localpass")
@@ -427,11 +547,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled_legacy(self):
+ self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+ @override_config(
+ {
**providers_config(CustomAuthProvider),
"password_config": {"enabled": False, "localdb_enabled": False},
}
)
def test_custom_auth_password_disabled_localdb_enabled(self):
+ self.custom_auth_password_disabled_localdb_enabled_test_body()
+
+ def custom_auth_password_disabled_localdb_enabled_test_body(self):
"""Check the localdb_enabled == enabled == False
Regression test for https://github.com/matrix-org/synapse/issues/8914: check
@@ -450,11 +582,23 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_login_legacy(self):
+ self.password_custom_auth_password_disabled_login_test_body()
+
+ @override_config(
+ {
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
)
def test_password_custom_auth_password_disabled_login(self):
+ self.password_custom_auth_password_disabled_login_test_body()
+
+ def password_custom_auth_password_disabled_login_test_body(self):
"""log in with a custom auth provider which implements password, but password
login is disabled"""
self.register_user("localuser", "localpass")
@@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
channel = self._send_password_login("localuser", "localpass")
self.assertEqual(channel.code, 400, channel.result)
mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(
+ {
+ **legacy_providers_config(LegacyPasswordCustomAuthProvider),
+ "password_config": {"enabled": False},
+ }
+ )
+ def test_password_custom_auth_password_disabled_ui_auth_legacy(self):
+ self.password_custom_auth_password_disabled_ui_auth_test_body()
@override_config(
{
@@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_password_custom_auth_password_disabled_ui_auth(self):
+ self.password_custom_auth_password_disabled_ui_auth_test_body()
+
+ def password_custom_auth_password_disabled_ui_auth_test_body(self):
"""UI Auth with a custom auth provider which implements password, but password
login is disabled"""
# register the user and log in twice via the test login type to get two devices,
self.register_user("localuser", "localpass")
mock_password_provider.check_auth.return_value = defer.succeed(
- "@localuser:test"
+ ("@localuser:test", None)
)
channel = self._send_login("test.login_type", "localuser", test_field="")
self.assertEqual(channel.code, 200, channel.result)
@@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
"Password login has been disabled.", channel.json_body["error"]
)
mock_password_provider.check_auth.assert_not_called()
+ mock_password_provider.check_password.assert_not_called()
mock_password_provider.reset_mock()
# successful auth
@@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
mock_password_provider.check_auth.assert_called_once_with(
"localuser", "test.login_type", {"test_field": "x"}
)
+ mock_password_provider.check_password.assert_not_called()
+
+ @override_config(
+ {
+ **legacy_providers_config(LegacyCustomAuthProvider),
+ "password_config": {"localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_no_local_user_fallback_legacy(self):
+ self.custom_auth_no_local_user_fallback_test_body()
@override_config(
{
@@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
}
)
def test_custom_auth_no_local_user_fallback(self):
+ self.custom_auth_no_local_user_fallback_test_body()
+
+ def custom_auth_no_local_user_fallback_test_body(self):
"""Test login with a custom auth provider where the local db is disabled"""
self.register_user("localuser", "localpass")
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 0120b4688b..70c621b825 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
tok=alice_token,
)
- users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
- in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
- in_private = self.get_success(
- self.user_dir_helper.get_users_who_share_private_rooms()
+ # The user directory should reflect the room memberships above.
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
)
-
self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)})
self.assertEqual(
- set(in_public), {(alice, public), (bob, public), (alice, public2)}
- )
- self.assertEqual(
- self.user_dir_helper._compress_shared(in_private),
+ in_private,
{(alice, bob, private), (bob, alice, private)},
)
@@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
self.assertEqual(set(in_public), {(user1, room), (user2, room)})
+ def test_excludes_users_when_making_room_public(self) -> None:
+ # Create a regular user and a support user.
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ support = "@support1:test"
+ self.get_success(
+ self.store.register_user(
+ user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+ )
+ )
+
+ # Make a public and private room containing Alice and the support user
+ public, initially_private = self._create_rooms_and_inject_memberships(
+ alice, alice_token, support
+ )
+ self._check_only_one_user_in_directory(alice, public)
+
+ # Alice makes the private room public.
+ self.helper.send_state(
+ initially_private,
+ "m.room.join_rules",
+ {"join_rule": "public"},
+ tok=alice_token,
+ )
+
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice})
+ self.assertEqual(in_public, {(alice, public), (alice, initially_private)})
+ self.assertEqual(in_private, set())
+
+ def test_switching_from_private_to_public_to_private(self) -> None:
+ """Check we update the room sharing tables when switching a room
+ from private to public, then back again to private."""
+ # Alice and Bob share a private room.
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+ room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+ self.helper.invite(room, alice, bob, tok=alice_token)
+ self.helper.join(room, bob, tok=bob_token)
+
+ # The user directory should reflect this.
+ def check_user_dir_for_private_room() -> None:
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, set())
+ self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)})
+
+ check_user_dir_for_private_room()
+
+ # Alice makes the room public.
+ self.helper.send_state(
+ room,
+ "m.room.join_rules",
+ {"join_rule": "public"},
+ tok=alice_token,
+ )
+
+ # The user directory should be updated accordingly
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, {(alice, room), (bob, room)})
+ self.assertEqual(in_private, set())
+
+ # Alice makes the room private.
+ self.helper.send_state(
+ room,
+ "m.room.join_rules",
+ {"join_rule": "invite"},
+ tok=alice_token,
+ )
+
+ # The user directory should be updated accordingly
+ check_user_dir_for_private_room()
+
def _create_rooms_and_inject_memberships(
self, creator: str, token: str, joiner: str
) -> Tuple[str, str]:
@@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
return public_room, private_room
def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
- users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
- in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
- in_private = self.get_success(
- self.user_dir_helper.get_users_who_share_private_rooms()
- )
+ """Check that the user directory DB tables show that:
+ - only one user is in the user directory
+ - they belong to exactly one public room
+ - they don't share a private room with anyone.
+ """
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
self.assertEqual(users, {user})
- self.assertEqual(set(in_public), {(user, public)})
- self.assertEqual(in_private, [])
+ self.assertEqual(in_public, {(user, public)})
+ self.assertEqual(in_private, set())
def test_handle_local_profile_change_with_support_user(self) -> None:
support_user_id = "@support:test"
@@ -565,27 +646,22 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
u2_token = self.login(u2, "pass")
u3 = self.register_user("user3", "pass")
- # We do not add users to the directory until they join a room.
+ # u1 can't see u2 until they share a private room, or u1 is in a public room.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
+ # Get u1 and u2 into a private room.
room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
self.helper.join(room, user=u2, tok=u2_token)
# Check we have populated the database correctly.
- shares_private = self.get_success(
- self.user_dir_helper.get_users_who_share_private_rooms()
- )
- public_users = self.get_success(
- self.user_dir_helper.get_users_in_public_rooms()
- )
-
- self.assertEqual(
- self.user_dir_helper._compress_shared(shares_private),
- {(u1, u2, room), (u2, u1, room)},
+ users, public_users, shares_private = self.get_success(
+ self.user_dir_helper.get_tables()
)
- self.assertEqual(public_users, [])
+ self.assertEqual(users, {u1, u2, u3})
+ self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+ self.assertEqual(public_users, set())
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -602,16 +678,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# User 2 then leaves.
self.helper.leave(room, user=u2, tok=u2_token)
- # Check we have removed the values.
- shares_private = self.get_success(
- self.user_dir_helper.get_users_who_share_private_rooms()
- )
- public_users = self.get_success(
- self.user_dir_helper.get_users_in_public_rooms()
+ # Check this is reflected in the DB.
+ users, public_users, shares_private = self.get_success(
+ self.user_dir_helper.get_tables()
)
-
- self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
- self.assertEqual(public_users, [])
+ self.assertEqual(users, {u1, u2, u3})
+ self.assertEqual(shares_private, set())
+ self.assertEqual(public_users, set())
# User1 now gets no search results for any of the other users.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -620,6 +693,50 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user3", 10))
self.assertEqual(len(s["results"]), 0)
+ def test_joining_private_room_with_excluded_user(self) -> None:
+ """
+ When a user excluded from the user directory, E say, joins a private
+ room, E will not appear in the `users_who_share_private_rooms` table.
+
+ When a normal user, U say, joins a private room containing E, then
+ U will appear in the `users_who_share_private_rooms` table, but E will
+ not.
+ """
+ # Setup a support and two normal users.
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+ support = "@support1:test"
+ self.get_success(
+ self.store.register_user(
+ user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+ )
+ )
+
+ # Alice makes a room. Inject the support user into the room.
+ room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+ self.get_success(inject_member_event(self.hs, room, support, "join"))
+ # Check the DB state. The support user should not be in the directory.
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, set())
+ self.assertEqual(in_private, set())
+
+ # Then invite Bob, who accepts.
+ self.helper.invite(room, alice, bob, tok=alice_token)
+ self.helper.join(room, bob, tok=bob_token)
+
+ # Check the DB state. The support user should not be in the directory.
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, set())
+ self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)})
+
def test_spam_checker(self) -> None:
"""
A user which fails the spam checks will not appear in search results.
@@ -645,11 +762,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms()
)
- self.assertEqual(
- self.user_dir_helper._compress_shared(shares_private),
- {(u1, u2, room), (u2, u1, room)},
- )
- self.assertEqual(public_users, [])
+ self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+ self.assertEqual(public_users, set())
# We get one search result when searching for user2 by user1.
s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -704,11 +818,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.user_dir_helper.get_users_in_public_rooms()
)
- self.assertEqual(
- self.user_dir_helper._compress_shared(shares_private),
- {(u1, u2, room), (u2, u1, room)},
- )
- self.assertEqual(public_users, [])
+ self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+ self.assertEqual(public_users, set())
# Configure a spam checker.
spam_checker = self.hs.get_spam_checker()
@@ -740,8 +851,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
)
# No users share rooms
- self.assertEqual(public_users, [])
- self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
+ self.assertEqual(public_users, set())
+ self.assertEqual(shares_private, set())
# Despite not sharing a room, search_all_users means we get a search
# result.
@@ -842,6 +953,56 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.hs.get_storage().persistence.persist_event(event, context)
)
+ def test_local_user_leaving_room_remains_in_user_directory(self) -> None:
+ """We've chosen to simplify the user directory's implementation by
+ always including local users. Ensure this invariant is maintained when
+ a local user
+ - leaves a room, and
+ - leaves the last room they're in which is visible to this server.
+
+ This is user-visible if the "search_all_users" config option is on: the
+ local user who left a room would no longer be searchable if this test fails!
+ """
+ alice = self.register_user("alice", "pass")
+ alice_token = self.login(alice, "pass")
+ bob = self.register_user("bob", "pass")
+ bob_token = self.login(bob, "pass")
+
+ # Alice makes two public rooms, which Bob joins.
+ room1 = self.helper.create_room_as(alice, is_public=True, tok=alice_token)
+ room2 = self.helper.create_room_as(alice, is_public=True, tok=alice_token)
+ self.helper.join(room1, bob, tok=bob_token)
+ self.helper.join(room2, bob, tok=bob_token)
+
+ # The user directory tables are updated.
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(
+ in_public, {(alice, room1), (alice, room2), (bob, room1), (bob, room2)}
+ )
+ self.assertEqual(in_private, set())
+
+ # Alice leaves one room. She should still be in the directory.
+ self.helper.leave(room1, alice, tok=alice_token)
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, {(alice, room2), (bob, room1), (bob, room2)})
+ self.assertEqual(in_private, set())
+
+ # Alice leaves the other. She should still be in the directory.
+ self.helper.leave(room2, alice, tok=alice_token)
+ self.wait_for_background_updates()
+ users, in_public, in_private = self.get_success(
+ self.user_dir_helper.get_tables()
+ )
+ self.assertEqual(users, {alice, bob})
+ self.assertEqual(in_public, {(bob, room1), (bob, room2)})
+ self.assertEqual(in_private, set())
+
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
servlets = [
|