From 9fd057b8c5a8c5748e7d8137d1485c38abd9602f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 28 Sep 2021 21:23:16 -0500 Subject: Ensure `(room_id, next_batch_id)` is unique to avoid cross-talk/conflicts between batches (MSC2716) (#10877) Part of [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) Part of https://github.com/matrix-org/synapse/issues/10737 --- synapse/rest/client/room_batch.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'synapse/rest/client') diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index bf14ec384e..1dffcc3147 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -306,11 +306,13 @@ class RoomBatchSendEventRestServlet(RestServlet): # Verify the batch_id_from_query corresponds to an actual insertion event # and have the batch connected. corresponding_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id(batch_id_from_query) + await self.store.get_insertion_event_by_batch_id( + room_id, batch_id_from_query + ) ) if corresponding_insertion_event_id is None: raise SynapseError( - 400, + HTTPStatus.BAD_REQUEST, "No insertion event corresponds to the given ?batch_id", errcode=Codes.INVALID_PARAM, ) -- cgit 1.5.1 From 8cef1ab2ac8d1602ea6a087384059d104097140f Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Wed, 29 Sep 2021 04:32:45 -0600 Subject: Implement MSC3069: Guest support on whoami (#9655) --- changelog.d/9655.feature | 1 + synapse/rest/client/account.py | 8 +++++-- tests/rest/client/test_account.py | 49 +++++++++++++++++++++++++++++++++++---- 3 files changed, 51 insertions(+), 7 deletions(-) create mode 100644 changelog.d/9655.feature (limited to 'synapse/rest/client') diff --git a/changelog.d/9655.feature b/changelog.d/9655.feature new file mode 100644 index 0000000000..70cac230d8 --- /dev/null +++ b/changelog.d/9655.feature @@ -0,0 +1 @@ +Add [MSC3069](https://github.com/matrix-org/matrix-doc/pull/3069) support to `/account/whoami`. \ No newline at end of file diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 6a7608d60b..bacb828330 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -878,9 +878,13 @@ class WhoamiRestServlet(RestServlet): self.auth = hs.get_auth() async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request, allow_guest=True) - response = {"user_id": requester.user.to_string()} + response = { + "user_id": requester.user.to_string(), + # MSC: https://github.com/matrix-org/matrix-doc/pull/3069 + "org.matrix.msc3069.is_guest": bool(requester.is_guest), + } # Appservices and similar accounts do not have device IDs # that we can report on, so exclude them for compliance. diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 9e9e953cf4..64b0b8458b 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -470,13 +470,45 @@ class WhoamiTestCase(unittest.HomeserverTestCase): register.register_servlets, ] + def default_config(self): + config = super().default_config() + config["allow_guest_access"] = True + return config + def test_GET_whoami(self): device_id = "wouldgohere" user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test", device_id=device_id) - whoami = self.whoami(tok) - self.assertEqual(whoami, {"user_id": user_id, "device_id": device_id}) + whoami = self._whoami(tok) + self.assertEqual( + whoami, + { + "user_id": user_id, + "device_id": device_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": False, + }, + ) + + def test_GET_whoami_guests(self): + channel = self.make_request( + b"POST", b"/_matrix/client/r0/register?kind=guest", b"{}" + ) + tok = channel.json_body["access_token"] + user_id = channel.json_body["user_id"] + device_id = channel.json_body["device_id"] + + whoami = self._whoami(tok) + self.assertEqual( + whoami, + { + "user_id": user_id, + "device_id": device_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": True, + }, + ) def test_GET_whoami_appservices(self): user_id = "@as:test" @@ -491,11 +523,18 @@ class WhoamiTestCase(unittest.HomeserverTestCase): ) self.hs.get_datastore().services_cache.append(appservice) - whoami = self.whoami(as_token) - self.assertEqual(whoami, {"user_id": user_id}) + whoami = self._whoami(as_token) + self.assertEqual( + whoami, + { + "user_id": user_id, + # Unstable until MSC3069 enters spec + "org.matrix.msc3069.is_guest": False, + }, + ) self.assertFalse(hasattr(whoami, "device_id")) - def whoami(self, tok): + def _whoami(self, tok): channel = self.make_request("GET", "account/whoami", {}, access_token=tok) self.assertEqual(channel.code, 200) return channel.json_body -- cgit 1.5.1 From 94b620a5edd6b5bc55c8aad6e00a11cc6bf210fa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Sep 2021 06:44:15 -0400 Subject: Use direct references for configuration variables (part 6). (#10916) --- changelog.d/10916.misc | 1 + synapse/app/_base.py | 8 ++++---- synapse/app/admin_cmd.py | 4 ++-- synapse/app/generic_worker.py | 2 +- synapse/app/homeserver.py | 14 +++++++------- synapse/app/phone_stats_home.py | 8 ++++---- synapse/config/_base.py | 2 +- synapse/config/server.py | 4 +--- synapse/events/presence_router.py | 6 +++--- synapse/events/utils.py | 2 +- synapse/federation/transport/server/__init__.py | 2 +- synapse/handlers/directory.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/identity.py | 2 +- synapse/handlers/message.py | 14 ++++++++------ synapse/handlers/pagination.py | 14 ++++++++++---- synapse/handlers/profile.py | 2 +- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 2 +- synapse/handlers/room_member.py | 14 +++++++------- synapse/handlers/search.py | 2 +- synapse/handlers/user_directory.py | 2 +- synapse/http/matrixfederationclient.py | 10 +++++----- synapse/replication/tcp/resource.py | 2 +- synapse/rest/client/account.py | 10 +++++----- synapse/rest/client/capabilities.py | 4 ++-- synapse/rest/client/filter.py | 2 +- synapse/rest/client/profile.py | 6 +++--- synapse/rest/client/register.py | 6 +++--- synapse/rest/client/room.py | 2 +- synapse/rest/client/shared_rooms.py | 2 +- synapse/rest/client/sync.py | 2 +- synapse/server_notices/resource_limits_server_notices.py | 8 ++++---- synapse/storage/databases/main/censor_events.py | 8 +++++--- synapse/storage/databases/main/client_ips.py | 2 +- synapse/storage/databases/main/events.py | 2 +- synapse/storage/databases/main/monthly_active_users.py | 12 ++++++------ synapse/storage/databases/main/registration.py | 2 +- synapse/storage/databases/main/room.py | 8 ++++---- synapse/storage/databases/main/search.py | 4 ++-- synapse/storage/prepare_database.py | 2 +- tests/api/test_auth.py | 14 +++++++------- tests/federation/test_federation_server.py | 2 +- tests/handlers/test_register.py | 14 +++++++------- tests/http/test_fedclient.py | 2 +- tests/rest/admin/test_user.py | 6 +++--- tests/rest/client/test_account.py | 2 +- tests/rest/client/test_capabilities.py | 2 +- tests/rest/client/test_presence.py | 2 +- tests/rest/client/test_register.py | 4 ++-- .../server_notices/test_resource_limits_server_notices.py | 2 +- tests/storage/test_monthly_active_users.py | 14 +++++++------- tests/test_mau.py | 2 +- tests/unittest.py | 2 +- 54 files changed, 141 insertions(+), 132 deletions(-) create mode 100644 changelog.d/10916.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/10916.misc b/changelog.d/10916.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10916.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 548f6dcde9..749bc1deb9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -86,11 +86,11 @@ def start_worker_reactor(appname, config, run_command=reactor.run): start_reactor( appname, - soft_file_limit=config.soft_file_limit, - gc_thresholds=config.gc_thresholds, + soft_file_limit=config.server.soft_file_limit, + gc_thresholds=config.server.gc_thresholds, pid_file=config.worker.worker_pid_file, daemonize=config.worker.worker_daemonize, - print_pidfile=config.print_pidfile, + print_pidfile=config.server.print_pidfile, logger=logger, run_command=run_command, ) @@ -298,7 +298,7 @@ def refresh_certificate(hs): Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. """ - if not hs.config.has_tls_listener(): + if not hs.config.server.has_tls_listener(): return hs.config.read_certificate_from_disk() diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index f2c5b75247..556bcc124e 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -195,14 +195,14 @@ def start(config_options): config.logging.no_redirect_stdio = True # Explicitly disable background processes - config.update_user_directory = False + config.server.update_user_directory = False config.worker.run_background_tasks = False config.start_pushers = False config.pusher_shard_config.instances = [] config.send_federation = False config.federation_shard_config.instances = [] - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts ss = AdminCmdServer( config.server.server_name, diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 3036e1b4a0..7489f31d9a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -462,7 +462,7 @@ def start(config_options): # For other worker types we force this to off. config.server.update_user_directory = False - synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts + synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage if config.server.gc_seconds: diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 205831dcda..2b2d4bbf83 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -248,7 +248,7 @@ class SynapseHomeServer(HomeServer): resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self) if name == "webclient": - webclient_loc = self.config.web_client_location + webclient_loc = self.config.server.web_client_location if webclient_loc is None: logger.warning( @@ -343,7 +343,7 @@ def setup(config_options): # generating config files and shouldn't try to continue. sys.exit(0) - events.USE_FROZEN_DICTS = config.use_frozen_dicts + events.USE_FROZEN_DICTS = config.server.use_frozen_dicts synapse.util.caches.TRACK_MEMORY_USAGE = config.caches.track_memory_usage if config.server.gc_seconds: @@ -439,11 +439,11 @@ def run(hs): _base.start_reactor( "synapse-homeserver", - soft_file_limit=hs.config.soft_file_limit, - gc_thresholds=hs.config.gc_thresholds, - pid_file=hs.config.pid_file, - daemonize=hs.config.daemonize, - print_pidfile=hs.config.print_pidfile, + soft_file_limit=hs.config.server.soft_file_limit, + gc_thresholds=hs.config.server.gc_thresholds, + pid_file=hs.config.server.pid_file, + daemonize=hs.config.server.daemonize, + print_pidfile=hs.config.server.print_pidfile, logger=logger, ) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 49e7a45e5c..fcd01e833c 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -74,7 +74,7 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): store = hs.get_datastore() stats["homeserver"] = hs.config.server.server_name - stats["server_context"] = hs.config.server_context + stats["server_context"] = hs.config.server.server_context stats["timestamp"] = now stats["uptime_seconds"] = uptime version = sys.version_info @@ -171,7 +171,7 @@ def start_phone_stats_home(hs): current_mau_count_by_service = {} reserved_users = () store = hs.get_datastore() - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: current_mau_count = await store.get_monthly_active_count() current_mau_count_by_service = ( await store.get_monthly_active_count_by_service() @@ -183,9 +183,9 @@ def start_phone_stats_home(hs): current_mau_by_service_gauge.labels(app_service).set(float(count)) registered_reserved_users_mau_gauge.set(float(len(reserved_users))) - max_mau_gauge.set(float(hs.config.max_mau_value)) + max_mau_gauge.set(float(hs.config.server.max_mau_value)) - if hs.config.limit_usage_by_mau or hs.config.mau_stats_only: + if hs.config.server.limit_usage_by_mau or hs.config.server.mau_stats_only: generate_monthly_active_users() clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) # End of monthly active user settings diff --git a/synapse/config/_base.py b/synapse/config/_base.py index d974a1a2a8..26152b0924 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -327,7 +327,7 @@ class RootConfig: """ Redirect lookups on this object either to config objects, or values on config objects, so that `config.tls.blah` works, as well as legacy uses - of things like `config.server_name`. It will first look up the config + of things like `config.server.server_name`. It will first look up the config section name, and then values on those config classes. """ if item in self._configs.keys(): diff --git a/synapse/config/server.py b/synapse/config/server.py index 041412d7ad..818b806357 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -1,6 +1,4 @@ -# Copyright 2014-2016 OpenMarket Ltd -# Copyright 2017-2018 New Vector Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2014-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/synapse/events/presence_router.py b/synapse/events/presence_router.py index eb4556cdc1..68b8b19024 100644 --- a/synapse/events/presence_router.py +++ b/synapse/events/presence_router.py @@ -45,11 +45,11 @@ def load_legacy_presence_router(hs: "HomeServer"): configuration, and registers the hooks they implement. """ - if hs.config.presence_router_module_class is None: + if hs.config.server.presence_router_module_class is None: return - module = hs.config.presence_router_module_class - config = hs.config.presence_router_config + module = hs.config.server.presence_router_module_class + config = hs.config.server.presence_router_config api = hs.get_module_api() presence_router = module(config=config, module_api=api) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index f86113a448..a13fb0148f 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -372,7 +372,7 @@ class EventClientSerializer: def __init__(self, hs): self.store = hs.get_datastore() self.experimental_msc1849_support_enabled = ( - hs.config.experimental_msc1849_support_enabled + hs.config.server.experimental_msc1849_support_enabled ) async def serialize_event( diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 95176ba6f9..c32539bf5a 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -117,7 +117,7 @@ class PublicRoomList(BaseFederationServlet): ): super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_room_list_handler() - self.allow_access = hs.config.allow_public_rooms_over_federation + self.allow_access = hs.config.server.allow_public_rooms_over_federation async def on_GET( self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 5cfba3c817..9078781d5a 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -49,7 +49,7 @@ class DirectoryHandler(BaseHandler): self.store = hs.get_datastore() self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search - self.require_membership = hs.config.require_membership_for_aliases + self.require_membership = hs.config.server.require_membership_for_aliases self.third_party_event_rules = hs.get_third_party_event_rules() self.federation = hs.get_federation_client() diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 16c435ee86..3b0b895b07 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -762,7 +762,7 @@ class FederationHandler(BaseHandler): if is_blocked: raise SynapseError(403, "This room has been blocked on this server") - if self.hs.config.block_non_admin_invites: + if self.hs.config.server.block_non_admin_invites: raise SynapseError(403, "This server does not accept room invites") if not await self.spam_checker.user_may_invite( diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index fe8a995892..a0640fcac0 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -57,7 +57,7 @@ class IdentityHandler(BaseHandler): self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. self.blacklisting_http_client = SimpleHttpClient( - hs, ip_blacklist=hs.config.federation_ip_range_blacklist + hs, ip_blacklist=hs.config.server.federation_ip_range_blacklist ) self.federation_http_client = hs.get_federation_http_client() self.hs = hs diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 39c18ecf99..3b8cc50ec0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -81,7 +81,7 @@ class MessageHandler: self.storage = hs.get_storage() self.state_store = self.storage.state self._event_serializer = hs.get_event_client_serializer() - self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages # The scheduled call to self._expire_event. None if no call is currently # scheduled. @@ -415,7 +415,9 @@ class EventCreationHandler: self.server_name = hs.hostname self.notifier = hs.get_notifier() self.config = hs.config - self.require_membership_for_aliases = hs.config.require_membership_for_aliases + self.require_membership_for_aliases = ( + hs.config.server.require_membership_for_aliases + ) self._events_shard_config = self.config.worker.events_shard_config self._instance_name = hs.get_instance_name() @@ -425,7 +427,7 @@ class EventCreationHandler: Membership.JOIN, Membership.KNOCK, } - if self.hs.config.include_profile_data_on_invite: + if self.hs.config.server.include_profile_data_on_invite: self.membership_types_to_include_profile_data_in.add(Membership.INVITE) self.send_event = ReplicationSendEventRestServlet.make_client(hs) @@ -461,11 +463,11 @@ class EventCreationHandler: # self._rooms_to_exclude_from_dummy_event_insertion: Dict[str, int] = {} # The number of forward extremeities before a dummy event is sent. - self._dummy_events_threshold = hs.config.dummy_events_threshold + self._dummy_events_threshold = hs.config.server.dummy_events_threshold if ( self.config.worker.run_background_tasks - and self.config.cleanup_extremities_with_dummy_events + and self.config.server.cleanup_extremities_with_dummy_events ): self.clock.looping_call( lambda: run_as_background_process( @@ -477,7 +479,7 @@ class EventCreationHandler: self._message_handler = hs.get_message_handler() - self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_events_enabled = hs.config.server.enable_ephemeral_messages self._external_cache = hs.get_external_cache() diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index a5301ece6f..176e4dfdd4 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -85,12 +85,18 @@ class PaginationHandler: self._purges_by_id: Dict[str, PurgeStatus] = {} self._event_serializer = hs.get_event_client_serializer() - self._retention_default_max_lifetime = hs.config.retention_default_max_lifetime + self._retention_default_max_lifetime = ( + hs.config.server.retention_default_max_lifetime + ) - self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min - self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max + self._retention_allowed_lifetime_min = ( + hs.config.server.retention_allowed_lifetime_min + ) + self._retention_allowed_lifetime_max = ( + hs.config.server.retention_allowed_lifetime_max + ) - if hs.config.worker.run_background_tasks and hs.config.retention_enabled: + if hs.config.worker.run_background_tasks and hs.config.server.retention_enabled: # Run the purge jobs described in the configuration file. for job in hs.config.server.retention_purge_jobs: logger.info("Setting up purge job with config: %s", job) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b23a1541bc..425c0d4973 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -397,7 +397,7 @@ class ProfileHandler(BaseHandler): # when building a membership event. In this case, we must allow the # lookup. if ( - not self.hs.config.limit_profile_requests_to_users_who_share_rooms + not self.hs.config.server.limit_profile_requests_to_users_who_share_rooms or not requester ): return diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 4f99f137a2..4a7ccb882e 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -854,7 +854,7 @@ class RegistrationHandler(BaseHandler): # Necessary due to auth checks prior to the threepid being # written to the db if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid + self.hs.config.server.mau_limits_reserved_threepids, threepid ): await self.store.upsert_monthly_active_user(user_id) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index dc4fab2223..bf8a85f563 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -666,7 +666,7 @@ class RoomCreationHandler(BaseHandler): await self.ratelimit(requester) room_version_id = config.get( - "room_version", self.config.default_room_version.identifier + "room_version", self.config.server.default_room_version.identifier ) if not isinstance(room_version_id, str): diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 1a56c82fbd..02103f6c9a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -90,7 +90,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid self._enable_lookup = hs.config.enable_3pid_lookup - self.allow_per_room_profiles = self.config.allow_per_room_profiles + self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( store=self.store, @@ -617,7 +617,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: - if self.config.block_non_admin_invites: + if self.config.server.block_non_admin_invites: logger.info( "Blocking invite: user is not admin and non-admin " "invites disabled" @@ -1222,7 +1222,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): Raises: ShadowBanError if the requester has been shadow-banned. """ - if self.config.block_non_admin_invites: + if self.config.server.block_non_admin_invites: is_requester_admin = await self.auth.is_server_admin(requester.user) if not is_requester_admin: raise SynapseError( @@ -1420,7 +1420,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Returns: bool of whether the complexity is too great, or None if unable to be fetched """ - max_complexity = self.hs.config.limit_remote_rooms.complexity + max_complexity = self.hs.config.server.limit_remote_rooms.complexity complexity = await self.federation_handler.get_room_complexity( remote_room_hosts, room_id ) @@ -1436,7 +1436,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): Args: room_id: The room ID to check for complexity. """ - max_complexity = self.hs.config.limit_remote_rooms.complexity + max_complexity = self.hs.config.server.limit_remote_rooms.complexity complexity = await self.store.get_room_complexity(room_id) return complexity["v1"] > max_complexity @@ -1472,7 +1472,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): if too_complex is True: raise SynapseError( code=400, - msg=self.hs.config.limit_remote_rooms.complexity_error, + msg=self.hs.config.server.limit_remote_rooms.complexity_error, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) @@ -1507,7 +1507,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): ) raise SynapseError( code=400, - msg=self.hs.config.limit_remote_rooms.complexity_error, + msg=self.hs.config.server.limit_remote_rooms.complexity_error, errcode=Codes.RESOURCE_LIMIT_EXCEEDED, ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 8226d6f5a1..6d3333ee00 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -105,7 +105,7 @@ class SearchHandler(BaseHandler): dict to be returned to the client with results of search """ - if not self.hs.config.enable_search: + if not self.hs.config.server.enable_search: raise SynapseError(400, "Search is disabled on this homeserver") batch_group = None diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index b91e7cb501..f4430ce3c9 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -60,7 +60,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.clock = hs.get_clock() self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.update_user_directory = hs.config.update_user_directory + self.update_user_directory = hs.config.server.update_user_directory self.search_all_users = hs.config.userdirectory.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index cdc36b8d25..4f59224686 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -327,23 +327,23 @@ class MatrixFederationHttpClient: self.reactor = hs.get_reactor() user_agent = hs.version_string - if hs.config.user_agent_suffix: - user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix) + if hs.config.server.user_agent_suffix: + user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix) user_agent = user_agent.encode("ascii") federation_agent = MatrixFederationAgent( self.reactor, tls_client_options_factory, user_agent, - hs.config.federation_ip_range_whitelist, - hs.config.federation_ip_range_blacklist, + hs.config.server.federation_ip_range_whitelist, + hs.config.server.federation_ip_range_blacklist, ) # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( federation_agent, - ip_blacklist=hs.config.federation_ip_range_blacklist, + ip_blacklist=hs.config.server.federation_ip_range_blacklist, ) self.clock = hs.get_clock() diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index 030852cb5b..80f9b23bfd 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -71,7 +71,7 @@ class ReplicationStreamer: self.notifier = hs.get_notifier() self._instance_name = hs.get_instance_name() - self._replication_torture_level = hs.config.replication_torture_level + self._replication_torture_level = hs.config.server.replication_torture_level self.notifier.add_replication_callback(self.on_notifier_poke) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index bacb828330..fff133ef10 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -119,7 +119,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): ) if existing_user_id is None: - if self.config.request_token_inhibit_3pid_errors: + if self.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -403,7 +403,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): existing_user_id = await self.store.get_user_id_by_threepid("email", email) if existing_user_id is not None: - if self.config.request_token_inhibit_3pid_errors: + if self.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -486,7 +486,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): existing_user_id = await self.store.get_user_id_by_threepid("msisdn", msisdn) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -857,8 +857,8 @@ def assert_valid_next_link(hs: "HomeServer", next_link: str) -> None: # If the domain whitelist is set, the domain must be in it if ( valid - and hs.config.next_link_domain_whitelist is not None - and next_link_parsed.hostname not in hs.config.next_link_domain_whitelist + and hs.config.server.next_link_domain_whitelist is not None + and next_link_parsed.hostname not in hs.config.server.next_link_domain_whitelist ): valid = False diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index 65b3b5ce2c..d6b6256413 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -44,10 +44,10 @@ class CapabilitiesRestServlet(RestServlet): await self.auth.get_user_by_req(request, allow_guest=True) change_password = self.auth_handler.can_change_password() - response = { + response: JsonDict = { "capabilities": { "m.room_versions": { - "default": self.config.default_room_version.identifier, + "default": self.config.server.default_room_version.identifier, "available": { v.identifier: v.disposition for v in KNOWN_ROOM_VERSIONS.values() diff --git a/synapse/rest/client/filter.py b/synapse/rest/client/filter.py index 6ed60c7418..cc1c2f9731 100644 --- a/synapse/rest/client/filter.py +++ b/synapse/rest/client/filter.py @@ -90,7 +90,7 @@ class CreateFilterRestServlet(RestServlet): raise AuthError(403, "Can only create filters for local users") content = parse_json_object_from_request(request) - set_timeline_upper_limit(content, self.hs.config.filter_timeline_limit) + set_timeline_upper_limit(content, self.hs.config.server.filter_timeline_limit) filter_id = await self.filtering.add_user_filter( user_localpart=target_user.localpart, user_filter=content diff --git a/synapse/rest/client/profile.py b/synapse/rest/client/profile.py index d0f20de569..c684636c0a 100644 --- a/synapse/rest/client/profile.py +++ b/synapse/rest/client/profile.py @@ -41,7 +41,7 @@ class ProfileDisplaynameRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user @@ -94,7 +94,7 @@ class ProfileAvatarURLRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user @@ -146,7 +146,7 @@ class ProfileRestServlet(RestServlet): ) -> Tuple[int, JsonDict]: requester_user = None - if self.hs.config.require_auth_for_profile_requests: + if self.hs.config.server.require_auth_for_profile_requests: requester = await self.auth.get_user_by_req(request) requester_user = requester.user diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index 48b0062cf4..a6eb6f6410 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -129,7 +129,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -209,7 +209,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) if existing_user_id is not None: - if self.hs.config.request_token_inhibit_3pid_errors: + if self.hs.config.server.request_token_inhibit_3pid_errors: # Make the client think the operation succeeded. See the rationale in the # comments for request_token_inhibit_3pid_errors. # Also wait for some random amount of time between 100ms and 1s to make it @@ -682,7 +682,7 @@ class RegisterRestServlet(RestServlet): # written to the db if threepid: if is_threepid_reserved( - self.hs.config.mau_limits_reserved_threepids, threepid + self.hs.config.server.mau_limits_reserved_threepids, threepid ): await self.store.upsert_monthly_active_user(registered_user_id) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index bf46dc60f2..ed95189b6d 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -369,7 +369,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private # federations. - if not self.hs.config.allow_public_rooms_without_auth: + if not self.hs.config.server.allow_public_rooms_without_auth: raise # We allow people to not be authed if they're just looking at our diff --git a/synapse/rest/client/shared_rooms.py b/synapse/rest/client/shared_rooms.py index 1d90493eb0..09a46737de 100644 --- a/synapse/rest/client/shared_rooms.py +++ b/synapse/rest/client/shared_rooms.py @@ -42,7 +42,7 @@ class UserSharedRoomsServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.user_directory_active = hs.config.update_user_directory + self.user_directory_active = hs.config.server.update_user_directory async def on_GET( self, request: SynapseRequest, user_id: str diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 1259058b9b..913216a7c4 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -155,7 +155,7 @@ class SyncRestServlet(RestServlet): try: filter_object = json_decoder.decode(filter_id) set_timeline_upper_limit( - filter_object, self.hs.config.filter_timeline_limit + filter_object, self.hs.config.server.filter_timeline_limit ) except Exception: raise SynapseError(400, "Invalid filter JSON") diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index 073b0d754f..8522930b50 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -47,9 +47,9 @@ class ResourceLimitsServerNotices: self._notifier = hs.get_notifier() self._enabled = ( - hs.config.limit_usage_by_mau + hs.config.server.limit_usage_by_mau and self._server_notices_manager.is_enabled() - and not hs.config.hs_disabled + and not hs.config.server.hs_disabled ) async def maybe_send_server_notice_to_user(self, user_id: str) -> None: @@ -98,7 +98,7 @@ class ResourceLimitsServerNotices: try: if ( limit_type == LimitBlockingTypes.MONTHLY_ACTIVE_USER - and not self._config.mau_limit_alerting + and not self._config.server.mau_limit_alerting ): # We have hit the MAU limit, but MAU alerting is disabled: # reset room if necessary and return @@ -149,7 +149,7 @@ class ResourceLimitsServerNotices: "body": event_body, "msgtype": ServerNoticeMsgType, "server_notice_type": ServerNoticeLimitReached, - "admin_contact": self._config.admin_contact, + "admin_contact": self._config.server.admin_contact, "limit_type": event_limit_type, } event = await self._server_notices_manager.send_notice( diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 6305414e3d..eee07227ef 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -36,7 +36,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase if ( hs.config.worker.run_background_tasks - and self.hs.config.redaction_retention_period is not None + and self.hs.config.server.redaction_retention_period is not None ): hs.get_clock().looping_call(self._censor_redactions, 5 * 60 * 1000) @@ -48,7 +48,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase By censor we mean update the event_json table with the redacted event. """ - if self.hs.config.redaction_retention_period is None: + if self.hs.config.server.redaction_retention_period is None: return if not ( @@ -60,7 +60,9 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase # created. return - before_ts = self._clock.time_msec() - self.hs.config.redaction_retention_period + before_ts = ( + self._clock.time_msec() - self.hs.config.server.redaction_retention_period + ) # We fetch all redactions that: # 1. point to an event we have, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 7e33ae578c..0e1d97aaeb 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -353,7 +353,7 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self.user_ips_max_age = hs.config.user_ips_max_age + self.user_ips_max_age = hs.config.server.user_ips_max_age if hs.config.worker.run_background_tasks and self.user_ips_max_age: self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index cc4e31ec30..bc7d213fe2 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -104,7 +104,7 @@ class PersistEventsStore: self._clock = hs.get_clock() self._instance_name = hs.get_instance_name() - self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages + self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages self.is_mine_id = hs.is_mine_id # Ideally we'd move these ID gens here, unfortunately some other ID diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index b76ee51a9b..a14ac03d4b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -32,8 +32,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): self._clock = hs.get_clock() self.hs = hs - self._limit_usage_by_mau = hs.config.limit_usage_by_mau - self._max_mau_value = hs.config.max_mau_value + self._limit_usage_by_mau = hs.config.server.limit_usage_by_mau + self._max_mau_value = hs.config.server.max_mau_value @cached(num_args=0) async def get_monthly_active_count(self) -> int: @@ -96,8 +96,8 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """ users = [] - for tp in self.hs.config.mau_limits_reserved_threepids[ - : self.hs.config.max_mau_value + for tp in self.hs.config.server.mau_limits_reserved_threepids[ + : self.hs.config.server.max_mau_value ]: user_id = await self.hs.get_datastore().get_user_id_by_threepid( tp["medium"], tp["address"] @@ -212,7 +212,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - self._mau_stats_only = hs.config.mau_stats_only + self._mau_stats_only = hs.config.server.mau_stats_only # Do not add more reserved users than the total allowable number self.db_pool.new_transaction( @@ -221,7 +221,7 @@ class MonthlyActiveUsersStore(MonthlyActiveUsersWorkerStore): [], [], self._initialise_reserved_users, - hs.config.mau_limits_reserved_threepids[: self._max_mau_value], + hs.config.server.mau_limits_reserved_threepids[: self._max_mau_value], ) def _initialise_reserved_users(self, txn, threepids): diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c83089ee63..7279b0924e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -207,7 +207,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return False now = self._clock.time_msec() - trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 118b390e93..d69eaf80ce 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -679,8 +679,8 @@ class RoomWorkerStore(SQLBaseStore): # policy. if not ret: return { - "min_lifetime": self.config.retention_default_min_lifetime, - "max_lifetime": self.config.retention_default_max_lifetime, + "min_lifetime": self.config.server.retention_default_min_lifetime, + "max_lifetime": self.config.server.retention_default_max_lifetime, } row = ret[0] @@ -690,10 +690,10 @@ class RoomWorkerStore(SQLBaseStore): # The default values will be None if no default policy has been defined, or if one # of the attributes is missing from the default policy. if row["min_lifetime"] is None: - row["min_lifetime"] = self.config.retention_default_min_lifetime + row["min_lifetime"] = self.config.server.retention_default_min_lifetime if row["max_lifetime"] is None: - row["max_lifetime"] = self.config.retention_default_max_lifetime + row["max_lifetime"] = self.config.server.retention_default_max_lifetime return row diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 2a1e99e17a..c85383c975 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -51,7 +51,7 @@ class SearchWorkerStore(SQLBaseStore): txn: entries: entries to be added to the table """ - if not self.hs.config.enable_search: + if not self.hs.config.server.enable_search: return if isinstance(self.database_engine, PostgresEngine): sql = ( @@ -105,7 +105,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) - if not hs.config.enable_search: + if not hs.config.server.enable_search: return self.db_pool.updates.register_background_update_handler( diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index f31880b8ec..a63eaddfdc 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -366,7 +366,7 @@ def _upgrade_existing_database( + "new for the server to understand" ) - # some of the deltas assume that config.server_name is set correctly, so now + # some of the deltas assume that server_name is set correctly, so now # is a good time to run the sanity check. if not is_empty and "main" in databases: from synapse.storage.databases.main import check_database_before_upgrade diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index cccff7af26..3aa9ba3c43 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -217,7 +217,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.key.macaroon_secret_key, ) @@ -239,7 +239,7 @@ class AuthTestCase(unittest.HomeserverTestCase): user_id = "@baldrick:matrix.org" macaroon = pymacaroons.Macaroon( - location=self.hs.config.server_name, + location=self.hs.config.server.server_name, identifier="key", key=self.hs.config.key.macaroon_secret_key, ) @@ -268,7 +268,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.store.get_monthly_active_count = simple_async_mock(lots_of_users) e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) @@ -303,7 +303,7 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -332,7 +332,7 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -372,7 +372,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) @@ -387,7 +387,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.auth_blocking._hs_disabled = True self.auth_blocking._hs_disabled_message = "Reason for being disabled" e = self.get_failure(self.auth.check_auth_blocking(), ResourceLimitError) - self.assertEquals(e.value.admin_contact, self.hs.config.admin_contact) + self.assertEquals(e.value.admin_contact, self.hs.config.server.admin_contact) self.assertEquals(e.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.value.code, 403) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index 0b60cc4261..03e1e11f49 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -120,7 +120,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase): self.assertEqual( channel.json_body["room_version"], - self.hs.config.default_room_version.identifier, + self.hs.config.server.default_room_version.identifier, ) members = set( diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index d3efb67e3e..bd05a2c2d1 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -175,20 +175,20 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.assertTrue(result_token is not None) def test_mau_limits_when_disabled(self): - self.hs.config.limit_usage_by_mau = False + self.hs.config.server.limit_usage_by_mau = False # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "a", "display_name")) def test_get_or_create_user_mau_not_blocked(self): - self.hs.config.limit_usage_by_mau = True + self.hs.config.server.limit_usage_by_mau = True self.store.count_monthly_users = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value - 1) + return_value=make_awaitable(self.hs.config.server.max_mau_value - 1) ) # Ensure does not throw exception self.get_success(self.get_or_create_user(self.requester, "c", "User")) def test_get_or_create_user_mau_blocked(self): - self.hs.config.limit_usage_by_mau = True + self.hs.config.server.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -198,7 +198,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.get_or_create_user(self.requester, "b", "display_name"), @@ -206,7 +206,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) def test_register_mau_blocked(self): - self.hs.config.limit_usage_by_mau = True + self.hs.config.server.limit_usage_by_mau = True self.store.get_monthly_active_count = Mock( return_value=make_awaitable(self.lots_of_users) ) @@ -215,7 +215,7 @@ class RegistrationTestCase(unittest.HomeserverTestCase): ) self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) self.get_failure( self.handler.register_user(localpart="local_part"), ResourceLimitError diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index d9a8b077d3..638babae69 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -226,7 +226,7 @@ class FederationClientTests(HomeserverTestCase): """Ensure that Synapse does not try to connect to blacklisted IPs""" # Set up the ip_range blacklist - self.hs.config.federation_ip_range_blacklist = IPSet( + self.hs.config.server.federation_ip_range_blacklist = IPSet( ["127.0.0.0/8", "fe80::/64"] ) self.reactor.lookups["internal"] = "127.0.0.1" diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index ee3ae9cce4..a285d5a7fe 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -422,7 +422,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1485,7 +1485,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit @@ -1522,7 +1522,7 @@ class UserRestTestCase(unittest.HomeserverTestCase): # Set monthly active users to the limit self.store.get_monthly_active_count = Mock( - return_value=make_awaitable(self.hs.config.max_mau_value) + return_value=make_awaitable(self.hs.config.server.max_mau_value) ) # Check that the blocking of monthly active users is working as expected # The registration of a new user fails due to the limit diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 64b0b8458b..2f44547bfb 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -516,7 +516,7 @@ class WhoamiTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, diff --git a/tests/rest/client/test_capabilities.py b/tests/rest/client/test_capabilities.py index 422361b62a..b9e3602552 100644 --- a/tests/rest/client/test_capabilities.py +++ b/tests/rest/client/test_capabilities.py @@ -55,7 +55,7 @@ class CapabilitiesTestCase(unittest.HomeserverTestCase): self.assertTrue(room_version in KNOWN_ROOM_VERSIONS, "" + room_version) self.assertEqual( - self.config.default_room_version.identifier, + self.config.server.default_room_version.identifier, capabilities["m.room_versions"]["default"], ) diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py index 1d152352d1..56fe1a3d01 100644 --- a/tests/rest/client/test_presence.py +++ b/tests/rest/client/test_presence.py @@ -50,7 +50,7 @@ class PresenceTestCase(unittest.HomeserverTestCase): PUT to the status endpoint with use_presence enabled will call set_state on the presence handler. """ - self.hs.config.use_presence = True + self.hs.config.server.use_presence = True body = {"presence": "here", "status_msg": "beep boop"} channel = self.make_request( diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 72a5a11b46..af135d57e1 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -50,7 +50,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -74,7 +74,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server_name, + self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 7f25200a5d..36c495954f 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -346,7 +346,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): invites = [] # Register as many users as the MAU limit allows. - for i in range(self.hs.config.max_mau_value): + for i in range(self.hs.config.server.max_mau_value): localpart = "user%d" % i user_id = self.register_user(localpart, "password") tok = self.login(localpart, "password") diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py index 944dbc34a2..d6b4cdd788 100644 --- a/tests/storage/test_monthly_active_users.py +++ b/tests/storage/test_monthly_active_users.py @@ -51,7 +51,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): @override_config({"max_mau_value": 3, "mau_limit_reserved_threepids": gen_3pids(3)}) def test_initialise_reserved_users(self): - threepids = self.hs.config.mau_limits_reserved_threepids + threepids = self.hs.config.server.mau_limits_reserved_threepids # register three users, of which two have reserved 3pids, and a third # which is a support user. @@ -101,9 +101,9 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): # XXX some of this is redundant. poking things into the config shouldn't # work, and in any case it's not obvious what we expect to happen when # we advance the reactor. - self.hs.config.max_mau_value = 0 + self.hs.config.server.max_mau_value = 0 self.reactor.advance(FORTY_DAYS) - self.hs.config.max_mau_value = 5 + self.hs.config.server.max_mau_value = 5 self.get_success(self.store.reap_monthly_active_users()) @@ -183,7 +183,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(d) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, self.hs.config.max_mau_value) + self.assertEqual(count, self.hs.config.server.max_mau_value) self.reactor.advance(FORTY_DAYS) @@ -199,7 +199,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): def test_reap_monthly_active_users_reserved_users(self): """Tests that reaping correctly handles reaping where reserved users are present""" - threepids = self.hs.config.mau_limits_reserved_threepids + threepids = self.hs.config.server.mau_limits_reserved_threepids initial_users = len(threepids) reserved_user_number = initial_users - 1 for i in range(initial_users): @@ -234,7 +234,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): self.get_success(d) count = self.get_success(self.store.get_monthly_active_count()) - self.assertEqual(count, self.hs.config.max_mau_value) + self.assertEqual(count, self.hs.config.server.max_mau_value) def test_populate_monthly_users_is_guest(self): # Test that guest users are not added to mau list @@ -294,7 +294,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase): {"medium": "email", "address": user2_email}, ] - self.hs.config.mau_limits_reserved_threepids = threepids + self.hs.config.server.mau_limits_reserved_threepids = threepids d = self.store.db_pool.runInteraction( "initialise", self.store._initialise_reserved_users, threepids ) diff --git a/tests/test_mau.py b/tests/test_mau.py index 66111eb367..80ab40e255 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -165,7 +165,7 @@ class TestMauLimit(unittest.HomeserverTestCase): @override_config({"mau_trial_days": 1}) def test_trial_users_cant_come_back(self): - self.hs.config.mau_trial_days = 1 + self.hs.config.server.mau_trial_days = 1 # We should be able to register more than the limit initially token1 = self.create_user("kermit1") diff --git a/tests/unittest.py b/tests/unittest.py index 7a6f5954d0..6d5d87cb78 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -232,7 +232,7 @@ class HomeserverTestCase(TestCase): # Honour the `use_frozen_dicts` config option. We have to do this # manually because this is taken care of in the app `start` code, which # we don't run. Plus we want to reset it on tearDown. - events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts + events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts if self.hs is None: raise Exception("No homeserver returned from make_homeserver.") -- cgit 1.5.1 From 145cb6d08e2f775da208293a507c1dcd2d4128ce Mon Sep 17 00:00:00 2001 From: Lukas Lihotzki Date: Thu, 30 Sep 2021 14:04:55 +0200 Subject: Fix getTurnServer response: return an integer ttl (#10922) `ttl` must be an integer according to the OpenAPI spec: https://github.com/matrix-org/matrix-doc/blob/old_master/data/api/client-server/voip.yaml#L70 True division (`/`) returns a float instead (`"ttl": 7200.0`). Floor division (`//`) returns an integer, so the response is spec compliant. Signed-off-by: Lukas Lihotzki --- changelog.d/10922.bugfix | 1 + synapse/rest/client/voip.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10922.bugfix (limited to 'synapse/rest/client') diff --git a/changelog.d/10922.bugfix b/changelog.d/10922.bugfix new file mode 100644 index 0000000000..b7315514e0 --- /dev/null +++ b/changelog.d/10922.bugfix @@ -0,0 +1 @@ +Fix a minor bug in the response to `/_matrix/client/r0/voip/turnServer`. Contributed by @lukaslihotzki. diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py index ea2b8aa45f..ea7e025156 100644 --- a/synapse/rest/client/voip.py +++ b/synapse/rest/client/voip.py @@ -70,7 +70,7 @@ class VoipRestServlet(RestServlet): { "username": username, "password": password, - "ttl": userLifetime / 1000, + "ttl": userLifetime // 1000, "uris": turnUris, }, ) -- cgit 1.5.1 From a0f48ee89d88fd7b6da8023dbba607a69073152e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 4 Oct 2021 07:18:54 -0400 Subject: Use direct references for configuration variables (part 7). (#10959) --- changelog.d/10959.misc | 1 + synapse/handlers/auth.py | 2 +- synapse/handlers/identity.py | 13 ++++++++++--- synapse/handlers/profile.py | 4 ++-- synapse/handlers/register.py | 9 ++++++--- synapse/handlers/room_member.py | 2 +- synapse/handlers/ui_auth/checkers.py | 14 ++++++++------ synapse/rest/admin/users.py | 4 ++-- synapse/rest/client/account.py | 22 +++++++++++----------- synapse/rest/client/auth.py | 6 ++++-- synapse/rest/client/capabilities.py | 6 +++--- synapse/rest/client/login.py | 6 +++--- synapse/rest/client/register.py | 26 +++++++++++++------------- synapse/rest/well_known.py | 4 ++-- synapse/storage/databases/main/registration.py | 2 +- synapse/util/threepids.py | 4 ++-- tests/config/test_load.py | 6 +++--- tests/handlers/test_profile.py | 4 ++-- tests/rest/admin/test_user.py | 4 ++-- tests/rest/client/test_account.py | 4 ++-- tests/rest/client/test_identity.py | 2 +- tests/rest/client/test_register.py | 4 ++-- tests/unittest.py | 2 +- 23 files changed, 83 insertions(+), 68 deletions(-) create mode 100644 changelog.d/10959.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/10959.misc b/changelog.d/10959.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10959.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a8c717efd5..2d0f3d566c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -198,7 +198,7 @@ class AuthHandler(BaseHandler): if inst.is_enabled(): self.checkers[inst.AUTH_TYPE] = inst # type: ignore - self.bcrypt_rounds = hs.config.bcrypt_rounds + self.bcrypt_rounds = hs.config.registration.bcrypt_rounds # we can't use hs.get_module_api() here, because to do so will create an # import loop. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index a0640fcac0..c881475c25 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -573,9 +573,15 @@ class IdentityHandler(BaseHandler): # Try to validate as email if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Remote emails will only be used if a valid identity server is provided. + assert ( + self.hs.config.registration.account_threepid_delegate_email is not None + ) + # Ask our delegated email identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details @@ -587,10 +593,11 @@ class IdentityHandler(BaseHandler): return validation_session # Try to validate as msisdn - if self.hs.config.account_threepid_delegate_msisdn: + if self.hs.config.registration.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) return validation_session diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 425c0d4973..2e19706c69 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -178,7 +178,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: + if not by_admin and not self.hs.config.registration.enable_set_displayname: profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( @@ -268,7 +268,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: + if not by_admin and not self.hs.config.registration.enable_set_avatar_url: profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cb4eb0720b..441af7a848 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,8 +116,8 @@ class RegistrationHandler(BaseHandler): self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() - self.session_lifetime = hs.config.session_lifetime - self.access_token_lifetime = hs.config.access_token_lifetime + self.session_lifetime = hs.config.registration.session_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime init_counters_for_auth_provider("") @@ -343,7 +343,10 @@ class RegistrationHandler(BaseHandler): # If the user does not need to consent at registration, auto-join any # configured rooms. if not self.hs.config.consent.user_consent_at_registration: - if not self.hs.config.auto_join_rooms_for_guests and make_guest: + if ( + not self.hs.config.registration.auto_join_rooms_for_guests + and make_guest + ): logger.info( "Skipping auto-join for %s because auto-join for guests is disabled", user_id, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 29b3e41cc9..c8fb24a20c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -89,7 +89,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid - self._enable_lookup = hs.config.enable_3pid_lookup + self._enable_lookup = hs.config.registration.enable_3pid_lookup self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8f5d465fa1..184730ebe8 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -153,21 +153,23 @@ class _BaseThreepidAuthChecker: # msisdns are currently always ThreepidBehaviour.REMOTE if medium == "msisdn": - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "Phone number verification is not enabled on this homeserver" ) threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) elif medium == "email": if ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE ): - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL @@ -240,7 +242,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return bool(self.hs.config.account_threepid_delegate_msisdn) + return bool(self.hs.config.registration.account_threepid_delegate_msisdn) async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("msisdn", authdict) @@ -252,7 +254,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self._enabled = bool(hs.config.registration_requires_token) + self._enabled = bool(hs.config.registration.registration_requires_token) self.store = hs.get_datastore() def is_enabled(self) -> bool: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 46bfec4623..f20aa65301 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -442,7 +442,7 @@ class UserRegisterServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() - if not self.hs.config.registration_shared_secret: + if not self.hs.config.registration.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") body = parse_json_object_from_request(request) @@ -498,7 +498,7 @@ class UserRegisterServlet(RestServlet): got_mac = body["mac"] want_mac_builder = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), + key=self.hs.config.registration.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) want_mac_builder.update(nonce.encode("utf8")) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index fff133ef10..6b272658fc 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -130,11 +130,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -414,11 +414,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -507,7 +507,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -604,7 +604,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - if not self.config.account_threepid_delegate_msisdn: + if not self.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "This homeserver is not validating phone numbers. Use an identity server " @@ -617,7 +617,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.account_threepid_delegate_msisdn, + self.config.registration.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], body["token"], @@ -644,7 +644,7 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -693,7 +693,7 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -801,7 +801,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 282861fae2..c9ad35a3ad 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -49,8 +49,10 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template self.terms_template = hs.config.terms_template - self.registration_token_template = hs.config.registration_token_template - self.success_template = hs.config.fallback_success_template + self.registration_token_template = ( + hs.config.registration.registration_token_template + ) + self.success_template = hs.config.registration.fallback_success_template async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index d6b6256413..2a3e24ae7e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -64,13 +64,13 @@ class CapabilitiesRestServlet(RestServlet): if self.config.experimental.msc3283_enabled: response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.enable_set_displayname + "enabled": self.config.registration.enable_set_displayname } response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.enable_set_avatar_url + "enabled": self.config.registration.enable_set_avatar_url } response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.enable_3pid_changes + "enabled": self.config.registration.enable_3pid_changes } return 200, response diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index fa5c173f4b..d49a647b03 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -79,7 +79,7 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self.auth = hs.get_auth() @@ -447,7 +447,7 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.access_token_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -556,7 +556,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.access_token_lifetime is not None: + if hs.config.registration.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index a6eb6f6410..bf3cb34146 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -140,11 +140,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -221,7 +221,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -231,7 +231,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -341,7 +341,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -391,7 +391,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -419,8 +419,8 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() - self._registration_enabled = self.hs.config.enable_registration - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._registration_enabled = self.hs.config.registration.enable_registration + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -800,7 +800,7 @@ class RegisterRestServlet(RestServlet): async def _do_guest_registration( self, params: JsonDict, address: Optional[str] = None ) -> Tuple[int, JsonDict]: - if not self.hs.config.allow_guest_access: + if not self.hs.config.registration.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( make_guest=True, address=address @@ -849,13 +849,13 @@ def _calculate_registration_flows( """ # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registrations_require_3pid - require_msisdn = "msisdn" in config.registrations_require_3pid + require_email = "email" in config.registration.registrations_require_3pid + require_msisdn = "msisdn" in config.registration.registrations_require_3pid show_msisdn = True show_email = True - if config.disable_msisdn_registration: + if config.registration.disable_msisdn_registration: show_msisdn = False require_msisdn = False @@ -909,7 +909,7 @@ def _calculate_registration_flows( flow.insert(0, LoginType.RECAPTCHA) # Prepend registration token to all flows if we're requiring a token - if config.registration_requires_token: + if config.registration.registration_requires_token: for flow in flows: flow.insert(0, LoginType.REGISTRATION_TOKEN) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index c80a3a99aa..7ac01faab4 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -39,9 +39,9 @@ class WellKnownBuilder: result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} - if self._config.default_identity_server: + if self._config.registration.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server + "base_url": self._config.registration.default_identity_server } return result diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7279b0924e..de262fbf5a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): We do this by grandfathering in existing user threepids assuming that they used one of the server configured trusted identity servers. """ - id_servers = set(self.config.trusted_third_party_id_servers) + id_servers = set(self.config.registration.trusted_third_party_id_servers) def _bg_user_threepids_grandfather_txn(txn): sql = """ diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index baa9190a9a..389adf00f6 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -44,8 +44,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: bool: whether the 3PID medium/address is allowed to be added to this HS """ - if hs.config.allowed_local_3pids: - for constraint in hs.config.allowed_local_3pids: + if hs.config.registration.allowed_local_3pids: + for constraint in hs.config.registration.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", address, diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ef6c2beec7..8e49ca26d9 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -84,16 +84,16 @@ class ConfigLoadingTestCase(unittest.TestCase): ) # Check that disable_registration clobbers enable_registration. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( "", ["-c", self.file, "--enable-registration"] ) - self.assertTrue(config.enable_registration) + self.assertTrue(config.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics") diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 57cc3e2646..c153018fd8 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False + self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed self.get_success( @@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False + self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed self.get_success( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a285d5a7fe..6ed9e42173 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): If there is no shared secret, registration through this method will be prevented. """ - self.hs.config.registration_shared_secret = None + self.hs.config.registration.registration_shared_secret = None channel = self.make_request("POST", self.url, b"{}") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 2f44547bfb..89d85b0a17 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -664,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -734,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False # Add a threepid self.get_success( diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index ca2e8ff8ef..becb4e8dcc 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index af135d57e1..66dcfc9f88 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True + self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False + self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/unittest.py b/tests/unittest.py index 0807467e39..1f803564f6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -560,7 +560,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user. """ - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" # Create the user channel = self.make_request("GET", "/_synapse/admin/v1/register") -- cgit 1.5.1 From f4b1a9a527273ef71b2f7d970642b7af45462e0f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 6 Oct 2021 10:47:41 -0400 Subject: Require direct references to configuration variables. (#10985) This removes the magic allowing accessing configurable variables directly from the config object. It is now required that a specific configuration class is used (e.g. `config.foo` must be replaced with `config.server.foo`). --- changelog.d/10985.misc | 1 + scripts/synapse_port_db | 4 +- scripts/update_synapse_database | 2 +- synapse/app/_base.py | 2 +- synapse/app/admin_cmd.py | 4 +- synapse/app/homeserver.py | 2 +- synapse/config/_base.py | 64 ++++---------------------- synapse/config/account_validity.py | 2 +- synapse/config/cas.py | 2 +- synapse/config/emailconfig.py | 9 ++-- synapse/config/key.py | 6 ++- synapse/config/oidc.py | 2 +- synapse/config/registration.py | 7 ++- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/server_notices.py | 4 +- synapse/config/sso.py | 6 ++- synapse/handlers/account_validity.py | 8 +--- synapse/handlers/room_member.py | 7 ++- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/handler.py | 7 ++- synapse/rest/client/auth.py | 2 +- synapse/rest/client/push_rule.py | 4 +- synapse/storage/databases/main/push_rule.py | 4 +- synapse/storage/databases/main/registration.py | 4 +- tests/config/test_base.py | 21 +++++---- tests/config/test_cache.py | 50 ++++++++------------ tests/config/test_load.py | 12 +++-- tests/config/test_tls.py | 38 +++++++-------- tests/storage/test_appservice.py | 2 +- tests/storage/test_txn_limit.py | 2 +- 31 files changed, 124 insertions(+), 160 deletions(-) create mode 100644 changelog.d/10985.misc (limited to 'synapse/rest/client') diff --git a/changelog.d/10985.misc b/changelog.d/10985.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10985.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index fa6ac6d93a..a947d9e49e 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -215,7 +215,7 @@ class MockHomeserver: def __init__(self, config): self.clock = Clock(reactor) self.config = config - self.hostname = config.server_name + self.hostname = config.server.server_name self.version_string = "Synapse/" + get_version_string(synapse) def get_clock(self): @@ -583,7 +583,7 @@ class Porter(object): return self.postgres_store = self.build_db_store( - self.hs_config.get_single_database() + self.hs_config.database.get_single_database() ) await self.run_background_updates_on_postgres() diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 26b29b0b45..6c088bad93 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -36,7 +36,7 @@ class MockHomeserver(HomeServer): def __init__(self, config, **kwargs): super(MockHomeserver, self).__init__( - config.server_name, reactor=reactor, config=config, **kwargs + config.server.server_name, reactor=reactor, config=config, **kwargs ) self.version_string = "Synapse/" + get_version_string(synapse) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 749bc1deb9..4a204a5823 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -301,7 +301,7 @@ def refresh_certificate(hs): if not hs.config.server.has_tls_listener(): return - hs.config.read_certificate_from_disk() + hs.config.tls.read_certificate_from_disk() hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config) if hs._listening_services: diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 556bcc124e..13d20af457 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -197,9 +197,9 @@ def start(config_options): # Explicitly disable background processes config.server.update_user_directory = False config.worker.run_background_tasks = False - config.start_pushers = False + config.worker.start_pushers = False config.pusher_shard_config.instances = [] - config.send_federation = False + config.worker.send_federation = False config.federation_shard_config.instances = [] synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b2d4bbf83..422f03cc04 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["media", "federation", "client"]: - if self.config.media.enable_media_repo: + if self.config.server.enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 26152b0924..7c4428a138 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -118,21 +118,6 @@ class Config: "synapse", "res/templates" ) - def __getattr__(self, item: str) -> Any: - """ - Try and fetch a configuration option that does not exist on this class. - - This is so that existing configs that rely on `self.value`, where value - is actually from a different config section, continue to work. - """ - if item in ["generate_config_section", "read_config"]: - raise AttributeError(item) - - if self.root is None: - raise AttributeError(item) - else: - return self.root._get_unclassed_config(self.section, item) - @staticmethod def parse_size(value): if isinstance(value, int): @@ -289,7 +274,9 @@ class Config: env.filters.update( { "format_ts": _format_ts_filter, - "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + "mxc_to_http": _create_mxc_to_http_filter( + self.root.server.public_baseurl + ), } ) @@ -311,8 +298,6 @@ class RootConfig: config_classes = [] def __init__(self): - self._configs = OrderedDict() - for config_class in self.config_classes: if config_class.section is None: raise ValueError("%r requires a section name" % (config_class,)) @@ -321,42 +306,7 @@ class RootConfig: conf = config_class(self) except Exception as e: raise Exception("Failed making %s: %r" % (config_class.section, e)) - self._configs[config_class.section] = conf - - def __getattr__(self, item: str) -> Any: - """ - Redirect lookups on this object either to config objects, or values on - config objects, so that `config.tls.blah` works, as well as legacy uses - of things like `config.server.server_name`. It will first look up the config - section name, and then values on those config classes. - """ - if item in self._configs.keys(): - return self._configs[item] - - return self._get_unclassed_config(None, item) - - def _get_unclassed_config(self, asking_section: Optional[str], item: str): - """ - Fetch a config value from one of the instantiated config classes that - has not been fetched directly. - - Args: - asking_section: If this check is coming from a Config child, which - one? This section will not be asked if it has the value. - item: The configuration value key. - - Raises: - AttributeError if no config classes have the config key. The body - will contain what sections were checked. - """ - for key, val in self._configs.items(): - if key == asking_section: - continue - - if item in dir(val): - return getattr(val, item) - - raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),)) + setattr(self, config_class.section, conf) def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: """ @@ -373,9 +323,11 @@ class RootConfig: """ res = OrderedDict() - for name, config in self._configs.items(): + for config_class in self.config_classes: + config = getattr(self, config_class.section) + if hasattr(config, func_name): - res[name] = getattr(config, func_name)(*args, **kwargs) + res[config_class.section] = getattr(config, func_name)(*args, **kwargs) return res diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index ffaffc4931..b56c2a24df 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -76,7 +76,7 @@ class AccountValidityConfig(Config): ) if self.account_validity_renew_by_email_enabled: - if not self.public_baseurl: + if not self.root.server.public_baseurl: raise ConfigError("Can't send renewal emails without 'public_baseurl'") # Load account validity templates. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 901f4123e1..9b58ecf3d8 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -37,7 +37,7 @@ class CasConfig(Config): # The public baseurl is required because it is used by the redirect # template. - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if not public_baseurl: raise ConfigError("cas_config requires a public_baseurl to be set") diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 936abe6178..8ff59aa2f8 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -19,7 +19,6 @@ import email.utils import logging import os from enum import Enum -from typing import Optional import attr @@ -135,7 +134,7 @@ class EmailConfig(Config): # msisdn is currently always remote while Synapse does not support any method of # sending SMS messages ThreepidBehaviour.REMOTE - if self.account_threepid_delegate_email + if self.root.registration.account_threepid_delegate_email else ThreepidBehaviour.LOCAL ) # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would @@ -144,7 +143,7 @@ class EmailConfig(Config): # identity server in the process. self.using_identity_server_from_trusted_list = False if ( - not self.account_threepid_delegate_email + not self.root.registration.account_threepid_delegate_email and config.get("trust_identity_server_for_password_resets", False) is True ): # Use the first entry in self.trusted_third_party_id_servers instead @@ -156,7 +155,7 @@ class EmailConfig(Config): # trusted_third_party_id_servers does not contain a scheme whereas # account_threepid_delegate_email is expected to. Presume https - self.account_threepid_delegate_email: Optional[str] = ( + self.root.registration.account_threepid_delegate_email = ( "https://" + first_trusted_identity_server ) self.using_identity_server_from_trusted_list = True @@ -335,7 +334,7 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if self.account_validity_renew_by_email_enabled: + if self.root.account_validity.account_validity_renew_by_email_enabled: expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) diff --git a/synapse/config/key.py b/synapse/config/key.py index 94a9063043..015dbb8a67 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -145,11 +145,13 @@ class KeyConfig(Config): # list of TrustedKeyServer objects self.key_servers = list( - _parse_key_servers(key_servers, self.federation_verify_certificates) + _parse_key_servers( + key_servers, self.root.tls.federation_verify_certificates + ) ) self.macaroon_secret_key = config.get( - "macaroon_secret_key", self.registration_shared_secret + "macaroon_secret_key", self.root.registration.registration_shared_secret ) if not self.macaroon_secret_key: diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 7e67fbada1..10f5796330 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -58,7 +58,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if public_baseurl is None: raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7cffdacfa5..a3d2a38c4c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -45,7 +45,10 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: + if ( + self.account_threepid_delegate_msisdn + and not self.root.server.public_baseurl + ): raise ConfigError( "The configuration option `public_baseurl` is required if " "`account_threepid_delegate.msisdn` is set, such that " @@ -85,7 +88,7 @@ class RegistrationConfig(Config): if mxid_localpart: # Convert the localpart to a full mxid. self.auto_join_user_id = UserID( - mxid_localpart, self.server_name + mxid_localpart, self.root.server.server_name ).to_string() if self.autocreate_auto_join_rooms: diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 7481f3bf5f..69906a98d4 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -94,7 +94,7 @@ class ContentRepositoryConfig(Config): # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( - self.enable_media_repo is False + self.root.server.enable_media_repo is False and config.get("worker_app") != "synapse.app.media_repository" ): self.can_load_media_repo = False diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 05e983625d..9c51b6a25a 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -199,7 +199,7 @@ class SAML2Config(Config): """ import saml2 - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py index 48bf3241b6..bde4e879d9 100644 --- a/synapse/config/server_notices.py +++ b/synapse/config/server_notices.py @@ -73,7 +73,9 @@ class ServerNoticesConfig(Config): return mxid_localpart = c["system_mxid_localpart"] - self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string() + self.server_notices_mxid = UserID( + mxid_localpart, self.root.server.server_name + ).to_string() self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None) self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None) # todo: i18n diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 524a7ff3aa..11a9b76aa0 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -103,8 +103,10 @@ class SSOConfig(Config): # the client's. # public_baseurl is an optional setting, so we only add the fallback's URL to the # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + if self.root.server.public_baseurl: + login_fallback_url = ( + self.root.server.public_baseurl + "_matrix/static/client/login" + ) self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 5a5f124ddf..87e415df75 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -67,12 +67,8 @@ class AccountValidityHandler: and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. - self._template_html = ( - hs.config.account_validity.account_validity_template_html - ) - self._template_text = ( - hs.config.account_validity.account_validity_template_text - ) + self._template_html = hs.config.email.account_validity_template_html + self._template_text = hs.config.email.account_validity_template_text self._renew_email_subject = ( hs.config.account_validity.account_validity_renew_email_subject ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0b79dbcf8d..c05461bf2a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1499,8 +1499,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - check_complexity = self.hs.config.limit_remote_rooms.enabled - if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = self.hs.config.server.limit_remote_rooms.enabled + if ( + check_complexity + and self.hs.config.server.limit_remote_rooms.admins_can_join + ): check_complexity = not await self.auth.is_server_admin(user) if check_complexity: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 37769ace48..961c17762e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -117,7 +117,7 @@ class ReplicationDataHandler: self._instance_name = hs.get_instance_name() self._typing_handler = hs.get_typing_handler() - self._notify_pushers = hs.config.start_pushers + self._notify_pushers = hs.config.worker.start_pushers self._pusher_pool = hs.get_pusherpool() self._presence_handler = hs.get_presence_handler() diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d64d1dbacd..6aa9318027 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -171,7 +171,10 @@ class ReplicationCommandHandler: if hs.config.worker.worker_app is not None: continue - if stream.NAME == FederationStream.NAME and hs.config.send_federation: + if ( + stream.NAME == FederationStream.NAME + and hs.config.worker.send_federation + ): # We only support federation stream if federation sending # has been disabled on the master. continue @@ -225,7 +228,7 @@ class ReplicationCommandHandler: self._is_master = hs.config.worker.worker_app is None self._federation_sender = None - if self._is_master and not hs.config.send_federation: + if self._is_master and not hs.config.worker.send_federation: self._federation_sender = hs.get_federation_sender() self._server_notices_sender = None diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index c9ad35a3ad..9c15a04338 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -48,7 +48,7 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template - self.terms_template = hs.config.terms_template + self.terms_template = hs.config.consent.terms_template self.registration_token_template = ( hs.config.registration.registration_token_template ) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index ecebc46e8d..6f796d5e50 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -61,7 +61,9 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index a7fb8cd848..b81e33964a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -101,7 +101,9 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) @abc.abstractmethod def get_max_push_rules_stream_id(self): diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de262fbf5a..7de4ad7f9b 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1778,7 +1778,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + self._ignore_unknown_session_error = ( + hs.config.server.request_token_inhibit_3pid_errors + ) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") diff --git a/tests/config/test_base.py b/tests/config/test_base.py index baa5313fb3..6a52f862f4 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -14,23 +14,28 @@ import os.path import tempfile +from unittest.mock import Mock from synapse.config import ConfigError +from synapse.config._base import Config from synapse.util.stringutils import random_string from tests import unittest -class BaseConfigTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): - self.hs = hs +class BaseConfigTestCase(unittest.TestCase): + def setUp(self): + # The root object needs a server property with a public_baseurl. + root = Mock() + root.server.public_baseurl = "http://test" + self.config = Config(root) def test_loading_missing_templates(self): # Use a temporary directory that exists on the system, but that isn't likely to # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] + template = self.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +65,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], (tmp_dir,)) + self.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -97,7 +102,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [template_filename], (td.name for td in tempdirs), ) @@ -118,7 +123,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [other_template_name], (td.name for td in tempdirs), ) @@ -134,6 +139,6 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): - self.hs.config.read_templates( + self.config.read_templates( ["some_filename.html"], ("a_nonexistent_directory",) ) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 857d9cd096..f518abdb7a 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -12,39 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.config._base import Config, RootConfig from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.util.caches.lrucache import LruCache from tests.unittest import TestCase -class FakeServer(Config): - section = "server" - - -class TestConfig(RootConfig): - config_classes = [FakeServer, CacheConfig] - - class CacheConfigTests(TestCase): def setUp(self): # Reset caches before each test - TestConfig().caches.reset() + self.config = CacheConfig() + + def tearDown(self): + self.config.reset() def test_individual_caches_from_environ(self): """ Individual cache factors will be loaded from the environment. """ config = {} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_NOT_CACHE": "BLAH", } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) def test_config_overrides_environ(self): """ @@ -52,15 +45,14 @@ class CacheConfigTests(TestCase): over those in the config. """ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_CACHE_FACTOR_FOO": 1, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( - dict(t.caches.cache_factors), + dict(self.config.cache_factors), {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, ) @@ -76,8 +68,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"per_cache_factors": {"foo": 3}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config) self.assertEqual(cache.max_size, 300) @@ -88,8 +79,7 @@ class CacheConfigTests(TestCase): there is one. """ config = {"caches": {"per_cache_factors": {"foo": 2}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -106,8 +96,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"global_factor": 4}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(cache.max_size, 400) @@ -118,8 +107,7 @@ class CacheConfigTests(TestCase): is no per-cache factor. """ config = {"caches": {"global_factor": 1.5}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -133,12 +121,11 @@ class CacheConfigTests(TestCase): "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} } } - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_CACHE_A": "2", "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache_a = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) @@ -158,11 +145,10 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"event_cache_size": "10k"}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache( - max_size=t.caches.event_cache_size, + max_size=self.config.event_cache_size, apply_cache_factor_from_config=False, ) add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 8e49ca26d9..59635de205 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -49,7 +49,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -60,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -74,8 +74,12 @@ class ConfigLoadingTestCase(unittest.TestCase): config1 = HomeServerConfig.load_config("", ["-c", self.file]) config2 = HomeServerConfig.load_config("", ["-c", self.file]) config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key) - self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key) + self.assertEqual( + config1.key.macaroon_secret_key, config2.key.macaroon_secret_key + ) + self.assertEqual( + config1.key.macaroon_secret_key, config3.key.macaroon_secret_key + ) def test_disable_registration(self): self.generate_config() diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index b6bc1876b5..9ba5781573 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -42,9 +42,9 @@ class TLSConfigTests(TestCase): """ config = {} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") def test_tls_client_minimum_set(self): """ @@ -52,29 +52,29 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": 1.1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.1") config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") # Also test a string version config = {"federation_client_minimum_tls_version": "1"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": "1.2"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") def test_tls_client_minimum_1_point_3_missing(self): """ @@ -91,7 +91,7 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() with self.assertRaises(ConfigError) as e: - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( e.exception.args[0], ( @@ -112,8 +112,8 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.3") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3") def test_tls_client_minimum_set_passed_through_1_2(self): """ @@ -121,7 +121,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -137,7 +137,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -159,7 +159,7 @@ class TLSConfigTests(TestCase): } t = TestConfig() e = self.assertRaises( - ConfigError, t.read_config, config, config_dir_path="", data_dir_path="" + ConfigError, t.tls.read_config, config, config_dir_path="", data_dir_path="" ) self.assertIn("IDNA domain names", str(e)) @@ -174,7 +174,7 @@ class TLSConfigTests(TestCase): ] } t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index cf9748f218..f26d5acf9c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -126,7 +126,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.db_pool = database._db_pool self.engine = database.engine - db_config = hs.config.get_single_database() + db_config = hs.config.database.get_single_database() self.store = TestTransactionStore( database, make_conn(db_config, self.engine, "test"), hs ) diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index 6ff3ebb137..ace82cbf42 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -22,7 +22,7 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): return self.setup_test_homeserver(db_txn_limit=1000) def test_config(self): - db_config = self.hs.config.get_single_database() + db_config = self.hs.config.database.get_single_database() self.assertEqual(db_config.config["txn_limit"], 1000) def test_select(self): -- cgit 1.5.1 From a7d22c36dbbbdd396aeb8938b57b5fd7edb689f3 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Fri, 8 Oct 2021 18:35:00 -0500 Subject: Refactor MSC2716 `/batch_send` endpoint into separate handler functions (#10974) --- changelog.d/10974.misc | 1 + synapse/handlers/room_batch.py | 423 ++++++++++++++++++++++++++++++++++++++ synapse/rest/client/room_batch.py | 339 +++++------------------------- synapse/server.py | 5 + 4 files changed, 485 insertions(+), 283 deletions(-) create mode 100644 changelog.d/10974.misc create mode 100644 synapse/handlers/room_batch.py (limited to 'synapse/rest/client') diff --git a/changelog.d/10974.misc b/changelog.d/10974.misc new file mode 100644 index 0000000000..8695b378aa --- /dev/null +++ b/changelog.d/10974.misc @@ -0,0 +1 @@ +Refactor [MSC2716](https://github.com/matrix-org/matrix-doc/pull/2716) `/batch_send` mega function into smaller handler functions. diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py new file mode 100644 index 0000000000..51dd4e7555 --- /dev/null +++ b/synapse/handlers/room_batch.py @@ -0,0 +1,423 @@ +import logging +from typing import TYPE_CHECKING, List, Tuple + +from synapse.api.constants import EventContentFields, EventTypes +from synapse.appservice import ApplicationService +from synapse.http.servlet import assert_params_in_dict +from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.util.stringutils import random_string + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class RoomBatchHandler: + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.store = hs.get_datastore() + self.state_store = hs.get_storage().state + self.event_creation_handler = hs.get_event_creation_handler() + self.room_member_handler = hs.get_room_member_handler() + self.auth = hs.get_auth() + + async def inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: + """Finds the depth which would sort it after the most-recent + prev_event_id but before the successors of those events. If no + successors are found, we assume it's an historical extremity part of the + current batch and use the same depth of the prev_event_ids. + + Args: + prev_event_ids: List of prev event IDs + + Returns: + Inherited depth + """ + ( + most_recent_prev_event_id, + most_recent_prev_event_depth, + ) = await self.store.get_max_depth_of(prev_event_ids) + + # We want to insert the historical event after the `prev_event` but before the successor event + # + # We inherit depth from the successor event instead of the `prev_event` + # because events returned from `/messages` are first sorted by `topological_ordering` + # which is just the `depth` and then tie-break with `stream_ordering`. + # + # We mark these inserted historical events as "backfilled" which gives them a + # negative `stream_ordering`. If we use the same depth as the `prev_event`, + # then our historical event will tie-break and be sorted before the `prev_event` + # when it should come after. + # + # We want to use the successor event depth so they appear after `prev_event` because + # it has a larger `depth` but before the successor event because the `stream_ordering` + # is negative before the successor event. + successor_event_ids = await self.store.get_successor_events( + [most_recent_prev_event_id] + ) + + # If we can't find any successor events, then it's a forward extremity of + # historical messages and we can just inherit from the previous historical + # event which we can already assume has the correct depth where we want + # to insert into. + if not successor_event_ids: + depth = most_recent_prev_event_depth + else: + ( + _, + oldest_successor_depth, + ) = await self.store.get_min_depth_of(successor_event_ids) + + depth = oldest_successor_depth + + return depth + + def create_insertion_event_dict( + self, sender: str, room_id: str, origin_server_ts: int + ) -> JsonDict: + """Creates an event dict for an "insertion" event with the proper fields + and a random batch ID. + + Args: + sender: The event author MXID + room_id: The room ID that the event belongs to + origin_server_ts: Timestamp when the event was sent + + Returns: + The new event dictionary to insert. + """ + + next_batch_id = random_string(8) + insertion_event = { + "type": EventTypes.MSC2716_INSERTION, + "sender": sender, + "room_id": room_id, + "content": { + EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id, + EventContentFields.MSC2716_HISTORICAL: True, + }, + "origin_server_ts": origin_server_ts, + } + + return insertion_event + + async def create_requester_for_user_id_from_app_service( + self, user_id: str, app_service: ApplicationService + ) -> Requester: + """Creates a new requester for the given user_id + and validates that the app service is allowed to control + the given user. + + Args: + user_id: The author MXID that the app service is controlling + app_service: The app service that controls the user + + Returns: + Requester object + """ + + await self.auth.validate_appservice_can_control_user_id(app_service, user_id) + + return create_requester(user_id, app_service=app_service) + + async def get_most_recent_auth_event_ids_from_event_id_list( + self, event_ids: List[str] + ) -> List[str]: + """Find the most recent auth event ids (derived from state events) that + allowed that message to be sent. We will use this as a base + to auth our historical messages against. + + Args: + event_ids: List of event ID's to look at + + Returns: + List of event ID's + """ + + ( + most_recent_prev_event_id, + _, + ) = await self.store.get_max_depth_of(event_ids) + # mapping from (type, state_key) -> state_event_id + prev_state_map = await self.state_store.get_state_ids_for_event( + most_recent_prev_event_id + ) + # List of state event ID's + prev_state_ids = list(prev_state_map.values()) + auth_event_ids = prev_state_ids + + return auth_event_ids + + async def persist_state_events_at_start( + self, + state_events_at_start: List[JsonDict], + room_id: str, + initial_auth_event_ids: List[str], + app_service_requester: Requester, + ) -> List[str]: + """Takes all `state_events_at_start` event dictionaries and creates/persists + them as floating state events which don't resolve into the current room state. + They are floating because they reference a fake prev_event which doesn't connect + to the normal DAG at all. + + Args: + state_events_at_start: + room_id: Room where you want the events persisted in. + initial_auth_event_ids: These will be the auth_events for the first + state event created. Each event created afterwards will be + added to the list of auth events for the next state event + created. + app_service_requester: The requester of an application service. + + Returns: + List of state event ID's we just persisted + """ + assert app_service_requester.app_service + + state_event_ids_at_start = [] + auth_event_ids = initial_auth_event_ids.copy() + for state_event in state_events_at_start: + assert_params_in_dict( + state_event, ["type", "origin_server_ts", "content", "sender"] + ) + + logger.debug( + "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", + state_event, + auth_event_ids, + ) + + event_dict = { + "type": state_event["type"], + "origin_server_ts": state_event["origin_server_ts"], + "content": state_event["content"], + "room_id": room_id, + "sender": state_event["sender"], + "state_key": state_event["state_key"], + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + # Make the state events float off on their own so we don't have a + # bunch of `@mxid joined the room` noise between each batch + fake_prev_event_id = "$" + random_string(43) + + # TODO: This is pretty much the same as some other code to handle inserting state in this file + if event_dict["type"] == EventTypes.Member: + membership = event_dict["content"].get("membership", None) + event_id, _ = await self.room_member_handler.update_membership( + await self.create_requester_for_user_id_from_app_service( + state_event["sender"], app_service_requester.app_service + ), + target=UserID.from_string(event_dict["state_key"]), + room_id=room_id, + action=membership, + content=event_dict["content"], + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + else: + # TODO: Add some complement tests that adds state that is not member joins + # and will use this code path. Maybe we only want to support join state events + # and can get rid of this `else`? + ( + event, + _, + ) = await self.event_creation_handler.create_and_send_nonmember_event( + await self.create_requester_for_user_id_from_app_service( + state_event["sender"], app_service_requester.app_service + ), + event_dict, + outlier=True, + prev_event_ids=[fake_prev_event_id], + # Make sure to use a copy of this list because we modify it + # later in the loop here. Otherwise it will be the same + # reference and also update in the event when we append later. + auth_event_ids=auth_event_ids.copy(), + ) + event_id = event.event_id + + state_event_ids_at_start.append(event_id) + auth_event_ids.append(event_id) + + return state_event_ids_at_start + + async def persist_historical_events( + self, + events_to_create: List[JsonDict], + room_id: str, + initial_prev_event_ids: List[str], + inherited_depth: int, + auth_event_ids: List[str], + app_service_requester: Requester, + ) -> List[str]: + """Create and persists all events provided sequentially. Handles the + complexity of creating events in chronological order so they can + reference each other by prev_event but still persists in + reverse-chronoloical order so they have the correct + (topological_ordering, stream_ordering) and sort correctly from + /messages. + + Args: + events_to_create: List of historical events to create in JSON + dictionary format. + room_id: Room where you want the events persisted in. + initial_prev_event_ids: These will be the prev_events for the first + event created. Each event created afterwards will point to the + previous event created. + inherited_depth: The depth to create the events at (you will + probably by calling inherit_depth_from_prev_ids(...)). + auth_event_ids: Define which events allow you to create the given + event in the room. + app_service_requester: The requester of an application service. + + Returns: + List of persisted event IDs + """ + assert app_service_requester.app_service + + prev_event_ids = initial_prev_event_ids.copy() + + event_ids = [] + events_to_persist = [] + for ev in events_to_create: + assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) + + event_dict = { + "type": ev["type"], + "origin_server_ts": ev["origin_server_ts"], + "content": ev["content"], + "room_id": room_id, + "sender": ev["sender"], # requester.user.to_string(), + "prev_events": prev_event_ids.copy(), + } + + # Mark all events as historical + event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True + + event, context = await self.event_creation_handler.create_event( + await self.create_requester_for_user_id_from_app_service( + ev["sender"], app_service_requester.app_service + ), + event_dict, + prev_event_ids=event_dict.get("prev_events"), + auth_event_ids=auth_event_ids, + historical=True, + depth=inherited_depth, + ) + logger.debug( + "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", + event, + prev_event_ids, + auth_event_ids, + ) + + assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( + event.sender, + ) + + events_to_persist.append((event, context)) + event_id = event.event_id + + event_ids.append(event_id) + prev_event_ids = [event_id] + + # Persist events in reverse-chronological order so they have the + # correct stream_ordering as they are backfilled (which decrements). + # Events are sorted by (topological_ordering, stream_ordering) + # where topological_ordering is just depth. + for (event, context) in reversed(events_to_persist): + await self.event_creation_handler.handle_new_client_event( + await self.create_requester_for_user_id_from_app_service( + event["sender"], app_service_requester.app_service + ), + event=event, + context=context, + ) + + return event_ids + + async def handle_batch_of_events( + self, + events_to_create: List[JsonDict], + room_id: str, + batch_id_to_connect_to: str, + initial_prev_event_ids: List[str], + inherited_depth: int, + auth_event_ids: List[str], + app_service_requester: Requester, + ) -> Tuple[List[str], str]: + """ + Handles creating and persisting all of the historical events as well + as insertion and batch meta events to make the batch navigable in the DAG. + + Args: + events_to_create: List of historical events to create in JSON + dictionary format. + room_id: Room where you want the events created in. + batch_id_to_connect_to: The batch_id from the insertion event you + want this batch to connect to. + initial_prev_event_ids: These will be the prev_events for the first + event created. Each event created afterwards will point to the + previous event created. + inherited_depth: The depth to create the events at (you will + probably by calling inherit_depth_from_prev_ids(...)). + auth_event_ids: Define which events allow you to create the given + event in the room. + app_service_requester: The requester of an application service. + + Returns: + Tuple containing a list of created events and the next_batch_id + """ + + # Connect this current batch to the insertion event from the previous batch + last_event_in_batch = events_to_create[-1] + batch_event = { + "type": EventTypes.MSC2716_BATCH, + "sender": app_service_requester.user.to_string(), + "room_id": room_id, + "content": { + EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to, + EventContentFields.MSC2716_HISTORICAL: True, + }, + # Since the batch event is put at the end of the batch, + # where the newest-in-time event is, copy the origin_server_ts from + # the last event we're inserting + "origin_server_ts": last_event_in_batch["origin_server_ts"], + } + # Add the batch event to the end of the batch (newest-in-time) + events_to_create.append(batch_event) + + # Add an "insertion" event to the start of each batch (next to the oldest-in-time + # event in the batch) so the next batch can be connected to this one. + insertion_event = self.create_insertion_event_dict( + sender=app_service_requester.user.to_string(), + room_id=room_id, + # Since the insertion event is put at the start of the batch, + # where the oldest-in-time event is, copy the origin_server_ts from + # the first event we're inserting + origin_server_ts=events_to_create[0]["origin_server_ts"], + ) + next_batch_id = insertion_event["content"][ + EventContentFields.MSC2716_NEXT_BATCH_ID + ] + # Prepend the insertion event to the start of the batch (oldest-in-time) + events_to_create = [insertion_event] + events_to_create + + # Create and persist all of the historical events + event_ids = await self.persist_historical_events( + events_to_create=events_to_create, + room_id=room_id, + initial_prev_event_ids=initial_prev_event_ids, + inherited_depth=inherited_depth, + auth_event_ids=auth_event_ids, + app_service_requester=app_service_requester, + ) + + return event_ids, next_batch_id diff --git a/synapse/rest/client/room_batch.py b/synapse/rest/client/room_batch.py index 1dffcc3147..38ad4c2447 100644 --- a/synapse/rest/client/room_batch.py +++ b/synapse/rest/client/room_batch.py @@ -15,13 +15,12 @@ import logging import re from http import HTTPStatus -from typing import TYPE_CHECKING, Awaitable, List, Tuple +from typing import TYPE_CHECKING, Awaitable, Tuple from twisted.web.server import Request -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields from synapse.api.errors import AuthError, Codes, SynapseError -from synapse.appservice import ApplicationService from synapse.http.server import HttpServer from synapse.http.servlet import ( RestServlet, @@ -32,7 +31,7 @@ from synapse.http.servlet import ( ) from synapse.http.site import SynapseRequest from synapse.rest.client.transactions import HttpTransactionCache -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import JsonDict from synapse.util.stringutils import random_string if TYPE_CHECKING: @@ -77,102 +76,12 @@ class RoomBatchSendEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.store = hs.get_datastore() - self.state_store = hs.get_storage().state self.event_creation_handler = hs.get_event_creation_handler() - self.room_member_handler = hs.get_room_member_handler() self.auth = hs.get_auth() + self.room_batch_handler = hs.get_room_batch_handler() self.txns = HttpTransactionCache(hs) - async def _inherit_depth_from_prev_ids(self, prev_event_ids: List[str]) -> int: - ( - most_recent_prev_event_id, - most_recent_prev_event_depth, - ) = await self.store.get_max_depth_of(prev_event_ids) - - # We want to insert the historical event after the `prev_event` but before the successor event - # - # We inherit depth from the successor event instead of the `prev_event` - # because events returned from `/messages` are first sorted by `topological_ordering` - # which is just the `depth` and then tie-break with `stream_ordering`. - # - # We mark these inserted historical events as "backfilled" which gives them a - # negative `stream_ordering`. If we use the same depth as the `prev_event`, - # then our historical event will tie-break and be sorted before the `prev_event` - # when it should come after. - # - # We want to use the successor event depth so they appear after `prev_event` because - # it has a larger `depth` but before the successor event because the `stream_ordering` - # is negative before the successor event. - successor_event_ids = await self.store.get_successor_events( - [most_recent_prev_event_id] - ) - - # If we can't find any successor events, then it's a forward extremity of - # historical messages and we can just inherit from the previous historical - # event which we can already assume has the correct depth where we want - # to insert into. - if not successor_event_ids: - depth = most_recent_prev_event_depth - else: - ( - _, - oldest_successor_depth, - ) = await self.store.get_min_depth_of(successor_event_ids) - - depth = oldest_successor_depth - - return depth - - def _create_insertion_event_dict( - self, sender: str, room_id: str, origin_server_ts: int - ) -> JsonDict: - """Creates an event dict for an "insertion" event with the proper fields - and a random batch ID. - - Args: - sender: The event author MXID - room_id: The room ID that the event belongs to - origin_server_ts: Timestamp when the event was sent - - Returns: - The new event dictionary to insert. - """ - - next_batch_id = random_string(8) - insertion_event = { - "type": EventTypes.MSC2716_INSERTION, - "sender": sender, - "room_id": room_id, - "content": { - EventContentFields.MSC2716_NEXT_BATCH_ID: next_batch_id, - EventContentFields.MSC2716_HISTORICAL: True, - }, - "origin_server_ts": origin_server_ts, - } - - return insertion_event - - async def _create_requester_for_user_id_from_app_service( - self, user_id: str, app_service: ApplicationService - ) -> Requester: - """Creates a new requester for the given user_id - and validates that the app service is allowed to control - the given user. - - Args: - user_id: The author MXID that the app service is controlling - app_service: The app service that controls the user - - Returns: - Requester object - """ - - await self.auth.validate_appservice_can_control_user_id(app_service, user_id) - - return create_requester(user_id, app_service=app_service) - async def on_POST( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: @@ -200,123 +109,62 @@ class RoomBatchSendEventRestServlet(RestServlet): errcode=Codes.MISSING_PARAM, ) + # Verify the batch_id_from_query corresponds to an actual insertion event + # and have the batch connected. + if batch_id_from_query: + corresponding_insertion_event_id = ( + await self.store.get_insertion_event_by_batch_id( + room_id, batch_id_from_query + ) + ) + if corresponding_insertion_event_id is None: + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "No insertion event corresponds to the given ?batch_id", + errcode=Codes.INVALID_PARAM, + ) + # For the event we are inserting next to (`prev_event_ids_from_query`), # find the most recent auth events (derived from state events) that # allowed that message to be sent. We will use that as a base # to auth our historical messages against. - ( - most_recent_prev_event_id, - _, - ) = await self.store.get_max_depth_of(prev_event_ids_from_query) - # mapping from (type, state_key) -> state_event_id - prev_state_map = await self.state_store.get_state_ids_for_event( - most_recent_prev_event_id + auth_event_ids = await self.room_batch_handler.get_most_recent_auth_event_ids_from_event_id_list( + prev_event_ids_from_query ) - # List of state event ID's - prev_state_ids = list(prev_state_map.values()) - auth_event_ids = prev_state_ids - - state_event_ids_at_start = [] - for state_event in body["state_events_at_start"]: - assert_params_in_dict( - state_event, ["type", "origin_server_ts", "content", "sender"] - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting state_event=%s, auth_event_ids=%s", - state_event, - auth_event_ids, + # Create and persist all of the state events that float off on their own + # before the batch. These will most likely be all of the invite/member + # state events used to auth the upcoming historical messages. + state_event_ids_at_start = ( + await self.room_batch_handler.persist_state_events_at_start( + state_events_at_start=body["state_events_at_start"], + room_id=room_id, + initial_auth_event_ids=auth_event_ids, + app_service_requester=requester, ) + ) + # Update our ongoing auth event ID list with all of the new state we + # just created + auth_event_ids.extend(state_event_ids_at_start) - event_dict = { - "type": state_event["type"], - "origin_server_ts": state_event["origin_server_ts"], - "content": state_event["content"], - "room_id": room_id, - "sender": state_event["sender"], - "state_key": state_event["state_key"], - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - # Make the state events float off on their own - fake_prev_event_id = "$" + random_string(43) - - # TODO: This is pretty much the same as some other code to handle inserting state in this file - if event_dict["type"] == EventTypes.Member: - membership = event_dict["content"].get("membership", None) - event_id, _ = await self.room_member_handler.update_membership( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - target=UserID.from_string(event_dict["state_key"]), - room_id=room_id, - action=membership, - content=event_dict["content"], - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - else: - # TODO: Add some complement tests that adds state that is not member joins - # and will use this code path. Maybe we only want to support join state events - # and can get rid of this `else`? - ( - event, - _, - ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( - state_event["sender"], requester.app_service - ), - event_dict, - outlier=True, - prev_event_ids=[fake_prev_event_id], - # Make sure to use a copy of this list because we modify it - # later in the loop here. Otherwise it will be the same - # reference and also update in the event when we append later. - auth_event_ids=auth_event_ids.copy(), - ) - event_id = event.event_id - - state_event_ids_at_start.append(event_id) - auth_event_ids.append(event_id) - - events_to_create = body["events"] - - inherited_depth = await self._inherit_depth_from_prev_ids( + inherited_depth = await self.room_batch_handler.inherit_depth_from_prev_ids( prev_event_ids_from_query ) + events_to_create = body["events"] + # Figure out which batch to connect to. If they passed in # batch_id_from_query let's use it. The batch ID passed in comes # from the batch_id in the "insertion" event from the previous batch. last_event_in_batch = events_to_create[-1] - batch_id_to_connect_to = batch_id_from_query base_insertion_event = None if batch_id_from_query: + batch_id_to_connect_to = batch_id_from_query # All but the first base insertion event should point at a fake # event, which causes the HS to ask for the state at the start of # the batch later. + fake_prev_event_id = "$" + random_string(43) prev_event_ids = [fake_prev_event_id] - - # Verify the batch_id_from_query corresponds to an actual insertion event - # and have the batch connected. - corresponding_insertion_event_id = ( - await self.store.get_insertion_event_by_batch_id( - room_id, batch_id_from_query - ) - ) - if corresponding_insertion_event_id is None: - raise SynapseError( - HTTPStatus.BAD_REQUEST, - "No insertion event corresponds to the given ?batch_id", - errcode=Codes.INVALID_PARAM, - ) - pass # Otherwise, create an insertion event to act as a starting point. # # We don't always have an insertion event to start hanging more history @@ -327,10 +175,12 @@ class RoomBatchSendEventRestServlet(RestServlet): else: prev_event_ids = prev_event_ids_from_query - base_insertion_event_dict = self._create_insertion_event_dict( - sender=requester.user.to_string(), - room_id=room_id, - origin_server_ts=last_event_in_batch["origin_server_ts"], + base_insertion_event_dict = ( + self.room_batch_handler.create_insertion_event_dict( + sender=requester.user.to_string(), + room_id=room_id, + origin_server_ts=last_event_in_batch["origin_server_ts"], + ) ) base_insertion_event_dict["prev_events"] = prev_event_ids.copy() @@ -338,7 +188,7 @@ class RoomBatchSendEventRestServlet(RestServlet): base_insertion_event, _, ) = await self.event_creation_handler.create_and_send_nonmember_event( - await self._create_requester_for_user_id_from_app_service( + await self.room_batch_handler.create_requester_for_user_id_from_app_service( base_insertion_event_dict["sender"], requester.app_service, ), @@ -353,92 +203,17 @@ class RoomBatchSendEventRestServlet(RestServlet): EventContentFields.MSC2716_NEXT_BATCH_ID ] - # Connect this current batch to the insertion event from the previous batch - batch_event = { - "type": EventTypes.MSC2716_BATCH, - "sender": requester.user.to_string(), - "room_id": room_id, - "content": { - EventContentFields.MSC2716_BATCH_ID: batch_id_to_connect_to, - EventContentFields.MSC2716_HISTORICAL: True, - }, - # Since the batch event is put at the end of the batch, - # where the newest-in-time event is, copy the origin_server_ts from - # the last event we're inserting - "origin_server_ts": last_event_in_batch["origin_server_ts"], - } - # Add the batch event to the end of the batch (newest-in-time) - events_to_create.append(batch_event) - - # Add an "insertion" event to the start of each batch (next to the oldest-in-time - # event in the batch) so the next batch can be connected to this one. - insertion_event = self._create_insertion_event_dict( - sender=requester.user.to_string(), + # Create and persist all of the historical events as well as insertion + # and batch meta events to make the batch navigable in the DAG. + event_ids, next_batch_id = await self.room_batch_handler.handle_batch_of_events( + events_to_create=events_to_create, room_id=room_id, - # Since the insertion event is put at the start of the batch, - # where the oldest-in-time event is, copy the origin_server_ts from - # the first event we're inserting - origin_server_ts=events_to_create[0]["origin_server_ts"], + batch_id_to_connect_to=batch_id_to_connect_to, + initial_prev_event_ids=prev_event_ids, + inherited_depth=inherited_depth, + auth_event_ids=auth_event_ids, + app_service_requester=requester, ) - # Prepend the insertion event to the start of the batch (oldest-in-time) - events_to_create = [insertion_event] + events_to_create - - event_ids = [] - events_to_persist = [] - for ev in events_to_create: - assert_params_in_dict(ev, ["type", "origin_server_ts", "content", "sender"]) - - event_dict = { - "type": ev["type"], - "origin_server_ts": ev["origin_server_ts"], - "content": ev["content"], - "room_id": room_id, - "sender": ev["sender"], # requester.user.to_string(), - "prev_events": prev_event_ids.copy(), - } - - # Mark all events as historical - event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True - - event, context = await self.event_creation_handler.create_event( - await self._create_requester_for_user_id_from_app_service( - ev["sender"], requester.app_service - ), - event_dict, - prev_event_ids=event_dict.get("prev_events"), - auth_event_ids=auth_event_ids, - historical=True, - depth=inherited_depth, - ) - logger.debug( - "RoomBatchSendEventRestServlet inserting event=%s, prev_event_ids=%s, auth_event_ids=%s", - event, - prev_event_ids, - auth_event_ids, - ) - - assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % ( - event.sender, - ) - - events_to_persist.append((event, context)) - event_id = event.event_id - - event_ids.append(event_id) - prev_event_ids = [event_id] - - # Persist events in reverse-chronological order so they have the - # correct stream_ordering as they are backfilled (which decrements). - # Events are sorted by (topological_ordering, stream_ordering) - # where topological_ordering is just depth. - for (event, context) in reversed(events_to_persist): - ev = await self.event_creation_handler.handle_new_client_event( - await self._create_requester_for_user_id_from_app_service( - event["sender"], requester.app_service - ), - event=event, - context=context, - ) insertion_event_id = event_ids[0] batch_event_id = event_ids[-1] @@ -447,9 +222,7 @@ class RoomBatchSendEventRestServlet(RestServlet): response_dict = { "state_event_ids": state_event_ids_at_start, "event_ids": historical_event_ids, - "next_batch_id": insertion_event["content"][ - EventContentFields.MSC2716_NEXT_BATCH_ID - ], + "next_batch_id": next_batch_id, "insertion_event_id": insertion_event_id, "batch_event_id": batch_event_id, } diff --git a/synapse/server.py b/synapse/server.py index 0783df41d4..5bc045d615 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -97,6 +97,7 @@ from synapse.handlers.room import ( RoomCreationHandler, RoomShutdownHandler, ) +from synapse.handlers.room_batch import RoomBatchHandler from synapse.handlers.room_list import RoomListHandler from synapse.handlers.room_member import RoomMemberHandler, RoomMemberMasterHandler from synapse.handlers.room_member_worker import RoomMemberWorkerHandler @@ -437,6 +438,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_room_creation_handler(self) -> RoomCreationHandler: return RoomCreationHandler(self) + @cache_in_self + def get_room_batch_handler(self) -> RoomBatchHandler: + return RoomBatchHandler(self) + @cache_in_self def get_room_shutdown_handler(self) -> RoomShutdownHandler: return RoomShutdownHandler(self) -- cgit 1.5.1