From cdd308845ba22fef22a39ed5bf904b438e48b491 Mon Sep 17 00:00:00 2001 From: Azrenbeth <77782548+Azrenbeth@users.noreply.github.com> Date: Wed, 13 Oct 2021 12:21:52 +0100 Subject: Port the Password Auth Providers module interface to the new generic interface (#10548) Co-authored-by: Azrenbeth <7782548+Azrenbeth@users.noreply.github.com> Co-authored-by: Brendan Abolivier --- tests/handlers/test_password_providers.py | 223 ++++++++++++++++++++++++++---- 1 file changed, 197 insertions(+), 26 deletions(-) (limited to 'tests/handlers') diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py index 38e6d9f536..7dd4a5a367 100644 --- a/tests/handlers/test_password_providers.py +++ b/tests/handlers/test_password_providers.py @@ -20,6 +20,8 @@ from unittest.mock import Mock from twisted.internet import defer import synapse +from synapse.handlers.auth import load_legacy_password_auth_providers +from synapse.module_api import ModuleApi from synapse.rest.client import devices, login from synapse.types import JsonDict @@ -36,8 +38,8 @@ ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_servi mock_password_provider = Mock() -class PasswordOnlyAuthProvider: - """A password_provider which only implements `check_password`.""" +class LegacyPasswordOnlyAuthProvider: + """A legacy password_provider which only implements `check_password`.""" @staticmethod def parse_config(self): @@ -50,8 +52,8 @@ class PasswordOnlyAuthProvider: return mock_password_provider.check_password(*args) -class CustomAuthProvider: - """A password_provider which implements a custom login type.""" +class LegacyCustomAuthProvider: + """A legacy password_provider which implements a custom login type.""" @staticmethod def parse_config(self): @@ -67,7 +69,23 @@ class CustomAuthProvider: return mock_password_provider.check_auth(*args) -class PasswordCustomAuthProvider: +class CustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type.""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={("test.login_type", ("test_field",)): self.check_auth}, + ) + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + +class LegacyPasswordCustomAuthProvider: """A password_provider which implements password login via `check_auth`, as well as a custom type.""" @@ -85,8 +103,32 @@ class PasswordCustomAuthProvider: return mock_password_provider.check_auth(*args) -def providers_config(*providers: Type[Any]) -> dict: - """Returns a config dict that will enable the given password auth providers""" +class PasswordCustomAuthProvider: + """A module which registers password_auth_provider callbacks for a custom login type. + as well as a password login""" + + @staticmethod + def parse_config(self): + pass + + def __init__(self, config, api: ModuleApi): + api.register_password_auth_provider_callbacks( + auth_checkers={ + ("test.login_type", ("test_field",)): self.check_auth, + ("m.login.password", ("password",)): self.check_auth, + }, + ) + pass + + def check_auth(self, *args): + return mock_password_provider.check_auth(*args) + + def check_pass(self, *args): + return mock_password_provider.check_password(*args) + + +def legacy_providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given legacy password auth providers""" return { "password_providers": [ {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} @@ -95,6 +137,16 @@ def providers_config(*providers: Type[Any]) -> dict: } +def providers_config(*providers: Type[Any]) -> dict: + """Returns a config dict that will enable the given modules""" + return { + "modules": [ + {"module": "%s.%s" % (__name__, provider.__qualname__), "config": {}} + for provider in providers + ] + } + + class PasswordAuthProviderTests(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, @@ -107,8 +159,21 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() super().setUp() - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_login(self): + def make_homeserver(self, reactor, clock): + hs = self.setup_test_homeserver() + # Load the modules into the homeserver + module_api = hs.get_module_api() + for module, config in hs.config.modules.loaded_modules: + module(config=config, api=module_api) + load_legacy_password_auth_providers(hs) + + return hs + + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_progiver_login_legacy(self): + self.password_only_auth_provider_login_test_body() + + def password_only_auth_provider_login_test_body(self): # login flows should only have m.login.password flows = self._get_login_flows() self.assertEqual(flows, [{"type": "m.login.password"}] + ADDITIONAL_LOGIN_FLOWS) @@ -138,8 +203,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "@ USER🙂NAME :test", " pASS😢word " ) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_password_only_auth_provider_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_password_only_auth_provider_ui_auth_legacy(self): + self.password_only_auth_provider_ui_auth_test_body() + + def password_only_auth_provider_ui_auth_test_body(self): """UI Auth should delegate correctly to the password provider""" # create the user, otherwise access doesn't work @@ -172,8 +240,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200) mock_password_provider.check_password.assert_called_once_with("@u:test", "p") - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_login(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_login_legacy(self): + self.local_user_fallback_login_test_body() + + def local_user_fallback_login_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -186,8 +257,11 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@localuser:test", channel.json_body["user_id"]) - @override_config(providers_config(PasswordOnlyAuthProvider)) - def test_local_user_fallback_ui_auth(self): + @override_config(legacy_providers_config(LegacyPasswordOnlyAuthProvider)) + def test_local_user_fallback_ui_auth_legacy(self): + self.local_user_fallback_ui_auth_test_body() + + def local_user_fallback_ui_auth_test_body(self): """rejected login should fall back to local db""" self.register_user("localuser", "localpass") @@ -223,11 +297,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_login(self): + def test_no_local_user_fallback_login_legacy(self): + self.no_local_user_fallback_login_test_body() + + def no_local_user_fallback_login_test_body(self): """localdb_enabled can block login with the local password""" self.register_user("localuser", "localpass") @@ -242,11 +319,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"localdb_enabled": False}, } ) - def test_no_local_user_fallback_ui_auth(self): + def test_no_local_user_fallback_ui_auth_legacy(self): + self.no_local_user_fallback_ui_auth_test_body() + + def no_local_user_fallback_ui_auth_test_body(self): """localdb_enabled can block ui auth with the local password""" self.register_user("localuser", "localpass") @@ -280,11 +360,14 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): @override_config( { - **providers_config(PasswordOnlyAuthProvider), + **legacy_providers_config(LegacyPasswordOnlyAuthProvider), "password_config": {"enabled": False}, } ) - def test_password_auth_disabled(self): + def test_password_auth_disabled_legacy(self): + self.password_auth_disabled_test_body() + + def password_auth_disabled_test_body(self): """password auth doesn't work if it's disabled across the board""" # login flows should be empty flows = self._get_login_flows() @@ -295,8 +378,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_password.assert_not_called() + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_login_legacy(self): + self.custom_auth_provider_login_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_login(self): + self.custom_auth_provider_login_test_body() + + def custom_auth_provider_login_test_body(self): # login flows should have the custom flow and m.login.password, since we # haven't disabled local password lookup. # (password must come first, because reasons) @@ -312,7 +402,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) channel = self._send_login("test.login_type", "u", test_field="y") self.assertEqual(channel.code, 200, channel.result) self.assertEqual("@user:bz", channel.json_body["user_id"]) @@ -325,7 +417,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # in these cases, but at least we can guard against the API changing # unexpectedly mock_password_provider.check_auth.return_value = defer.succeed( - "@ MALFORMED! :bz" + ("@ MALFORMED! :bz", None) ) channel = self._send_login("test.login_type", " USER🙂NAME ", test_field=" abc ") self.assertEqual(channel.code, 200, channel.result) @@ -334,8 +426,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): " USER🙂NAME ", "test.login_type", {"test_field": " abc "} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_ui_auth_legacy(self): + self.custom_auth_provider_ui_auth_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_ui_auth(self): + self.custom_auth_provider_ui_auth_test_body() + + def custom_auth_provider_ui_auth_test_body(self): # register the user and log in twice, to get two devices self.register_user("localuser", "localpass") tok1 = self.login("localuser", "localpass") @@ -367,7 +466,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.reset_mock() # right params, but authing as the wrong user - mock_password_provider.check_auth.return_value = defer.succeed("@user:bz") + mock_password_provider.check_auth.return_value = defer.succeed( + ("@user:bz", None) + ) body["auth"]["test_field"] = "foo" channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 403) @@ -379,7 +480,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): # and finally, succeed mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._delete_device(tok1, "dev2", body) self.assertEqual(channel.code, 200) @@ -387,8 +488,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "localuser", "test.login_type", {"test_field": "foo"} ) + @override_config(legacy_providers_config(LegacyCustomAuthProvider)) + def test_custom_auth_provider_callback_legacy(self): + self.custom_auth_provider_callback_test_body() + @override_config(providers_config(CustomAuthProvider)) def test_custom_auth_provider_callback(self): + self.custom_auth_provider_callback_test_body() + + def custom_auth_provider_callback_test_body(self): callback = Mock(return_value=defer.succeed(None)) mock_password_provider.check_auth.return_value = defer.succeed( @@ -410,10 +518,22 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): for p in ["user_id", "access_token", "device_id", "home_server"]: self.assertIn(p, call_args[0]) + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_custom_auth_password_disabled_legacy(self): + self.custom_auth_password_disabled_test_body() + @override_config( {**providers_config(CustomAuthProvider), "password_config": {"enabled": False}} ) def test_custom_auth_password_disabled(self): + self.custom_auth_password_disabled_test_body() + + def custom_auth_password_disabled_test_body(self): """Test login with a custom auth provider where password login is disabled""" self.register_user("localuser", "localpass") @@ -425,6 +545,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"enabled": False, "localdb_enabled": False}, + } + ) + def test_custom_auth_password_disabled_localdb_enabled_legacy(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + @override_config( { **providers_config(CustomAuthProvider), @@ -432,6 +561,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_custom_auth_password_disabled_localdb_enabled(self): + self.custom_auth_password_disabled_localdb_enabled_test_body() + + def custom_auth_password_disabled_localdb_enabled_test_body(self): """Check the localdb_enabled == enabled == False Regression test for https://github.com/matrix-org/synapse/issues/8914: check @@ -448,6 +580,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_login_legacy(self): + self.password_custom_auth_password_disabled_login_test_body() + @override_config( { **providers_config(PasswordCustomAuthProvider), @@ -455,6 +596,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_password_custom_auth_password_disabled_login(self): + self.password_custom_auth_password_disabled_login_test_body() + + def password_custom_auth_password_disabled_login_test_body(self): """log in with a custom auth provider which implements password, but password login is disabled""" self.register_user("localuser", "localpass") @@ -466,6 +610,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): channel = self._send_password_login("localuser", "localpass") self.assertEqual(channel.code, 400, channel.result) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyPasswordCustomAuthProvider), + "password_config": {"enabled": False}, + } + ) + def test_password_custom_auth_password_disabled_ui_auth_legacy(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() @override_config( { @@ -474,12 +628,15 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_password_custom_auth_password_disabled_ui_auth(self): + self.password_custom_auth_password_disabled_ui_auth_test_body() + + def password_custom_auth_password_disabled_ui_auth_test_body(self): """UI Auth with a custom auth provider which implements password, but password login is disabled""" # register the user and log in twice via the test login type to get two devices, self.register_user("localuser", "localpass") mock_password_provider.check_auth.return_value = defer.succeed( - "@localuser:test" + ("@localuser:test", None) ) channel = self._send_login("test.login_type", "localuser", test_field="") self.assertEqual(channel.code, 200, channel.result) @@ -516,6 +673,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): "Password login has been disabled.", channel.json_body["error"] ) mock_password_provider.check_auth.assert_not_called() + mock_password_provider.check_password.assert_not_called() mock_password_provider.reset_mock() # successful auth @@ -526,6 +684,16 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): mock_password_provider.check_auth.assert_called_once_with( "localuser", "test.login_type", {"test_field": "x"} ) + mock_password_provider.check_password.assert_not_called() + + @override_config( + { + **legacy_providers_config(LegacyCustomAuthProvider), + "password_config": {"localdb_enabled": False}, + } + ) + def test_custom_auth_no_local_user_fallback_legacy(self): + self.custom_auth_no_local_user_fallback_test_body() @override_config( { @@ -534,6 +702,9 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase): } ) def test_custom_auth_no_local_user_fallback(self): + self.custom_auth_no_local_user_fallback_test_body() + + def custom_auth_no_local_user_fallback_test_body(self): """Test login with a custom auth provider where the local db is disabled""" self.register_user("localuser", "localpass") -- cgit 1.5.1 From daf498e099394e206709bbc7a330be4a989e31d5 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 14 Oct 2021 18:53:45 -0500 Subject: Fix 500 error on `/messages` when we accumulate more than 5 backward extremities (#11027) Found while working on the Gitter backfill script and noticed it only happened after we sent 7 batches, https://gitlab.com/gitterHQ/webapp/-/merge_requests/2229#note_665906390 When there are more than 5 backward extremities for a given depth, backfill will throw an error because we sliced the extremity list to 5 but then try to iterate over the full list. This causes us to look for state that we never fetched and we get a `KeyError`. Before when calling `/messages` when there are more than 5 backward extremities: ``` Traceback (most recent call last): File "/usr/local/lib/python3.8/site-packages/synapse/http/server.py", line 258, in _async_render_wrapper callback_return = await self._async_render(request) File "/usr/local/lib/python3.8/site-packages/synapse/http/server.py", line 446, in _async_render callback_return = await raw_callback_return File "/usr/local/lib/python3.8/site-packages/synapse/rest/client/room.py", line 580, in on_GET msgs = await self.pagination_handler.get_messages( File "/usr/local/lib/python3.8/site-packages/synapse/handlers/pagination.py", line 396, in get_messages await self.hs.get_federation_handler().maybe_backfill( File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 133, in maybe_backfill return await self._maybe_backfill_inner(room_id, current_depth, limit) File "/usr/local/lib/python3.8/site-packages/synapse/handlers/federation.py", line 386, in _maybe_backfill_inner likely_extremeties_domains = get_domains_from_state(states[e_id]) KeyError: '$zpFflMEBtZdgcMQWTakaVItTLMjLFdKcRWUPHbbSZJl' ``` --- changelog.d/11027.bugfix | 1 + synapse/handlers/federation.py | 24 +++++++------- synapse/handlers/federation_event.py | 2 +- tests/handlers/test_federation.py | 64 ++++++++++++++++++++++++++++++++++++ 4 files changed, 79 insertions(+), 12 deletions(-) create mode 100644 changelog.d/11027.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/11027.bugfix b/changelog.d/11027.bugfix new file mode 100644 index 0000000000..ae6cc44470 --- /dev/null +++ b/changelog.d/11027.bugfix @@ -0,0 +1 @@ +Fix 500 error on `/messages` when the server accumulates more than 5 backwards extremities at a given depth for a room. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 3e341bd287..e072efad16 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -238,18 +238,10 @@ class FederationHandler: ) return False - logger.debug( - "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s", - room_id, - current_depth, - max_depth, - sorted_extremeties_tuple, - ) - # We ignore extremities that have a greater depth than our current depth # as: # 1. we don't really care about getting events that have happened - # before our current position; and + # after our current position; and # 2. we have likely previously tried and failed to backfill from that # extremity, so to avoid getting "stuck" requesting the same # backfill repeatedly we drop those extremities. @@ -257,9 +249,19 @@ class FederationHandler: t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth ] + logger.debug( + "room_id: %s, backfill: current_depth: %s, limit: %s, max_depth: %s, extrems: %s filtered_sorted_extremeties_tuple: %s", + room_id, + current_depth, + limit, + max_depth, + sorted_extremeties_tuple, + filtered_sorted_extremeties_tuple, + ) + # However, we need to check that the filtered extremities are non-empty. # If they are empty then either we can a) bail or b) still attempt to - # backill. We opt to try backfilling anyway just in case we do get + # backfill. We opt to try backfilling anyway just in case we do get # relevant events. if filtered_sorted_extremeties_tuple: sorted_extremeties_tuple = filtered_sorted_extremeties_tuple @@ -389,7 +391,7 @@ class FederationHandler: for key, state_dict in states.items() } - for e_id, _ in sorted_extremeties_tuple: + for e_id in event_ids: likely_extremeties_domains = get_domains_from_state(states[e_id]) success = await try_backfill( diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index f640b417b3..0e455678aa 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -392,7 +392,7 @@ class FederationEventHandler: @log_function async def backfill( - self, dest: str, room_id: str, limit: int, extremities: List[str] + self, dest: str, room_id: str, limit: int, extremities: Iterable[str] ) -> None: """Trigger a backfill request to `dest` for the given `room_id` diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 936ebf3dde..e1557566e4 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -23,6 +23,7 @@ from synapse.federation.federation_base import event_from_pdu_json from synapse.logging.context import LoggingContext, run_in_background from synapse.rest import admin from synapse.rest.client import login, room +from synapse.types import create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -30,6 +31,10 @@ from tests import unittest logger = logging.getLogger(__name__) +def generate_fake_event_id() -> str: + return "$fake_" + random_string(43) + + class FederationTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, @@ -198,6 +203,65 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) + def test_backfill_with_many_backward_extremities(self): + """ + Check that we can backfill with many backward extremities. + The goal is to make sure that when we only use a portion + of backwards extremities(the magic number is more than 5), + no errors are thrown. + + Regression test, see #11027 + """ + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + requester = create_requester(user_id) + + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + + ev1 = self.helper.send(room_id, "first message", tok=tok) + + # Create "many" backward extremities. The magic number we're trying to + # create more than is 5 which corresponds to the number of backward + # extremities we slice off in `_maybe_backfill_inner` + for _ in range(0, 8): + event_handler = self.hs.get_event_creation_handler() + event, context = self.get_success( + event_handler.create_event( + requester, + { + "type": "m.room.message", + "content": { + "msgtype": "m.text", + "body": "message connected to fake event", + }, + "room_id": room_id, + "sender": user_id, + }, + prev_event_ids=[ + ev1["event_id"], + # We're creating an backward extremity each time thanks + # to this fake event + generate_fake_event_id(), + ], + ) + ) + self.get_success( + event_handler.handle_new_client_event(requester, event, context) + ) + + current_depth = 1 + limit = 100 + with LoggingContext("receive_pdu"): + # Make sure backfill still works + d = run_in_background( + self.hs.get_federation_handler().maybe_backfill, + room_id, + current_depth, + limit, + ) + self.get_success(d) + def test_backfill_floating_outlier_membership_auth(self): """ As the local homeserver, check that we can properly process a federated -- cgit 1.5.1 From e09be0c87a8c0da1381e7986cc940d1411705081 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 15 Oct 2021 15:53:05 +0100 Subject: Correctly exclude users when making a room public or private (#11075) Co-authored-by: Patrick Cloke --- changelog.d/11075.bugfix | 1 + synapse/handlers/user_directory.py | 11 ++- tests/handlers/test_user_directory.py | 142 +++++++++++++++++++++++++--------- tests/storage/test_user_directory.py | 77 ++++++++---------- 4 files changed, 148 insertions(+), 83 deletions(-) create mode 100644 changelog.d/11075.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/11075.bugfix b/changelog.d/11075.bugfix new file mode 100644 index 0000000000..9b24971c5a --- /dev/null +++ b/changelog.d/11075.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where users excluded from the user directory were added into the directory if they belonged to a room which became public or private. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 52b2de388f..99f23ed967 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -266,14 +266,17 @@ class UserDirectoryHandler(StateDeltasHandler): for user_id in users_in_room: await self.store.remove_user_who_share_room(user_id, room_id) - # Then, re-add them to the tables. + # Then, re-add all remote users and some local users to the tables. # NOTE: this is not the most efficient method, as _track_user_joined_room sets # up local_user -> other_user and other_user_whos_local -> local_user, # which when ran over an entire room, will result in the same values # being added multiple times. The batching upserts shouldn't make this # too bad, though. for user_id in users_in_room: - await self._track_user_joined_room(room_id, user_id) + if not self.is_mine_id( + user_id + ) or await self.store.should_include_local_user_in_dir(user_id): + await self._track_user_joined_room(room_id, user_id) async def _handle_room_membership_event( self, @@ -364,8 +367,8 @@ class UserDirectoryHandler(StateDeltasHandler): """Someone's just joined a room. Update `users_in_public_rooms` or `users_who_share_private_rooms` as appropriate. - The caller is responsible for ensuring that the given user is not excluded - from the user directory. + The caller is responsible for ensuring that the given user should be + included in the user directory. """ is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 0120b4688b..e0635c8898 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): tok=alice_token, ) - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) - in_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() + # The user directory should reflect the room memberships above. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) - self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)}) self.assertEqual( - set(in_public), {(alice, public), (bob, public), (alice, public2)} - ) - self.assertEqual( - self.user_dir_helper._compress_shared(in_private), + in_private, {(alice, bob, private), (bob, alice, private)}, ) @@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) self.assertEqual(set(in_public), {(user1, room), (user2, room)}) + def test_excludes_users_when_making_room_public(self) -> None: + # Create a regular user and a support user. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + # Make a public and private room containing Alice and the support user + public, initially_private = self._create_rooms_and_inject_memberships( + alice, alice_token, support + ) + self._check_only_one_user_in_directory(alice, public) + + # Alice makes the private room public. + self.helper.send_state( + initially_private, + "m.room.join_rules", + {"join_rule": "public"}, + tok=alice_token, + ) + + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice}) + self.assertEqual(in_public, {(alice, public), (alice, initially_private)}) + self.assertEqual(in_private, set()) + + def test_switching_from_private_to_public_to_private(self) -> None: + """Check we update the room sharing tables when switching a room + from private to public, then back again to private.""" + # Alice and Bob share a private room. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + room = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.helper.invite(room, alice, bob, tok=alice_token) + self.helper.join(room, bob, tok=bob_token) + + # The user directory should reflect this. + def check_user_dir_for_private_room() -> None: + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, set()) + self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)}) + + check_user_dir_for_private_room() + + # Alice makes the room public. + self.helper.send_state( + room, + "m.room.join_rules", + {"join_rule": "public"}, + tok=alice_token, + ) + + # The user directory should be updated accordingly + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, room), (bob, room)}) + self.assertEqual(in_private, set()) + + # Alice makes the room private. + self.helper.send_state( + room, + "m.room.join_rules", + {"join_rule": "invite"}, + tok=alice_token, + ) + + # The user directory should be updated accordingly + check_user_dir_for_private_room() + def _create_rooms_and_inject_memberships( self, creator: str, token: str, joiner: str ) -> Tuple[str, str]: @@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): return public_room, private_room def _check_only_one_user_in_directory(self, user: str, public: str) -> None: - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms()) - in_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) + """Check that the user directory DB tables show that: + - only one user is in the user directory + - they belong to exactly one public room + - they don't share a private room with anyone. + """ + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) self.assertEqual(users, {user}) - self.assertEqual(set(in_public), {(user, public)}) - self.assertEqual(in_private, []) + self.assertEqual(in_public, {(user, public)}) + self.assertEqual(in_private, set()) def test_handle_local_profile_change_with_support_user(self) -> None: support_user_id = "@support:test" @@ -581,11 +662,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # We get one search result when searching for user2 by user1. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -610,8 +688,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, set()) + self.assertEqual(public_users, set()) # User1 now gets no search results for any of the other users. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -645,11 +723,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # We get one search result when searching for user2 by user1. s = self.get_success(self.handler.search_users(u1, "user2", 10)) @@ -704,11 +779,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.user_dir_helper.get_users_in_public_rooms() ) - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u2, room), (u2, u1, room)}, - ) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) + self.assertEqual(public_users, set()) # Configure a spam checker. spam_checker = self.hs.get_spam_checker() @@ -740,8 +812,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): ) # No users share rooms - self.assertEqual(public_users, []) - self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set()) + self.assertEqual(public_users, set()) + self.assertEqual(shares_private, set()) # Despite not sharing a room, search_all_users means we get a search # result. diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index be3ed64f5e..37cf7bb232 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Set, Tuple +from typing import Any, Dict, Set, Tuple from unittest import mock from unittest.mock import Mock, patch @@ -42,18 +42,7 @@ class GetUserDirectoryTables: def __init__(self, store: DataStore): self.store = store - def _compress_shared( - self, shared: List[Dict[str, str]] - ) -> Set[Tuple[str, str, str]]: - """ - Compress a list of users who share rooms dicts to a list of tuples. - """ - r = set() - for i in shared: - r.add((i["user_id"], i["other_user_id"], i["room_id"])) - return r - - async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]: + async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]: """Fetch the entire `users_in_public_rooms` table. Returns a list of tuples (user_id, room_id) where room_id is public and @@ -63,24 +52,27 @@ class GetUserDirectoryTables: "users_in_public_rooms", None, ("user_id", "room_id") ) - retval = [] + retval = set() for i in r: - retval.append((i["user_id"], i["room_id"])) + retval.add((i["user_id"], i["room_id"])) return retval - async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]: + async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]: """Fetch the entire `users_who_share_private_rooms` table. - Returns a dict containing "user_id", "other_user_id" and "room_id" keys. - The dicts can be flattened to Tuples with the `_compress_shared` method. - (This seems a little awkward---maybe we could clean this up.) + Returns a set of tuples (user_id, other_user_id, room_id) corresponding + to the rows of `users_who_share_private_rooms`. """ - return await self.store.db_pool.simple_select_list( + rows = await self.store.db_pool.simple_select_list( "users_who_share_private_rooms", None, ["user_id", "other_user_id", "room_id"], ) + rv = set() + for row in rows: + rv.add((row["user_id"], row["other_user_id"], row["room_id"])) + return rv async def get_users_in_user_directory(self) -> Set[str]: """Fetch the set of users in the `user_directory` table. @@ -113,6 +105,16 @@ class GetUserDirectoryTables: for row in rows } + async def get_tables( + self, + ) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]: + """Multiple tests want to inspect these tables, so expose them together.""" + return ( + await self.get_users_in_user_directory(), + await self.get_users_in_public_rooms(), + await self.get_users_who_share_private_rooms(), + ) + class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): """Ensure that rebuilding the directory writes the correct data to the DB. @@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): ) # Nothing updated yet - self.assertEqual(shares_private, []) - self.assertEqual(public_users, []) + self.assertEqual(shares_private, set()) + self.assertEqual(public_users, set()) # Ugh, have to reset this flag self.store.db_pool.updates._all_done = False @@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): # Do the initial population of the user directory via the background update self._purge_and_rebuild_user_dir() - shares_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - public_users = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) # User 1 and User 2 are in the same public room - self.assertEqual(set(public_users), {(u1, room), (u2, room)}) - + self.assertEqual(in_public, {(u1, room), (u2, room)}) # User 1 and User 3 share private rooms - self.assertEqual( - self.user_dir_helper._compress_shared(shares_private), - {(u1, u3, private_room), (u3, u1, private_room)}, - ) - + self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)}) # All three should have entries in the directory - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) self.assertEqual(users, {u1, u2, u3}) # The next four tests (test_population_excludes_*) all set up @@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): self, normal_user: str, public_room: str, private_room: str ) -> None: # After rebuilding the directory, we should only see the normal user. - users = self.get_success(self.user_dir_helper.get_users_in_user_directory()) - self.assertEqual(users, {normal_user}) - in_public_rooms = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() ) - self.assertEqual(set(in_public_rooms), {(normal_user, public_room)}) - in_private_rooms = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - self.assertEqual(in_private_rooms, []) + self.assertEqual(users, {normal_user}) + self.assertEqual(in_public, {(normal_user, public_room)}) + self.assertEqual(in_private, set()) def test_population_excludes_support_user(self) -> None: # Create a normal and support user. -- cgit 1.5.1 From 37b845dabc687b1e0d4bc84bf5933db10db641d5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 18 Oct 2021 14:20:04 +0100 Subject: Don't remove local users from dir when the leave their last room (#11103) --- changelog.d/11103.bugfix | 1 + synapse/handlers/user_directory.py | 13 +++++---- tests/handlers/test_user_directory.py | 50 +++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 5 deletions(-) create mode 100644 changelog.d/11103.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/11103.bugfix b/changelog.d/11103.bugfix new file mode 100644 index 0000000000..3498f04a45 --- /dev/null +++ b/changelog.d/11103.bugfix @@ -0,0 +1 @@ +Fix local users who left all their rooms being removed from the user directory, even if the "search_all_users" config option was enabled. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 99f23ed967..991fee7e58 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -415,16 +415,19 @@ class UserDirectoryHandler(StateDeltasHandler): room_id: The room ID that user left or stopped being public that user_id """ - logger.debug("Removing user %r", user_id) + logger.debug("Removing user %r from room %r", user_id, room_id) # Remove user from sharing tables await self.store.remove_user_who_share_room(user_id, room_id) - # Are they still in any rooms? If not, remove them entirely. - rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id) + # Additionally, if they're a remote user and we're no longer joined + # to any rooms they're in, remove them from the user directory. + if not self.is_mine_id(user_id): + rooms_user_is_in = await self.store.get_user_dir_rooms_user_is_in(user_id) - if len(rooms_user_is_in) == 0: - await self.store.remove_from_user_dir(user_id) + if len(rooms_user_is_in) == 0: + logger.debug("Removing user %r from directory", user_id) + await self.store.remove_from_user_dir(user_id) async def _handle_possible_remote_profile_change( self, diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index e0635c8898..b9ad92b977 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -914,6 +914,56 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.hs.get_storage().persistence.persist_event(event, context) ) + def test_local_user_leaving_room_remains_in_user_directory(self) -> None: + """We've chosen to simplify the user directory's implementation by + always including local users. Ensure this invariant is maintained when + a local user + - leaves a room, and + - leaves the last room they're in which is visible to this server. + + This is user-visible if the "search_all_users" config option is on: the + local user who left a room would no longer be searchable if this test fails! + """ + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + + # Alice makes two public rooms, which Bob joins. + room1 = self.helper.create_room_as(alice, is_public=True, tok=alice_token) + room2 = self.helper.create_room_as(alice, is_public=True, tok=alice_token) + self.helper.join(room1, bob, tok=bob_token) + self.helper.join(room2, bob, tok=bob_token) + + # The user directory tables are updated. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual( + in_public, {(alice, room1), (alice, room2), (bob, room1), (bob, room2)} + ) + self.assertEqual(in_private, set()) + + # Alice leaves one room. She should still be in the directory. + self.helper.leave(room1, alice, tok=alice_token) + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(alice, room2), (bob, room1), (bob, room2)}) + self.assertEqual(in_private, set()) + + # Alice leaves the other. She should still be in the directory. + self.helper.leave(room2, alice, tok=alice_token) + self.wait_for_background_updates() + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, {(bob, room1), (bob, room2)}) + self.assertEqual(in_private, set()) + class TestUserDirSearchDisabled(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From 2d91b6256e53a9e60027880b0407bd77cb653ad1 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 21 Oct 2021 17:48:59 +0100 Subject: Fix adding excluded users to the private room sharing tables when joining a room (#11143) * We only need to fetch users in private rooms * Filter out `user_id` at the top * Discard excluded users in the top loop We weren't doing this in the "First, if they're our user" branch so this is a bugfix. * The caller must check that `user_id` is included This is in the docstring. There are two call sites: - one in `_handle_room_publicity_change`, which explicitly checks before calling; - and another in `_handle_room_membership_event`, which returns early if the user is excluded. So this change is safe. * Test joining a private room with an excluded user * Tweak an existing test * Changelog * test docstring * lint --- changelog.d/11143.misc | 1 + synapse/handlers/user_directory.py | 28 +++++++-------- tests/handlers/test_user_directory.py | 67 +++++++++++++++++++++++++++-------- 3 files changed, 67 insertions(+), 29 deletions(-) create mode 100644 changelog.d/11143.misc (limited to 'tests/handlers') diff --git a/changelog.d/11143.misc b/changelog.d/11143.misc new file mode 100644 index 0000000000..496e44a9c0 --- /dev/null +++ b/changelog.d/11143.misc @@ -0,0 +1 @@ +Fix a long-standing bug where users excluded from the directory could still be added to the `users_who_share_private_rooms` table after a regular user joins a private room. \ No newline at end of file diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 991fee7e58..a0eb45446f 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -373,31 +373,29 @@ class UserDirectoryHandler(StateDeltasHandler): is_public = await self.store.is_room_world_readable_or_publicly_joinable( room_id ) - other_users_in_room = await self.store.get_users_in_room(room_id) - if is_public: await self.store.add_users_in_public_rooms(room_id, (user_id,)) else: + users_in_room = await self.store.get_users_in_room(room_id) + other_users_in_room = [ + other + for other in users_in_room + if other != user_id + and ( + not self.is_mine_id(other) + or await self.store.should_include_local_user_in_dir(other) + ) + ] to_insert = set() # First, if they're our user then we need to update for every user if self.is_mine_id(user_id): - if await self.store.should_include_local_user_in_dir(user_id): - for other_user_id in other_users_in_room: - if user_id == other_user_id: - continue - - to_insert.add((user_id, other_user_id)) + for other_user_id in other_users_in_room: + to_insert.add((user_id, other_user_id)) # Next we need to update for every local user in the room for other_user_id in other_users_in_room: - if user_id == other_user_id: - continue - - include_other_user = self.is_mine_id( - other_user_id - ) and await self.store.should_include_local_user_in_dir(other_user_id) - if include_other_user: + if self.is_mine_id(other_user_id): to_insert.add((other_user_id, user_id)) if to_insert: diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index b9ad92b977..70c621b825 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -646,22 +646,20 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): u2_token = self.login(u2, "pass") u3 = self.register_user("user3", "pass") - # We do not add users to the directory until they join a room. + # u1 can't see u2 until they share a private room, or u1 is in a public room. s = self.get_success(self.handler.search_users(u1, "user2", 10)) self.assertEqual(len(s["results"]), 0) + # Get u1 and u2 into a private room. room = self.helper.create_room_as(u1, is_public=False, tok=u1_token) self.helper.invite(room, src=u1, targ=u2, tok=u1_token) self.helper.join(room, user=u2, tok=u2_token) # Check we have populated the database correctly. - shares_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - public_users = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + users, public_users, shares_private = self.get_success( + self.user_dir_helper.get_tables() ) - + self.assertEqual(users, {u1, u2, u3}) self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)}) self.assertEqual(public_users, set()) @@ -680,14 +678,11 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # User 2 then leaves. self.helper.leave(room, user=u2, tok=u2_token) - # Check we have removed the values. - shares_private = self.get_success( - self.user_dir_helper.get_users_who_share_private_rooms() - ) - public_users = self.get_success( - self.user_dir_helper.get_users_in_public_rooms() + # Check this is reflected in the DB. + users, public_users, shares_private = self.get_success( + self.user_dir_helper.get_tables() ) - + self.assertEqual(users, {u1, u2, u3}) self.assertEqual(shares_private, set()) self.assertEqual(public_users, set()) @@ -698,6 +693,50 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): s = self.get_success(self.handler.search_users(u1, "user3", 10)) self.assertEqual(len(s["results"]), 0) + def test_joining_private_room_with_excluded_user(self) -> None: + """ + When a user excluded from the user directory, E say, joins a private + room, E will not appear in the `users_who_share_private_rooms` table. + + When a normal user, U say, joins a private room containing E, then + U will appear in the `users_who_share_private_rooms` table, but E will + not. + """ + # Setup a support and two normal users. + alice = self.register_user("alice", "pass") + alice_token = self.login(alice, "pass") + bob = self.register_user("bob", "pass") + bob_token = self.login(bob, "pass") + support = "@support1:test" + self.get_success( + self.store.register_user( + user_id=support, password_hash=None, user_type=UserTypes.SUPPORT + ) + ) + + # Alice makes a room. Inject the support user into the room. + room = self.helper.create_room_as(alice, is_public=False, tok=alice_token) + self.get_success(inject_member_event(self.hs, room, support, "join")) + # Check the DB state. The support user should not be in the directory. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, set()) + self.assertEqual(in_private, set()) + + # Then invite Bob, who accepts. + self.helper.invite(room, alice, bob, tok=alice_token) + self.helper.join(room, bob, tok=bob_token) + + # Check the DB state. The support user should not be in the directory. + users, in_public, in_private = self.get_success( + self.user_dir_helper.get_tables() + ) + self.assertEqual(users, {alice, bob}) + self.assertEqual(in_public, set()) + self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)}) + def test_spam_checker(self) -> None: """ A user which fails the spam checks will not appear in search results. -- cgit 1.5.1 From 4387b791e01eb1a207fe44fecbc901eead8eb4db Mon Sep 17 00:00:00 2001 From: AndrewFerr Date: Mon, 25 Oct 2021 10:24:49 -0400 Subject: Don't set new room alias before potential 403 (#10930) Fixes: #10929 Signed-off-by: Andrew Ferrazzutti --- changelog.d/10930.bugfix | 1 + synapse/handlers/directory.py | 4 +- synapse/handlers/room.py | 18 +++---- tests/handlers/test_directory.py | 102 ++++++++++++++++++++++++++++++++++++++- 4 files changed, 113 insertions(+), 12 deletions(-) create mode 100644 changelog.d/10930.bugfix (limited to 'tests/handlers') diff --git a/changelog.d/10930.bugfix b/changelog.d/10930.bugfix new file mode 100644 index 0000000000..756bfe9107 --- /dev/null +++ b/changelog.d/10930.bugfix @@ -0,0 +1 @@ +Newly-created public rooms are now only assigned an alias if the room's creation has not been blocked by permission settings. Contributed by @AndrewFerr. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 14ed7d9879..8567cb0e00 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -145,7 +145,7 @@ class DirectoryHandler: if not self.config.roomdirectory.is_alias_creation_allowed( user_id, room_id, room_alias_str ): - # Lets just return a generic message, as there may be all sorts of + # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? raise SynapseError(403, "Not allowed to create alias") @@ -461,7 +461,7 @@ class DirectoryHandler: if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases ): - # Lets just return a generic message, as there may be all sorts of + # Let's just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? raise SynapseError(403, "Not allowed to publish room") diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 6f39e9446f..cf01d58ea1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -773,6 +773,15 @@ class RoomCreationHandler: if not allowed_by_third_party_rules: raise SynapseError(403, "Room visibility value not allowed.") + if is_public: + if not self.config.roomdirectory.is_publishing_room_allowed( + user_id, room_id, room_alias + ): + # Let's just return a generic message, as there may be all sorts of + # reasons why we said no. TODO: Allow configurable error messages + # per alias creation rule? + raise SynapseError(403, "Not allowed to publish room") + directory_handler = self.hs.get_directory_handler() if room_alias: await directory_handler.create_association( @@ -783,15 +792,6 @@ class RoomCreationHandler: check_membership=False, ) - if is_public: - if not self.config.roomdirectory.is_publishing_room_allowed( - user_id, room_id, room_alias - ): - # Lets just return a generic message, as there may be all sorts of - # reasons why we said no. TODO: Allow configurable error messages - # per alias creation rule? - raise SynapseError(403, "Not allowed to publish room") - preset_config = config.get( "preset", RoomCreationPreset.PRIVATE_CHAT diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 6a2e76ca4a..be008227df 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -15,8 +15,8 @@ from unittest.mock import Mock -import synapse import synapse.api.errors +import synapse.rest.admin from synapse.api.constants import EventTypes from synapse.config.room_directory import RoomDirectoryConfig from synapse.rest.client import directory, login, room @@ -432,6 +432,106 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.result) +class TestCreatePublishedRoomACL(unittest.HomeserverTestCase): + data = {"room_alias_name": "unofficial_test"} + + servlets = [ + synapse.rest.admin.register_servlets_for_client_rest_resource, + login.register_servlets, + directory.register_servlets, + room.register_servlets, + ] + hijack_auth = False + + def prepare(self, reactor, clock, hs): + self.allowed_user_id = self.register_user("allowed", "pass") + self.allowed_access_token = self.login("allowed", "pass") + + self.denied_user_id = self.register_user("denied", "pass") + self.denied_access_token = self.login("denied", "pass") + + # This time we add custom room list publication rules + config = {} + config["alias_creation_rules"] = [] + config["room_list_publication_rules"] = [ + {"user_id": "*", "alias": "*", "action": "deny"}, + {"user_id": self.allowed_user_id, "alias": "*", "action": "allow"}, + ] + + rd_config = RoomDirectoryConfig() + rd_config.read_config(config) + + self.hs.config.roomdirectory.is_publishing_room_allowed = ( + rd_config.is_publishing_room_allowed + ) + + return hs + + def test_denied_without_publication_permission(self): + """ + Try to create a room, register an alias for it, and publish it, + as a user without permission to publish rooms. + (This is used as both a standalone test & as a helper function.) + """ + self.helper.create_room_as( + self.denied_user_id, + tok=self.denied_access_token, + extra_content=self.data, + is_public=True, + expect_code=403, + ) + + def test_allowed_when_creating_private_room(self): + """ + Try to create a room, register an alias for it, and NOT publish it, + as a user without permission to publish rooms. + (This is used as both a standalone test & as a helper function.) + """ + self.helper.create_room_as( + self.denied_user_id, + tok=self.denied_access_token, + extra_content=self.data, + is_public=False, + expect_code=200, + ) + + def test_allowed_with_publication_permission(self): + """ + Try to create a room, register an alias for it, and publish it, + as a user WITH permission to publish rooms. + (This is used as both a standalone test & as a helper function.) + """ + self.helper.create_room_as( + self.allowed_user_id, + tok=self.allowed_access_token, + extra_content=self.data, + is_public=False, + expect_code=200, + ) + + def test_can_create_as_private_room_after_rejection(self): + """ + After failing to publish a room with an alias as a user without publish permission, + retry as the same user, but without publishing the room. + + This should pass, but used to fail because the alias was registered by the first + request, even though the room creation was denied. + """ + self.test_denied_without_publication_permission() + self.test_allowed_when_creating_private_room() + + def test_can_create_with_permission_after_rejection(self): + """ + After failing to publish a room with an alias as a user without publish permission, + retry as someone with permission, using the same alias. + + This also used to fail because of the alias having been registered by the first + request, leaving it unavailable for any other user's new rooms. + """ + self.test_denied_without_publication_permission() + self.test_allowed_with_publication_permission() + + class TestRoomListSearchDisabled(unittest.HomeserverTestCase): user_id = "@test:test" -- cgit 1.5.1