From 6fc8be9a1b2046e69e8c6f731442887e3addeec0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 22 Sep 2021 09:45:20 -0400 Subject: Include more information in oEmbed previews. (#10819) * Improved titles (fall back to the author name if there's not title) and include the site name. * Handle photo/video payloads. * Include the original URL in the Open Graph response. * Fix the expiration time (by properly converting from seconds to milliseconds). --- tests/rest/media/v1/test_url_preview.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) (limited to 'tests/rest') diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 9d13899584..d83dfacfed 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -620,11 +620,12 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertIn(b"/matrixdotorg", server.data) self.assertEqual(channel.code, 200) - self.assertIsNone(channel.json_body["og:title"]) - self.assertTrue(channel.json_body["og:image"].startswith("mxc://")) - self.assertEqual(channel.json_body["og:image:height"], 1) - self.assertEqual(channel.json_body["og:image:width"], 1) - self.assertEqual(channel.json_body["og:image:type"], "image/png") + body = channel.json_body + self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345") + self.assertTrue(body["og:image"].startswith("mxc://")) + self.assertEqual(body["og:image:height"], 1) + self.assertEqual(body["og:image:width"], 1) + self.assertEqual(body["og:image:type"], "image/png") def test_oembed_rich(self): """Test an oEmbed endpoint which returns HTML content via the 'rich' type.""" @@ -633,6 +634,8 @@ class URLPreviewTests(unittest.HomeserverTestCase): result = { "version": "1.0", "type": "rich", + # Note that this provides the author, not the title. + "author_name": "Alice", "html": "
Content Preview
", } end_content = json.dumps(result).encode("utf-8") @@ -660,9 +663,14 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.pump() self.assertEqual(channel.code, 200) + body = channel.json_body self.assertEqual( - channel.json_body, - {"og:title": None, "og:description": "Content Preview"}, + body, + { + "og:url": "http://twitter.com/matrixdotorg/status/12345", + "og:title": "Alice", + "og:description": "Content Preview", + }, ) def test_oembed_format(self): @@ -705,7 +713,11 @@ class URLPreviewTests(unittest.HomeserverTestCase): self.assertIn(b"format=json", server.data) self.assertEqual(channel.code, 200) + body = channel.json_body self.assertEqual( - channel.json_body, - {"og:title": None, "og:description": "Content Preview"}, + body, + { + "og:url": "http://www.hulu.com/watch/12345", + "og:description": "Content Preview", + }, ) -- cgit 1.5.1 From e584534403b55ad3f250f92592e30b15b01f0201 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Sep 2021 07:13:34 -0400 Subject: Use direct references for some configuration variables (part 3) (#10885) This avoids the overhead of searching through the various configuration classes by directly referencing the class that the attributes are in. It also improves type hints since mypy can now resolve the types of the configuration variables. --- changelog.d/10885.misc | 1 + synapse/app/homeserver.py | 2 +- synapse/config/consent.py | 9 +++-- synapse/handlers/account_validity.py | 2 +- synapse/handlers/appservice.py | 2 +- synapse/handlers/auth.py | 22 ++++++------ synapse/handlers/cas.py | 8 ++--- synapse/handlers/identity.py | 12 +++---- synapse/handlers/message.py | 4 +-- synapse/handlers/password_policy.py | 4 +-- synapse/handlers/register.py | 11 +++--- synapse/handlers/ui_auth/checkers.py | 17 +++++---- synapse/module_api/__init__.py | 8 +++-- synapse/push/pusher.py | 2 +- synapse/rest/admin/users.py | 4 +-- synapse/rest/client/account.py | 40 +++++++++++----------- synapse/rest/client/auth.py | 10 +++--- synapse/rest/client/login.py | 4 +-- synapse/rest/client/password_policy.py | 4 +-- synapse/rest/client/register.py | 30 ++++++++-------- synapse/rest/consent/consent_resource.py | 9 ++--- synapse/rest/synapse/client/password_reset.py | 10 +++--- synapse/server_notices/consent_server_notices.py | 11 ++++-- synapse/storage/databases/main/appservice.py | 2 +- .../storage/databases/main/monthly_active_users.py | 2 +- synapse/storage/databases/main/registration.py | 2 +- synapse/storage/prepare_database.py | 2 +- synapse/storage/schema/main/delta/30/as_users.py | 2 +- tests/rest/admin/test_room.py | 2 +- tests/rest/client/test_login.py | 2 +- tests/storage/test_appservice.py | 14 +++----- tests/storage/test_cleanup_extrems.py | 2 +- 32 files changed, 137 insertions(+), 119 deletions(-) create mode 100644 changelog.d/10885.misc (limited to 'tests/rest') diff --git a/changelog.d/10885.misc b/changelog.d/10885.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10885.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index b909f8db8d..886e291e4c 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -195,7 +195,7 @@ class SynapseHomeServer(HomeServer): } ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: from synapse.rest.synapse.client.password_reset import ( PasswordResetSubmitTokenResource, ) diff --git a/synapse/config/consent.py b/synapse/config/consent.py index b05a9bd97f..ecc43b08b9 100644 --- a/synapse/config/consent.py +++ b/synapse/config/consent.py @@ -13,6 +13,7 @@ # limitations under the License. from os import path +from typing import Optional from synapse.config import ConfigError @@ -78,8 +79,8 @@ class ConsentConfig(Config): def __init__(self, *args): super().__init__(*args) - self.user_consent_version = None - self.user_consent_template_dir = None + self.user_consent_version: Optional[str] = None + self.user_consent_template_dir: Optional[str] = None self.user_consent_server_notice_content = None self.user_consent_server_notice_to_guests = False self.block_events_without_consent_error = None @@ -94,7 +95,9 @@ class ConsentConfig(Config): return self.user_consent_version = str(consent_config["version"]) self.user_consent_template_dir = self.abspath(consent_config["template_dir"]) - if not path.isdir(self.user_consent_template_dir): + if not isinstance(self.user_consent_template_dir, str) or not path.isdir( + self.user_consent_template_dir + ): raise ConfigError( "Could not find template directory '%s'" % (self.user_consent_template_dir,) diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 4724565ba5..5a5f124ddf 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -47,7 +47,7 @@ class AccountValidityHandler: self.send_email_handler = self.hs.get_send_email_handler() self.clock = self.hs.get_clock() - self._app_name = self.hs.config.email_app_name + self._app_name = self.hs.config.email.email_app_name self._account_validity_enabled = ( hs.config.account_validity.account_validity_enabled diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index b7213b67a5..163278708c 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -52,7 +52,7 @@ class ApplicationServicesHandler: self.scheduler = hs.get_application_service_scheduler() self.started_scheduler = False self.clock = hs.get_clock() - self.notify_appservices = hs.config.notify_appservices + self.notify_appservices = hs.config.appservice.notify_appservices self.event_sources = hs.get_event_sources() self.current_max = 0 diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index bcd4249e09..b747f80bc1 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -210,15 +210,15 @@ class AuthHandler(BaseHandler): self.password_providers = [ PasswordProvider.load(module, config, account_handler) - for module, config in hs.config.password_providers + for module, config in hs.config.authproviders.password_providers ] logger.info("Extra password_providers: %s", self.password_providers) self.hs = hs # FIXME better possibility to access registrationHandler later? self.macaroon_gen = hs.get_macaroon_generator() - self._password_enabled = hs.config.password_enabled - self._password_localdb_enabled = hs.config.password_localdb_enabled + self._password_enabled = hs.config.auth.password_enabled + self._password_localdb_enabled = hs.config.auth.password_localdb_enabled # start out by assuming PASSWORD is enabled; we will remove it later if not. login_types = set() @@ -250,7 +250,7 @@ class AuthHandler(BaseHandler): ) # The number of seconds to keep a UI auth session active. - self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout + self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( @@ -739,19 +739,19 @@ class AuthHandler(BaseHandler): return canonical_id def _get_params_recaptcha(self) -> dict: - return {"public_key": self.hs.config.recaptcha_public_key} + return {"public_key": self.hs.config.captcha.recaptcha_public_key} def _get_params_terms(self) -> dict: return { "policies": { "privacy_policy": { - "version": self.hs.config.user_consent_version, + "version": self.hs.config.consent.user_consent_version, "en": { - "name": self.hs.config.user_consent_policy_name, + "name": self.hs.config.consent.user_consent_policy_name, "url": "%s_matrix/consent?v=%s" % ( self.hs.config.server.public_baseurl, - self.hs.config.user_consent_version, + self.hs.config.consent.user_consent_version, ), }, } @@ -1016,7 +1016,7 @@ class AuthHandler(BaseHandler): def can_change_password(self) -> bool: """Get whether users on this server are allowed to change or set a password. - Both `config.password_enabled` and `config.password_localdb_enabled` must be true. + Both `config.auth.password_enabled` and `config.auth.password_localdb_enabled` must be true. Note that any account (even SSO accounts) are allowed to add passwords if the above is true. @@ -1486,7 +1486,7 @@ class AuthHandler(BaseHandler): pw = unicodedata.normalize("NFKC", password) return bcrypt.hashpw( - pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.auth.password_pepper.encode("utf8"), bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") @@ -1510,7 +1510,7 @@ class AuthHandler(BaseHandler): pw = unicodedata.normalize("NFKC", password) return bcrypt.checkpw( - pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), + pw.encode("utf8") + self.hs.config.auth.password_pepper.encode("utf8"), checked_hash, ) diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py index b0b188dc78..5d8f6c50a9 100644 --- a/synapse/handlers/cas.py +++ b/synapse/handlers/cas.py @@ -65,10 +65,10 @@ class CasHandler: self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() - self._cas_server_url = hs.config.cas_server_url - self._cas_service_url = hs.config.cas_service_url - self._cas_displayname_attribute = hs.config.cas_displayname_attribute - self._cas_required_attributes = hs.config.cas_required_attributes + self._cas_server_url = hs.config.cas.cas_server_url + self._cas_service_url = hs.config.cas.cas_service_url + self._cas_displayname_attribute = hs.config.cas.cas_displayname_attribute + self._cas_required_attributes = hs.config.cas.cas_required_attributes self._http_client = hs.get_proxied_http_client() diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 8b8f1f41ca..fe8a995892 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -62,7 +62,7 @@ class IdentityHandler(BaseHandler): self.federation_http_client = hs.get_federation_http_client() self.hs = hs - self._web_client_location = hs.config.invite_client_location + self._web_client_location = hs.config.email.invite_client_location # Ratelimiters for `/requestToken` endpoints. self._3pid_validation_ratelimiter_ip = Ratelimiter( @@ -419,7 +419,7 @@ class IdentityHandler(BaseHandler): token_expires = ( self.hs.get_clock().time_msec() - + self.hs.config.email_validation_token_lifetime + + self.hs.config.email.email_validation_token_lifetime ) await self.store.start_or_continue_validation_session( @@ -465,7 +465,7 @@ class IdentityHandler(BaseHandler): if next_link: params["next_link"] = next_link - if self.hs.config.using_identity_server_from_trusted_list: + if self.hs.config.email.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use logger.warning( 'The config option "trust_identity_server_for_password_resets" ' @@ -518,7 +518,7 @@ class IdentityHandler(BaseHandler): if next_link: params["next_link"] = next_link - if self.hs.config.using_identity_server_from_trusted_list: + if self.hs.config.email.using_identity_server_from_trusted_list: # Warn that a deprecated config option is in use logger.warning( 'The config option "trust_identity_server_for_password_resets" ' @@ -572,12 +572,12 @@ class IdentityHandler(BaseHandler): validation_session = None # Try to validate as email - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: # Ask our delegated email identity server validation_session = await self.threepid_from_creds( self.hs.config.account_threepid_delegate_email, threepid_creds ) - elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details validation_session = await self.store.get_threepid_validation_session( "email", client_secret, sid=sid, validated=True diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 7a5d8e6f4e..ad4e4a3d6f 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -443,7 +443,7 @@ class EventCreationHandler: ) self._block_events_without_consent_error = ( - self.config.block_events_without_consent_error + self.config.consent.block_events_without_consent_error ) # we need to construct a ConsentURIBuilder here, as it checks that the necessary @@ -744,7 +744,7 @@ class EventCreationHandler: if u["appservice_id"] is not None: # users registered by an appservice are exempt return - if u["consent_version"] == self.config.user_consent_version: + if u["consent_version"] == self.config.consent.user_consent_version: return consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart) diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py index cd21efdcc6..eadd7ced09 100644 --- a/synapse/handlers/password_policy.py +++ b/synapse/handlers/password_policy.py @@ -27,8 +27,8 @@ logger = logging.getLogger(__name__) class PasswordPolicyHandler: def __init__(self, hs: "HomeServer"): - self.policy = hs.config.password_policy - self.enabled = hs.config.password_policy_enabled + self.policy = hs.config.auth.password_policy + self.enabled = hs.config.auth.password_policy_enabled # Regexps for the spec'd policy parameters. self.regexp_digit = re.compile("[0-9]") diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 1c195c65db..01c5e1385d 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -97,6 +97,7 @@ class RegistrationHandler(BaseHandler): self.ratelimiter = hs.get_registration_ratelimiter() self.macaroon_gen = hs.get_macaroon_generator() self._account_validity_handler = hs.get_account_validity_handler() + self._user_consent_version = self.hs.config.consent.user_consent_version self._server_notices_mxid = hs.config.server_notices_mxid self._server_name = hs.hostname @@ -339,7 +340,7 @@ class RegistrationHandler(BaseHandler): auth_provider=(auth_provider_id or ""), ).inc() - if not self.hs.config.user_consent_at_registration: + if not self.hs.config.consent.user_consent_at_registration: if not self.hs.config.auto_join_rooms_for_guests and make_guest: logger.info( "Skipping auto-join for %s because auto-join for guests is disabled", @@ -864,7 +865,9 @@ class RegistrationHandler(BaseHandler): await self._register_msisdn_threepid(user_id, threepid) if auth_result and LoginType.TERMS in auth_result: - await self._on_user_consented(user_id, self.hs.config.user_consent_version) + # The terms type should only exist if consent is enabled. + assert self._user_consent_version is not None + await self._on_user_consented(user_id, self._user_consent_version) async def _on_user_consented(self, user_id: str, consent_version: str) -> None: """A user consented to the terms on registration @@ -910,8 +913,8 @@ class RegistrationHandler(BaseHandler): # getting mail spam where they weren't before if email # notifs are set up on a homeserver) if ( - self.hs.config.email_enable_notifs - and self.hs.config.email_notif_for_new_users + self.hs.config.email.email_enable_notifs + and self.hs.config.email.email_notif_for_new_users and token ): # Pull the ID of the access token back out of the db diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index ea9325e96a..8f5d465fa1 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -82,10 +82,10 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self._enabled = bool(hs.config.recaptcha_private_key) + self._enabled = bool(hs.config.captcha.recaptcha_private_key) self._http_client = hs.get_proxied_http_client() - self._url = hs.config.recaptcha_siteverify_api - self._secret = hs.config.recaptcha_private_key + self._url = hs.config.captcha.recaptcha_siteverify_api + self._secret = hs.config.captcha.recaptcha_private_key def is_enabled(self) -> bool: return self._enabled @@ -161,12 +161,17 @@ class _BaseThreepidAuthChecker: self.hs.config.account_threepid_delegate_msisdn, threepid_creds ) elif medium == "email": - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if ( + self.hs.config.email.threepid_behaviour_email + == ThreepidBehaviour.REMOTE + ): assert self.hs.config.account_threepid_delegate_email threepid = await identity_handler.threepid_from_creds( self.hs.config.account_threepid_delegate_email, threepid_creds ) - elif self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + elif ( + self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL + ): threepid = None row = await self.store.get_threepid_validation_session( medium, @@ -218,7 +223,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return self.hs.config.threepid_behaviour_email in ( + return self.hs.config.email.threepid_behaviour_email in ( ThreepidBehaviour.REMOTE, ThreepidBehaviour.LOCAL, ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 174e6934a8..8ae21bc43c 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -119,14 +119,16 @@ class ModuleApi: self.custom_template_dir = hs.config.server.custom_template_directory try: - app_name = self._hs.config.email_app_name + app_name = self._hs.config.email.email_app_name - self._from_string = self._hs.config.email_notif_from % {"app": app_name} + self._from_string = self._hs.config.email.email_notif_from % { + "app": app_name + } except (KeyError, TypeError): # If substitution failed (which can happen if the string contains # placeholders other than just "app", or if the type of the placeholder is # not a string), fall back to the bare strings. - self._from_string = self._hs.config.email_notif_from + self._from_string = self._hs.config.email.email_notif_from self._raw_from = email.utils.parseaddr(self._from_string)[1] diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 29ed346d37..b57e094091 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -77,4 +77,4 @@ class PusherFactory: if isinstance(brand, str): return brand - return self.config.email_app_name + return self.config.email.email_app_name diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 681e491826..46bfec4623 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -368,8 +368,8 @@ class UserRestServletV2(RestServlet): user_id, medium, address, current_time ) if ( - self.hs.config.email_enable_notifs - and self.hs.config.email_notif_for_new_users + self.hs.config.email.email_enable_notifs + and self.hs.config.email.email_notif_for_new_users ): await self.pusher_pool.add_pusher( user_id=user_id, diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index aefaaa8ae8..6a7608d60b 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -64,17 +64,17 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.config = hs.config self.identity_handler = hs.get_identity_handler() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_password_reset_template_html, - template_text=self.config.email_password_reset_template_text, + app_name=self.config.email.email_app_name, + template_html=self.config.email.email_password_reset_template_html, + template_text=self.config.email.email_password_reset_template_text, ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.email.local_threepid_handling_disabled_due_to_email_config: logger.warning( "User password resets have been disabled due to lack of email config" ) @@ -129,7 +129,7 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request @@ -349,17 +349,17 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.store = self.hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_add_threepid_template_html, - template_text=self.config.email_add_threepid_template_text, + app_name=self.config.email.email_app_name, + template_html=self.config.email.email_add_threepid_template_html, + template_text=self.config.email.email_add_threepid_template_text, ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.email.local_threepid_handling_disabled_due_to_email_config: logger.warning( "Adding emails have been disabled due to lack of an email config" ) @@ -413,7 +413,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request @@ -534,21 +534,21 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.config = hs.config self.clock = hs.get_clock() self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( - self.config.email_add_threepid_template_failure_html + self.config.email.email_add_threepid_template_failure_html ) async def on_GET(self, request: Request) -> None: - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.email.local_threepid_handling_disabled_due_to_email_config: logger.warning( "Adding emails have been disabled due to lack of an email config" ) raise SynapseError( 400, "Adding an email to your account is disabled on this server" ) - elif self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + elif self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: raise SynapseError( 400, "This homeserver is not validating threepids. Use an identity server " @@ -575,7 +575,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_add_threepid_template_success_html_content + html = self.config.email.email_add_threepid_template_success_html_content status_code = 200 except ThreepidValidationError as e: status_code = e.code diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 7bb7801472..282861fae2 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -47,7 +47,7 @@ class AuthRestServlet(RestServlet): self.auth = hs.get_auth() self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() - self.recaptcha_template = hs.config.recaptcha_template + self.recaptcha_template = hs.config.captcha.recaptcha_template self.terms_template = hs.config.terms_template self.registration_token_template = hs.config.registration_token_template self.success_template = hs.config.fallback_success_template @@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet): session=session, myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - sitekey=self.hs.config.recaptcha_public_key, + sitekey=self.hs.config.captcha.recaptcha_public_key, ) elif stagetype == LoginType.TERMS: html = self.terms_template.render( @@ -70,7 +70,7 @@ class AuthRestServlet(RestServlet): terms_url="%s_matrix/consent?v=%s" % ( self.hs.config.server.public_baseurl, - self.hs.config.user_consent_version, + self.hs.config.consent.user_consent_version, ), myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), @@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet): session=session, myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.RECAPTCHA), - sitekey=self.hs.config.recaptcha_public_key, + sitekey=self.hs.config.captcha.recaptcha_public_key, error=e.msg, ) else: @@ -139,7 +139,7 @@ class AuthRestServlet(RestServlet): terms_url="%s_matrix/consent?v=%s" % ( self.hs.config.server.public_baseurl, - self.hs.config.user_consent_version, + self.hs.config.consent.user_consent_version, ), myurl="%s/r0/auth/%s/fallback/web" % (CLIENT_API_PREFIX, LoginType.TERMS), diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index a6ede7e2f3..d766e98dce 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -77,7 +77,7 @@ class LoginRestServlet(RestServlet): # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled - self.cas_enabled = hs.config.cas_enabled + self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc_enabled self._msc2918_enabled = hs.config.access_token_lifetime is not None @@ -559,7 +559,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: if hs.config.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) - if hs.config.cas_enabled: + if hs.config.cas.cas_enabled: CasTicketServlet(hs).register(http_server) diff --git a/synapse/rest/client/password_policy.py b/synapse/rest/client/password_policy.py index 0465fd2292..9f1908004b 100644 --- a/synapse/rest/client/password_policy.py +++ b/synapse/rest/client/password_policy.py @@ -35,8 +35,8 @@ class PasswordPolicyServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.policy = hs.config.password_policy - self.enabled = hs.config.password_policy_enabled + self.policy = hs.config.auth.password_policy + self.enabled = hs.config.auth.password_policy_enabled def on_GET(self, request: Request) -> Tuple[int, JsonDict]: if not self.enabled or not self.policy: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index abe4d7e205..48b0062cf4 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -75,17 +75,19 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_identity_handler() self.config = hs.config - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self.mailer = Mailer( hs=self.hs, - app_name=self.config.email_app_name, - template_html=self.config.email_registration_template_html, - template_text=self.config.email_registration_template_text, + app_name=self.config.email.email_app_name, + template_html=self.config.email.email_registration_template_html, + template_text=self.config.email.email_registration_template_text, ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.hs.config.local_threepid_handling_disabled_due_to_email_config: + if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if ( + self.hs.config.email.local_threepid_handling_disabled_due_to_email_config + ): logger.warning( "Email registration has been disabled due to lack of email config" ) @@ -137,7 +139,7 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) - if self.config.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: assert self.hs.config.account_threepid_delegate_email # Have the configured identity server handle the request @@ -259,9 +261,9 @@ class RegistrationSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: self._failure_email_template = ( - self.config.email_registration_template_failure_html + self.config.email.email_registration_template_failure_html ) async def on_GET(self, request: Request, medium: str) -> None: @@ -269,8 +271,8 @@ class RegistrationSubmitTokenServlet(RestServlet): raise SynapseError( 400, "This medium is currently not supported for registration" ) - if self.config.threepid_behaviour_email == ThreepidBehaviour.OFF: - if self.config.local_threepid_handling_disabled_due_to_email_config: + if self.config.email.threepid_behaviour_email == ThreepidBehaviour.OFF: + if self.config.email.local_threepid_handling_disabled_due_to_email_config: logger.warning( "User registration via email has been disabled due to lack of email config" ) @@ -303,7 +305,7 @@ class RegistrationSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_registration_template_success_html_content + html = self.config.email.email_registration_template_success_html_content status_code = 200 except ThreepidValidationError as e: status_code = e.code @@ -897,12 +899,12 @@ def _calculate_registration_flows( flows.append([LoginType.MSISDN, LoginType.EMAIL_IDENTITY]) # Prepend m.login.terms to all flows if we're requiring consent - if config.user_consent_at_registration: + if config.consent.user_consent_at_registration: for flow in flows: flow.insert(0, LoginType.TERMS) # Prepend recaptcha to all flows if we're requiring captcha - if config.enable_registration_captcha: + if config.captcha.enable_registration_captcha: for flow in flows: flow.insert(0, LoginType.RECAPTCHA) diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index 06e0fbde22..fc634a492d 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -84,14 +84,15 @@ class ConsentResource(DirectServeHtmlResource): # this is required by the request_handler wrapper self.clock = hs.get_clock() - self._default_consent_version = hs.config.user_consent_version - if self._default_consent_version is None: + # Consent must be configured to create this resource. + default_consent_version = hs.config.consent.user_consent_version + consent_template_directory = hs.config.consent.user_consent_template_dir + if default_consent_version is None or consent_template_directory is None: raise ConfigError( "Consent resource is enabled but user_consent section is " "missing in config file." ) - - consent_template_directory = hs.config.user_consent_template_dir + self._default_consent_version = default_consent_version # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(consent_template_directory) diff --git a/synapse/rest/synapse/client/password_reset.py b/synapse/rest/synapse/client/password_reset.py index f2800bf2db..28a67f04e3 100644 --- a/synapse/rest/synapse/client/password_reset.py +++ b/synapse/rest/synapse/client/password_reset.py @@ -47,20 +47,20 @@ class PasswordResetSubmitTokenResource(DirectServeHtmlResource): self.store = hs.get_datastore() self._local_threepid_handling_disabled_due_to_email_config = ( - hs.config.local_threepid_handling_disabled_due_to_email_config + hs.config.email.local_threepid_handling_disabled_due_to_email_config ) self._confirmation_email_template = ( - hs.config.email_password_reset_template_confirmation_html + hs.config.email.email_password_reset_template_confirmation_html ) self._email_password_reset_template_success_html = ( - hs.config.email_password_reset_template_success_html_content + hs.config.email.email_password_reset_template_success_html_content ) self._failure_email_template = ( - hs.config.email_password_reset_template_failure_html + hs.config.email.email_password_reset_template_failure_html ) # This resource should not be mounted if threepid behaviour is not LOCAL - assert hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL + assert hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL async def _async_render_GET(self, request: Request) -> Tuple[int, bytes]: sid = parse_string(request, "sid", required=True) diff --git a/synapse/server_notices/consent_server_notices.py b/synapse/server_notices/consent_server_notices.py index 4e0f814035..e09a25591f 100644 --- a/synapse/server_notices/consent_server_notices.py +++ b/synapse/server_notices/consent_server_notices.py @@ -36,9 +36,11 @@ class ConsentServerNotices: self._users_in_progress: Set[str] = set() - self._current_consent_version = hs.config.user_consent_version - self._server_notice_content = hs.config.user_consent_server_notice_content - self._send_to_guests = hs.config.user_consent_server_notice_to_guests + self._current_consent_version = hs.config.consent.user_consent_version + self._server_notice_content = ( + hs.config.consent.user_consent_server_notice_content + ) + self._send_to_guests = hs.config.consent.user_consent_server_notice_to_guests if self._server_notice_content is not None: if not self._server_notices_manager.is_enabled(): @@ -63,6 +65,9 @@ class ConsentServerNotices: # not enabled return + # A consent version must be given. + assert self._current_consent_version is not None + # make sure we don't send two messages to the same user at once if user_id in self._users_in_progress: return diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py index e2d1b758bd..2da2659f41 100644 --- a/synapse/storage/databases/main/appservice.py +++ b/synapse/storage/databases/main/appservice.py @@ -60,7 +60,7 @@ def _make_exclusive_regex( class ApplicationServiceWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): self.services_cache = load_appservices( - hs.hostname, hs.config.app_service_config_files + hs.hostname, hs.config.appservice.app_service_config_files ) self.exclusive_user_regex = _make_exclusive_regex(self.services_cache) diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index d213b26703..b76ee51a9b 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -63,7 +63,7 @@ class MonthlyActiveUsersWorkerStore(SQLBaseStore): """Generates current count of monthly active users broken down by service. A service is typically an appservice but also includes native matrix users. Since the `monthly_active_users` table is populated from the `user_ips` table - `config.track_appservice_user_ips` must be set to `true` for this + `config.appservice.track_appservice_user_ips` must be set to `true` for this method to return anything other than native matrix users. Returns: diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index fafadb88fc..52ef9deede 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -388,7 +388,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "get_users_expiring_soon", select_users_txn, self._clock.time_msec(), - self.config.account_validity_renew_at, + self.config.account_validity.account_validity_renew_at, ) async def set_renewal_mail_status(self, user_id: str, email_sent: bool) -> None: diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py index d4754c904c..f31880b8ec 100644 --- a/synapse/storage/prepare_database.py +++ b/synapse/storage/prepare_database.py @@ -545,7 +545,7 @@ def _apply_module_schemas( database_engine: config: application config """ - for (mod, _config) in config.password_providers: + for (mod, _config) in config.authproviders.password_providers: if not hasattr(mod, "get_db_schema_files"): continue modname = ".".join((mod.__module__, mod.__name__)) diff --git a/synapse/storage/schema/main/delta/30/as_users.py b/synapse/storage/schema/main/delta/30/as_users.py index 8a1f340083..22a7901e15 100644 --- a/synapse/storage/schema/main/delta/30/as_users.py +++ b/synapse/storage/schema/main/delta/30/as_users.py @@ -33,7 +33,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs): config_files = [] try: - config_files = config.app_service_config_files + config_files = config.appservice.app_service_config_files except AttributeError: logger.warning("Could not get app_service_config_files from config") pass diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index e798513ac1..0fa55e03b4 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -47,7 +47,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.event_creation_handler = hs.get_event_creation_handler() - hs.config.user_consent_version = "1" + hs.config.consent.user_consent_version = "1" consent_uri_builder = Mock() consent_uri_builder.build_user_consent_uri.return_value = "http://example.com" diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index f5c195a075..414c8781a9 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -97,7 +97,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase): self.hs.config.enable_registration = True self.hs.config.registrations_require_3pid = [] self.hs.config.auto_join_rooms = [] - self.hs.config.enable_registration_captcha = False + self.hs.config.captcha.enable_registration_captcha = False return self.hs diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 666bffe257..ebadf47948 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -41,9 +41,8 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = self.as_yaml_files + hs.config.appservice.app_service_config_files = self.as_yaml_files hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] self.as_token = "token1" self.as_url = "some_url" @@ -108,9 +107,8 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = self.as_yaml_files + hs.config.appservice.app_service_config_files = self.as_yaml_files hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] self.as_list = [ {"token": "token1", "url": "https://matrix-as.org", "id": "id_1"}, @@ -496,9 +494,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = [f1, f2] + hs.config.appservice.app_service_config_files = [f1, f2] hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] database = hs.get_datastores().databases[0] ApplicationServiceStore( @@ -514,7 +511,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = [f1, f2] + hs.config.appservice.app_service_config_files = [f1, f2] hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] @@ -540,9 +537,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): self.addCleanup, federation_sender=Mock(), federation_client=Mock() ) - hs.config.app_service_config_files = [f1, f2] + hs.config.appservice.app_service_config_files = [f1, f2] hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] diff --git a/tests/storage/test_cleanup_extrems.py b/tests/storage/test_cleanup_extrems.py index da98733ce8..7cc5e621ba 100644 --- a/tests/storage/test_cleanup_extrems.py +++ b/tests/storage/test_cleanup_extrems.py @@ -258,7 +258,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase): info, _ = self.get_success(self.room_creator.create_room(self.requester, {})) self.room_id = info["room_id"] self.event_creator = homeserver.get_event_creation_handler() - homeserver.config.user_consent_version = self.CONSENT_VERSION + homeserver.config.consent.user_consent_version = self.CONSENT_VERSION def test_send_dummy_event(self): self._create_extremity_rich_graph() -- cgit 1.5.1 From 47854c71e9bded2c446a251f3ef16f4d5da96ebe Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 23 Sep 2021 12:03:01 -0400 Subject: Use direct references for configuration variables (part 4). (#10893) --- changelog.d/10893.misc | 1 + synapse/api/urls.py | 4 ++-- synapse/app/_base.py | 6 ++++-- synapse/app/admin_cmd.py | 2 +- synapse/app/generic_worker.py | 4 ++-- synapse/app/homeserver.py | 10 +++++----- synapse/app/phone_stats_home.py | 8 +++++--- synapse/config/logger.py | 2 +- synapse/federation/transport/server/_base.py | 4 +++- synapse/groups/groups_server.py | 6 +++--- synapse/handlers/auth.py | 2 +- synapse/handlers/oidc.py | 2 +- synapse/handlers/profile.py | 2 +- synapse/http/matrixfederationclient.py | 5 +++-- synapse/push/httppusher.py | 4 +++- synapse/rest/client/login.py | 12 ++++++------ synapse/rest/consent/consent_resource.py | 4 ++-- synapse/rest/key/v2/local_key_resource.py | 10 +++++----- synapse/rest/key/v2/remote_key_resource.py | 6 ++++-- synapse/rest/media/v1/media_repository.py | 4 +++- synapse/rest/synapse/client/__init__.py | 2 +- synapse/storage/databases/main/roommember.py | 2 +- tests/api/test_auth.py | 4 ++-- tests/app/test_phone_stats_home.py | 2 +- tests/config/test_load.py | 10 +++++----- tests/config/test_ratelimiting.py | 2 +- tests/handlers/test_auth.py | 2 +- tests/replication/_base.py | 2 +- tests/rest/client/test_login.py | 12 ++++++------ tests/rest/client/test_register.py | 2 +- tests/storage/test_appservice.py | 1 - tests/util/test_ratelimitutils.py | 2 +- 32 files changed, 77 insertions(+), 64 deletions(-) create mode 100644 changelog.d/10893.misc (limited to 'tests/rest') diff --git a/changelog.d/10893.misc b/changelog.d/10893.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10893.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/api/urls.py b/synapse/api/urls.py index d3270cd6d2..032c69b210 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -39,12 +39,12 @@ class ConsentURIBuilder: Args: hs_config (synapse.config.homeserver.HomeServerConfig): """ - if hs_config.form_secret is None: + if hs_config.key.form_secret is None: raise ConfigError("form_secret not set in config") if hs_config.server.public_baseurl is None: raise ConfigError("public_baseurl not set in config") - self._hmac_secret = hs_config.form_secret.encode("utf-8") + self._hmac_secret = hs_config.key.form_secret.encode("utf-8") self._public_baseurl = hs_config.server.public_baseurl def build_user_consent_uri(self, user_id): diff --git a/synapse/app/_base.py b/synapse/app/_base.py index d1aa2e7fb5..f657f11f76 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -424,12 +424,14 @@ def setup_sentry(hs): hs (synapse.server.HomeServer) """ - if not hs.config.sentry_enabled: + if not hs.config.metrics.sentry_enabled: return import sentry_sdk - sentry_sdk.init(dsn=hs.config.sentry_dsn, release=get_version_string(synapse)) + sentry_sdk.init( + dsn=hs.config.metrics.sentry_dsn, release=get_version_string(synapse) + ) # We set some default tags that give some context to this instance with sentry_sdk.configure_scope() as scope: diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 5e956b1e27..259d5ec7cc 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -192,7 +192,7 @@ def start(config_options): ): # Since we're meant to be run as a "command" let's not redirect stdio # unless we've actually set log config. - config.no_redirect_stdio = True + config.logging.no_redirect_stdio = True # Explicitly disable background processes config.update_user_directory = False diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 33afd59c72..e0776689ce 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -395,7 +395,7 @@ class GenericWorkerServer(HomeServer): manhole_globals={"hs": self}, ) elif listener.type == "metrics": - if not self.config.enable_metrics: + if not self.config.metrics.enable_metrics: logger.warning( "Metrics listener configured, but " "enable_metrics is not True!" @@ -488,7 +488,7 @@ def start(config_options): register_start(_base.start, hs) # redirect stdio to the logs, if configured. - if not hs.config.no_redirect_stdio: + if not hs.config.logging.no_redirect_stdio: redirect_stdio_to_logs() _base.start_worker_reactor("synapse-generic-worker", config) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 886e291e4c..f1769f146b 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -269,7 +269,7 @@ class SynapseHomeServer(HomeServer): # https://twistedmatrix.com/trac/ticket/7678 resources[WEB_CLIENT_PREFIX] = File(webclient_loc) - if name == "metrics" and self.config.enable_metrics: + if name == "metrics" and self.config.metrics.enable_metrics: resources[METRICS_PREFIX] = MetricsResource(RegistryProxy) if name == "replication": @@ -278,7 +278,7 @@ class SynapseHomeServer(HomeServer): return resources def start_listening(self): - if self.config.redis_enabled: + if self.config.redis.redis_enabled: # If redis is enabled we connect via the replication command handler # in the same way as the workers (since we're effectively a client # rather than a server). @@ -305,7 +305,7 @@ class SynapseHomeServer(HomeServer): for s in services: reactor.addSystemEventTrigger("before", "shutdown", s.stopListening) elif listener.type == "metrics": - if not self.config.enable_metrics: + if not self.config.metrics.enable_metrics: logger.warning( "Metrics listener configured, but " "enable_metrics is not True!" @@ -366,7 +366,7 @@ def setup(config_options): async def start(): # Load the OIDC provider metadatas, if OIDC is enabled. - if hs.config.oidc_enabled: + if hs.config.oidc.oidc_enabled: oidc = hs.get_oidc_handler() # Loading the provider metadata also ensures the provider config is valid. await oidc.load_metadata() @@ -455,7 +455,7 @@ def main(): hs = setup(sys.argv[1:]) # redirect stdio to the logs, if configured. - if not hs.config.no_redirect_stdio: + if not hs.config.logging.no_redirect_stdio: redirect_stdio_to_logs() run(hs) diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 4a95da90f9..49e7a45e5c 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -131,10 +131,12 @@ async def phone_stats_home(hs, stats, stats_process=_stats_process): log_level = synapse_logger.getEffectiveLevel() stats["log_level"] = logging.getLevelName(log_level) - logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats)) + logger.info( + "Reporting stats to %s: %s" % (hs.config.metrics.report_stats_endpoint, stats) + ) try: await hs.get_proxied_http_client().put_json( - hs.config.report_stats_endpoint, stats + hs.config.metrics.report_stats_endpoint, stats ) except Exception as e: logger.warning("Error reporting stats: %s", e) @@ -188,7 +190,7 @@ def start_phone_stats_home(hs): clock.looping_call(generate_monthly_active_users, 5 * 60 * 1000) # End of monthly active user settings - if hs.config.report_stats: + if hs.config.metrics.report_stats: logger.info("Scheduling stats reporting for 3 hour intervals") clock.looping_call(phone_stats_home, 3 * 60 * 60 * 1000, hs, stats) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index aca9d467e6..bf8ca7d5fe 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -322,7 +322,7 @@ def setup_logging( """ log_config_path = ( - config.worker_log_config if use_worker_options else config.log_config + config.worker_log_config if use_worker_options else config.logging.log_config ) # Perform one-time logging configuration. diff --git a/synapse/federation/transport/server/_base.py b/synapse/federation/transport/server/_base.py index 624c859f1e..cef65929c5 100644 --- a/synapse/federation/transport/server/_base.py +++ b/synapse/federation/transport/server/_base.py @@ -49,7 +49,9 @@ class Authenticator: self.keyring = hs.get_keyring() self.server_name = hs.hostname self.store = hs.get_datastore() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) self.notifier = hs.get_notifier() self.replication_client = None diff --git a/synapse/groups/groups_server.py b/synapse/groups/groups_server.py index d6b75ac27f..449bbc7004 100644 --- a/synapse/groups/groups_server.py +++ b/synapse/groups/groups_server.py @@ -847,16 +847,16 @@ class GroupsServerHandler(GroupsServerWorkerHandler): UserID.from_string(requester_user_id) ) if not is_admin: - if not self.hs.config.enable_group_creation: + if not self.hs.config.groups.enable_group_creation: raise SynapseError( 403, "Only a server admin can create groups on this server" ) localpart = group_id_obj.localpart - if not localpart.startswith(self.hs.config.group_creation_prefix): + if not localpart.startswith(self.hs.config.groups.group_creation_prefix): raise SynapseError( 400, "Can only create groups with prefix %r on this server" - % (self.hs.config.group_creation_prefix,), + % (self.hs.config.groups.group_creation_prefix,), ) profile = content.get("profile", {}) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index b747f80bc1..0f80dfdc43 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1802,7 +1802,7 @@ class MacaroonGenerator: macaroon = pymacaroons.Macaroon( location=self.hs.config.server.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key, + key=self.hs.config.key.macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index aed5a40a78..3665d91513 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -277,7 +277,7 @@ class OidcProvider: self._token_generator = token_generator self._config = provider - self._callback_url: str = hs.config.oidc_callback_url + self._callback_url: str = hs.config.oidc.oidc_callback_url # Calculate the prefix for OIDC callback paths based on the public_baseurl. # We'll insert this into the Path= parameter of any session cookies we set. diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index f06070bfcf..b23a1541bc 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -309,7 +309,7 @@ class ProfileHandler(BaseHandler): async def on_profile_query(self, args: JsonDict) -> JsonDict: """Handles federation profile query requests.""" - if not self.hs.config.allow_profile_lookup_over_federation: + if not self.hs.config.federation.allow_profile_lookup_over_federation: raise SynapseError( 403, "Profile lookup over federation is disabled on this homeserver", diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index e56fa477bb..cdc36b8d25 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -465,8 +465,9 @@ class MatrixFederationHttpClient: _sec_timeout = self.default_timeout if ( - self.hs.config.federation_domain_whitelist is not None - and request.destination not in self.hs.config.federation_domain_whitelist + self.hs.config.federation.federation_domain_whitelist is not None + and request.destination + not in self.hs.config.federation.federation_domain_whitelist ): raise FederationDeniedError(request.destination) diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 065948f982..eac65572b2 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -73,7 +73,9 @@ class HttpPusher(Pusher): self.failing_since = pusher_config.failing_since self.timed_call: Optional[IDelayedCall] = None self._is_processing = False - self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room + self._group_unread_count_by_room = ( + hs.config.push.push_group_unread_count_by_room + ) self._pusherpool = hs.get_pusherpool() self.data = pusher_config.data diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index d766e98dce..64446fc486 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -69,16 +69,16 @@ class LoginRestServlet(RestServlet): self.hs = hs # JWT configuration variables. - self.jwt_enabled = hs.config.jwt_enabled - self.jwt_secret = hs.config.jwt_secret - self.jwt_algorithm = hs.config.jwt_algorithm - self.jwt_issuer = hs.config.jwt_issuer - self.jwt_audiences = hs.config.jwt_audiences + self.jwt_enabled = hs.config.jwt.jwt_enabled + self.jwt_secret = hs.config.jwt.jwt_secret + self.jwt_algorithm = hs.config.jwt.jwt_algorithm + self.jwt_issuer = hs.config.jwt.jwt_issuer + self.jwt_audiences = hs.config.jwt.jwt_audiences # SSO configuration. self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled - self.oidc_enabled = hs.config.oidc_enabled + self.oidc_enabled = hs.config.oidc.oidc_enabled self._msc2918_enabled = hs.config.access_token_lifetime is not None self.auth = hs.get_auth() diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index fc634a492d..3d2afacc50 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -100,13 +100,13 @@ class ConsentResource(DirectServeHtmlResource): loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) ) - if hs.config.form_secret is None: + if hs.config.key.form_secret is None: raise ConfigError( "Consent resource is enabled but form_secret is not set in " "config file. It should be set to an arbitrary secret string." ) - self._hmac_secret = hs.config.form_secret.encode("utf-8") + self._hmac_secret = hs.config.key.form_secret.encode("utf-8") async def _async_render_GET(self, request: Request) -> None: version = parse_string(request, "v", default=self._default_consent_version) diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py index ebe243bcfd..12b3ae120c 100644 --- a/synapse/rest/key/v2/local_key_resource.py +++ b/synapse/rest/key/v2/local_key_resource.py @@ -70,19 +70,19 @@ class LocalKey(Resource): Resource.__init__(self) def update_response_body(self, time_now_msec: int) -> None: - refresh_interval = self.config.key_refresh_interval + refresh_interval = self.config.key.key_refresh_interval self.valid_until_ts = int(time_now_msec + refresh_interval) self.response_body = encode_canonical_json(self.response_json_object()) def response_json_object(self) -> JsonDict: verify_keys = {} - for key in self.config.signing_key: + for key in self.config.key.signing_key: verify_key_bytes = key.verify_key.encode() key_id = "%s:%s" % (key.alg, key.version) verify_keys[key_id] = {"key": encode_base64(verify_key_bytes)} old_verify_keys = {} - for key_id, key in self.config.old_signing_keys.items(): + for key_id, key in self.config.key.old_signing_keys.items(): verify_key_bytes = key.encode() old_verify_keys[key_id] = { "key": encode_base64(verify_key_bytes), @@ -95,13 +95,13 @@ class LocalKey(Resource): "verify_keys": verify_keys, "old_verify_keys": old_verify_keys, } - for key in self.config.signing_key: + for key in self.config.key.signing_key: json_object = sign_json(json_object, self.config.server.server_name, key) return json_object def render_GET(self, request: Request) -> int: time_now = self.clock.time_msec() # Update the expiry time if less than half the interval remains. - if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts: + if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: self.update_response_body(time_now) return respond_with_json_bytes(request, 200, self.response_body) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index d8fd7938a4..c111a9d20f 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -97,7 +97,9 @@ class RemoteKey(DirectServeJsonResource): self.fetcher = ServerKeyFetcher(hs) self.store = hs.get_datastore() self.clock = hs.get_clock() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) self.config = hs.config async def _async_render_GET(self, request: Request) -> None: @@ -235,7 +237,7 @@ class RemoteKey(DirectServeJsonResource): signed_keys = [] for key_json in json_results: key_json = json_decoder.decode(key_json.decode("utf-8")) - for signing_key in self.config.key_server_signing_keys: + for signing_key in self.config.key.key_server_signing_keys: key_json = sign_json( key_json, self.config.server.server_name, signing_key ) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 50e4c9e29f..a30007a1e2 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -92,7 +92,9 @@ class MediaRepository: self.recently_accessed_remotes: Set[Tuple[str, str]] = set() self.recently_accessed_locals: Set[str] = set() - self.federation_domain_whitelist = hs.config.federation_domain_whitelist + self.federation_domain_whitelist = ( + hs.config.federation.federation_domain_whitelist + ) # List of StorageProviders where we should search for media and # potentially upload to. diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 47a2f72b32..086c80b723 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -45,7 +45,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc # provider-specific SSO bits. Only load these if they are enabled, since they # rely on optional dependencies. - if hs.config.oidc_enabled: + if hs.config.oidc.oidc_enabled: from synapse.rest.synapse.client.oidc import OIDCResource resources["/_synapse/client/oidc"] = OIDCResource(hs) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index a4ec6bc328..ddb162a4fc 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -82,7 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): if ( self.hs.config.worker.run_background_tasks - and self.hs.config.metrics_flags.known_servers + and self.hs.config.metrics.metrics_flags.known_servers ): self._known_servers_count = 1 self.hs.get_clock().looping_call( diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index f76fea4f66..8a4ef13054 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -217,7 +217,7 @@ class AuthTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key, + key=self.hs.config.key.macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") @@ -239,7 +239,7 @@ class AuthTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", - key=self.hs.config.macaroon_secret_key, + key=self.hs.config.key.macaroon_secret_key, ) macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") diff --git a/tests/app/test_phone_stats_home.py b/tests/app/test_phone_stats_home.py index d66aeb00eb..19eb4c79d0 100644 --- a/tests/app/test_phone_stats_home.py +++ b/tests/app/test_phone_stats_home.py @@ -172,7 +172,7 @@ class PhoneHomeR30V2TestCase(HomeserverTestCase): # We don't want our tests to actually report statistics, so check # that it's not enabled - assert not hs.config.report_stats + assert not hs.config.metrics.report_stats # This starts the needed data collection that we rely on to calculate # R30v2 metrics. diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 903c69127d..ef6c2beec7 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -52,10 +52,10 @@ class ConfigLoadingTestCase(unittest.TestCase): hasattr(config, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) - if len(config.macaroon_secret_key) < 5: + if len(config.key.macaroon_secret_key) < 5: self.fail( "Want macaroon secret key to be string of at least length 5," - "was: %r" % (config.macaroon_secret_key,) + "was: %r" % (config.key.macaroon_secret_key,) ) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) @@ -63,10 +63,10 @@ class ConfigLoadingTestCase(unittest.TestCase): hasattr(config, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) - if len(config.macaroon_secret_key) < 5: + if len(config.key.macaroon_secret_key) < 5: self.fail( "Want macaroon secret key to be string of at least length 5," - "was: %r" % (config.macaroon_secret_key,) + "was: %r" % (config.key.macaroon_secret_key,) ) def test_load_succeeds_if_macaroon_secret_key_missing(self): @@ -101,7 +101,7 @@ class ConfigLoadingTestCase(unittest.TestCase): # The default Metrics Flags are off by default. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.metrics_flags.known_servers) + self.assertFalse(config.metrics.metrics_flags.known_servers) def generate_config(self): with redirect_stdout(StringIO()): diff --git a/tests/config/test_ratelimiting.py b/tests/config/test_ratelimiting.py index 3c7bb32e07..1b63e1adfd 100644 --- a/tests/config/test_ratelimiting.py +++ b/tests/config/test_ratelimiting.py @@ -30,7 +30,7 @@ class RatelimitConfigTestCase(TestCase): config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") - config_obj = config.rc_federation + config_obj = config.ratelimiting.rc_federation self.assertEqual(config_obj.window_size, 20000) self.assertEqual(config_obj.sleep_limit, 693) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 5f3350e490..12857053e7 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -67,7 +67,7 @@ class AuthTestCase(unittest.HomeserverTestCase): v.satisfy_general(verify_type) v.satisfy_general(verify_nonce) v.satisfy_general(verify_guest) - v.verify(macaroon, self.hs.config.macaroon_secret_key) + v.verify(macaroon, self.hs.config.key.macaroon_secret_key) def test_short_term_login_token_gives_user_id(self): token = self.macaroon_generator.generate_short_term_login_token( diff --git a/tests/replication/_base.py b/tests/replication/_base.py index e9fd991718..c7555c26db 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -328,7 +328,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Set up TCP replication between master and the new worker if we don't # have Redis support enabled. - if not worker_hs.config.redis_enabled: + if not worker_hs.config.redis.redis_enabled: repl_handler = ReplicationCommandHandler(worker_hs) client = ClientReplicationStreamProtocol( worker_hs, diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 414c8781a9..371615a015 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -815,9 +815,9 @@ class JWTTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_secret - self.hs.config.jwt_algorithm = self.jwt_algorithm + self.hs.config.jwt.jwt_enabled = True + self.hs.config.jwt.jwt_secret = self.jwt_secret + self.hs.config.jwt.jwt_algorithm = self.jwt_algorithm return self.hs def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_secret) -> str: @@ -1023,9 +1023,9 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor, clock): self.hs = self.setup_test_homeserver() - self.hs.config.jwt_enabled = True - self.hs.config.jwt_secret = self.jwt_pubkey - self.hs.config.jwt_algorithm = "RS256" + self.hs.config.jwt.jwt_enabled = True + self.hs.config.jwt.jwt_secret = self.jwt_pubkey + self.hs.config.jwt.jwt_algorithm = "RS256" return self.hs def jwt_encode(self, payload: Dict[str, Any], secret: str = jwt_privatekey) -> str: diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 9f3ab2c985..72a5a11b46 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -146,7 +146,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertEquals(channel.json_body["errcode"], "M_FORBIDDEN") def test_POST_guest_registration(self): - self.hs.config.macaroon_secret_key = "test" + self.hs.config.key.macaroon_secret_key = "test" self.hs.config.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index ebadf47948..cf9748f218 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -513,7 +513,6 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): hs.config.appservice.app_service_config_files = [f1, f2] hs.config.caches.event_cache_size = 1 - hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: database = hs.get_datastores().databases[0] diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 34aaffe859..89d8656634 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -95,4 +95,4 @@ def build_rc_config(settings: Optional[dict] = None): config_dict.update(settings or {}) config = HomeServerConfig() config.parse_config_dict(config_dict, "", "") - return config.rc_federation + return config.ratelimiting.rc_federation -- cgit 1.5.1 From 90d9fc750514b1ede327f1dfe6e0a1c09b281d6d Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Thu, 23 Sep 2021 18:58:12 +0100 Subject: Allow `.` and `~` chars in registration tokens (#10887) Per updates to MSC3231 in order to use the same grammar as other identifiers. --- changelog.d/10887.bugfix | 1 + synapse/rest/admin/registration_tokens.py | 2 +- tests/rest/admin/test_registration_tokens.py | 8 +++++--- 3 files changed, 7 insertions(+), 4 deletions(-) create mode 100644 changelog.d/10887.bugfix (limited to 'tests/rest') diff --git a/changelog.d/10887.bugfix b/changelog.d/10887.bugfix new file mode 100644 index 0000000000..2d1f67489a --- /dev/null +++ b/changelog.d/10887.bugfix @@ -0,0 +1 @@ +Allow the `.` and `~` characters when creating registration tokens as per the change to [MSC3231](https://github.com/matrix-org/matrix-doc/pull/3231). diff --git a/synapse/rest/admin/registration_tokens.py b/synapse/rest/admin/registration_tokens.py index 5a1c929d85..aba48f6e7b 100644 --- a/synapse/rest/admin/registration_tokens.py +++ b/synapse/rest/admin/registration_tokens.py @@ -113,7 +113,7 @@ class NewRegistrationTokenRestServlet(RestServlet): self.store = hs.get_datastore() self.clock = hs.get_clock() # A string of all the characters allowed to be in a registration_token - self.allowed_chars = string.ascii_letters + string.digits + "-_" + self.allowed_chars = string.ascii_letters + string.digits + "._~-" self.allowed_chars_set = set(self.allowed_chars) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: diff --git a/tests/rest/admin/test_registration_tokens.py b/tests/rest/admin/test_registration_tokens.py index 4927321e5a..9bac423ae0 100644 --- a/tests/rest/admin/test_registration_tokens.py +++ b/tests/rest/admin/test_registration_tokens.py @@ -95,8 +95,10 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): def test_create_specifying_fields(self): """Create a token specifying the value of all fields.""" + # As many of the allowed characters as possible with length <= 64 + token = "adefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._~-" data = { - "token": "abcd", + "token": token, "uses_allowed": 1, "expiry_time": self.clock.time_msec() + 1000000, } @@ -109,7 +111,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) - self.assertEqual(channel.json_body["token"], "abcd") + self.assertEqual(channel.json_body["token"], token) self.assertEqual(channel.json_body["uses_allowed"], 1) self.assertEqual(channel.json_body["expiry_time"], data["expiry_time"]) self.assertEqual(channel.json_body["pending"], 0) @@ -193,7 +195,7 @@ class ManageRegistrationTokensTestCase(unittest.HomeserverTestCase): """Check right error is raised when server can't generate unique token.""" # Create all possible single character tokens tokens = [] - for c in string.ascii_letters + string.digits + "-_": + for c in string.ascii_letters + string.digits + "._~-": tokens.append( { "token": c, -- cgit 1.5.1 From 50022cff966a3991fbd8a1e5c98f490d9b335442 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 24 Sep 2021 11:01:25 +0100 Subject: Add reactor to `SynapseRequest` and fix up types. (#10868) --- changelog.d/10868.feature | 1 + synapse/http/server.py | 4 +-- synapse/http/site.py | 37 +++++++++++++++-------- synapse/rest/key/v2/remote_key_resource.py | 9 +++--- synapse/rest/media/v1/_base.py | 7 +++-- synapse/rest/media/v1/config_resource.py | 4 +-- synapse/rest/media/v1/download_resource.py | 5 ++-- synapse/rest/media/v1/media_repository.py | 10 +++++-- synapse/rest/media/v1/preview_url_resource.py | 3 +- synapse/rest/media/v1/thumbnail_resource.py | 15 +++++----- synapse/rest/media/v1/upload_resource.py | 4 +-- tests/http/test_additional_resource.py | 8 +++-- tests/logging/test_terse_json.py | 3 +- tests/replication/test_multi_media_repo.py | 2 +- tests/rest/admin/test_admin.py | 6 ++-- tests/rest/admin/test_media.py | 6 ++-- tests/rest/admin/test_user.py | 2 +- tests/rest/client/test_account.py | 4 +-- tests/rest/client/test_consent.py | 12 +++++--- tests/rest/client/utils.py | 2 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +-- tests/rest/media/v1/test_media_storage.py | 8 ++--- tests/server.py | 6 ++-- tests/test_server.py | 43 ++++++++++++++++++++------- 24 files changed, 123 insertions(+), 82 deletions(-) create mode 100644 changelog.d/10868.feature (limited to 'tests/rest') diff --git a/changelog.d/10868.feature b/changelog.d/10868.feature new file mode 100644 index 0000000000..07e7b2c6a7 --- /dev/null +++ b/changelog.d/10868.feature @@ -0,0 +1 @@ +Speed up responding with large JSON objects to requests. diff --git a/synapse/http/server.py b/synapse/http/server.py index b79fa722e9..e28b56abb9 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -320,7 +320,7 @@ class DirectServeJsonResource(_AsyncResource): def _send_response( self, - request: Request, + request: SynapseRequest, code: int, response_object: Any, ): @@ -629,7 +629,7 @@ def _encode_json_bytes(json_object: Any) -> Iterator[bytes]: def respond_with_json( - request: Request, + request: SynapseRequest, code: int, json_object: Any, send_cors: bool = False, diff --git a/synapse/http/site.py b/synapse/http/site.py index dd4c749e16..755ad56637 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -14,13 +14,14 @@ import contextlib import logging import time -from typing import Optional, Tuple, Union +from typing import Generator, Optional, Tuple, Union import attr from zope.interface import implementer from twisted.internet.interfaces import IAddress, IReactorTime from twisted.python.failure import Failure +from twisted.web.http import HTTPChannel from twisted.web.resource import IResource, Resource from twisted.web.server import Request, Site @@ -61,10 +62,18 @@ class SynapseRequest(Request): logcontext: the log context for this request """ - def __init__(self, channel, *args, max_request_body_size: int = 1024, **kw): - Request.__init__(self, channel, *args, **kw) + def __init__( + self, + channel: HTTPChannel, + site: "SynapseSite", + *args, + max_request_body_size: int = 1024, + **kw, + ): + super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size - self.site: SynapseSite = channel.site + self.synapse_site = site + self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 @@ -97,7 +106,7 @@ class SynapseRequest(Request): self.get_method(), self.get_redacted_uri(), self.clientproto.decode("ascii", errors="replace"), - self.site.site_tag, + self.synapse_site.site_tag, ) def handleContentChunk(self, data: bytes) -> None: @@ -216,7 +225,7 @@ class SynapseRequest(Request): request=ContextRequest( request_id=request_id, ip_address=self.getClientIP(), - site_tag=self.site.site_tag, + site_tag=self.synapse_site.site_tag, # The requester is going to be unknown at this point. requester=None, authenticated_entity=None, @@ -228,7 +237,7 @@ class SynapseRequest(Request): ) # override the Server header which is set by twisted - self.setHeader("Server", self.site.server_version_string) + self.setHeader("Server", self.synapse_site.server_version_string) with PreserveLoggingContext(self.logcontext): # we start the request metrics timer here with an initial stab @@ -247,7 +256,7 @@ class SynapseRequest(Request): requests_counter.labels(self.get_method(), self.request_metrics.name).inc() @contextlib.contextmanager - def processing(self): + def processing(self) -> Generator[None, None, None]: """Record the fact that we are processing this request. Returns a context manager; the correct way to use this is: @@ -346,10 +355,10 @@ class SynapseRequest(Request): self.start_time, name=servlet_name, method=self.get_method() ) - self.site.access_logger.debug( + self.synapse_site.access_logger.debug( "%s - %s - Received request: %s %s", self.getClientIP(), - self.site.site_tag, + self.synapse_site.site_tag, self.get_method(), self.get_redacted_uri(), ) @@ -388,13 +397,13 @@ class SynapseRequest(Request): if authenticated_entity: requester = f"{authenticated_entity}|{requester}" - self.site.access_logger.log( + self.synapse_site.access_logger.log( log_level, "%s - %s - {%s}" " Processed request: %.3fsec/%.3fsec (%.3fsec, %.3fsec) (%.3fsec/%.3fsec/%d)" ' %sB %s "%s %s %s" "%s" [%d dbevts]', self.getClientIP(), - self.site.site_tag, + self.synapse_site.site_tag, requester, processing_time, response_send_time, @@ -522,7 +531,7 @@ class SynapseSite(Site): site_tag: str, config: ListenerConfig, resource: IResource, - server_version_string, + server_version_string: str, max_request_body_size: int, reactor: IReactorTime, ): @@ -542,6 +551,7 @@ class SynapseSite(Site): Site.__init__(self, resource, reactor=reactor) self.site_tag = site_tag + self.reactor = reactor assert config.http_options is not None proxied = config.http_options.x_forwarded @@ -550,6 +560,7 @@ class SynapseSite(Site): def request_factory(channel, queued: bool) -> Request: return request_class( channel, + self, max_request_body_size=max_request_body_size, queued=queued, ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index c111a9d20f..3923ba8439 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -17,12 +17,11 @@ from typing import TYPE_CHECKING, Dict from signedjson.sign import sign_json -from twisted.web.server import Request - from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request +from synapse.http.site import SynapseRequest from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import yieldable_gather_results @@ -102,7 +101,7 @@ class RemoteKey(DirectServeJsonResource): ) self.config = hs.config - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: assert request.postpath is not None if len(request.postpath) == 1: (server,) = request.postpath @@ -119,7 +118,7 @@ class RemoteKey(DirectServeJsonResource): await self.query_keys(request, query, query_remote_on_cache_miss=True) - async def _async_render_POST(self, request: Request) -> None: + async def _async_render_POST(self, request: SynapseRequest) -> None: content = parse_json_object_from_request(request) query = content["server_keys"] @@ -128,7 +127,7 @@ class RemoteKey(DirectServeJsonResource): async def query_keys( self, - request: Request, + request: SynapseRequest, query: JsonDict, query_remote_on_cache_miss: bool = False, ) -> None: diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py index 7c881f2bdb..014fa893d6 100644 --- a/synapse/rest/media/v1/_base.py +++ b/synapse/rest/media/v1/_base.py @@ -27,6 +27,7 @@ from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError, cs_error from synapse.http.server import finish_request, respond_with_json +from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.util.stringutils import is_ascii @@ -74,7 +75,7 @@ def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: ) -def respond_404(request: Request) -> None: +def respond_404(request: SynapseRequest) -> None: respond_with_json( request, 404, @@ -84,7 +85,7 @@ def respond_404(request: Request) -> None: async def respond_with_file( - request: Request, + request: SynapseRequest, media_type: str, file_path: str, file_size: Optional[int] = None, @@ -221,7 +222,7 @@ def _can_encode_filename_as_token(x: str) -> bool: async def respond_with_responder( - request: Request, + request: SynapseRequest, responder: "Optional[Responder]", media_type: str, file_size: Optional[int], diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index a1d36e5cf1..712d4e8368 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -16,8 +16,6 @@ from typing import TYPE_CHECKING -from twisted.web.server import Request - from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.site import SynapseRequest @@ -39,5 +37,5 @@ class MediaConfigResource(DirectServeJsonResource): await self.auth.get_user_by_req(request) respond_with_json(request, 200, self.limits_dict, send_cors=True) - async def _async_render_OPTIONS(self, request: Request) -> None: + async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/download_resource.py b/synapse/rest/media/v1/download_resource.py index d6d938953e..6180fa575e 100644 --- a/synapse/rest/media/v1/download_resource.py +++ b/synapse/rest/media/v1/download_resource.py @@ -15,10 +15,9 @@ import logging from typing import TYPE_CHECKING -from twisted.web.server import Request - from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_boolean +from synapse.http.site import SynapseRequest from ._base import parse_media_id, respond_404 @@ -37,7 +36,7 @@ class DownloadResource(DirectServeJsonResource): self.media_repo = media_repo self.server_name = hs.hostname - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: set_cors_headers(request) request.setHeader( b"Content-Security-Policy", diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index a30007a1e2..c1bd81100d 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -23,7 +23,6 @@ import twisted.internet.error import twisted.web.http from twisted.internet.defer import Deferred from twisted.web.resource import Resource -from twisted.web.server import Request from synapse.api.errors import ( FederationDeniedError, @@ -34,6 +33,7 @@ from synapse.api.errors import ( ) from synapse.config._base import ConfigError from synapse.config.repository import ThumbnailRequirement +from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID @@ -189,7 +189,7 @@ class MediaRepository: return "mxc://%s/%s" % (self.server_name, media_id) async def get_local_media( - self, request: Request, media_id: str, name: Optional[str] + self, request: SynapseRequest, media_id: str, name: Optional[str] ) -> None: """Responds to requests for local media, if exists, or returns 404. @@ -223,7 +223,11 @@ class MediaRepository: ) async def get_remote_media( - self, request: Request, server_name: str, media_id: str, name: Optional[str] + self, + request: SynapseRequest, + server_name: str, + media_id: str, + name: Optional[str], ) -> None: """Respond to requests for remote media. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 9ffa983fbb..128706d297 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -29,7 +29,6 @@ import attr from twisted.internet.defer import Deferred from twisted.internet.error import DNSLookupError -from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError from synapse.http.client import SimpleHttpClient @@ -168,7 +167,7 @@ class PreviewUrlResource(DirectServeJsonResource): self._start_expire_url_cache_data, 10 * 1000 ) - async def _async_render_OPTIONS(self, request: Request) -> None: + async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: request.setHeader(b"Allow", b"OPTIONS, GET") respond_with_json(request, 200, {}, send_cors=True) diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index 22f43d8531..cb2f88676e 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -17,11 +17,10 @@ import logging from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from twisted.web.server import Request - from synapse.api.errors import SynapseError from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.servlet import parse_integer, parse_string +from synapse.http.site import SynapseRequest from synapse.rest.media.v1.media_storage import MediaStorage from ._base import ( @@ -57,7 +56,7 @@ class ThumbnailResource(DirectServeJsonResource): self.dynamic_thumbnails = hs.config.dynamic_thumbnails self.server_name = hs.hostname - async def _async_render_GET(self, request: Request) -> None: + async def _async_render_GET(self, request: SynapseRequest) -> None: set_cors_headers(request) server_name, media_id, _ = parse_media_id(request) width = parse_integer(request, "width", required=True) @@ -88,7 +87,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _respond_local_thumbnail( self, - request: Request, + request: SynapseRequest, media_id: str, width: int, height: int, @@ -121,7 +120,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_local_thumbnail( self, - request: Request, + request: SynapseRequest, media_id: str, desired_width: int, desired_height: int, @@ -186,7 +185,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_or_generate_remote_thumbnail( self, - request: Request, + request: SynapseRequest, server_name: str, media_id: str, desired_width: int, @@ -249,7 +248,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _respond_remote_thumbnail( self, - request: Request, + request: SynapseRequest, server_name: str, media_id: str, width: int, @@ -280,7 +279,7 @@ class ThumbnailResource(DirectServeJsonResource): async def _select_and_respond_with_thumbnail( self, - request: Request, + request: SynapseRequest, desired_width: int, desired_height: int, desired_method: str, diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 146adca8f1..39b29318bb 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -16,8 +16,6 @@ import logging from typing import IO, TYPE_CHECKING, Dict, List, Optional -from twisted.web.server import Request - from synapse.api.errors import Codes, SynapseError from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_bytes_from_args @@ -46,7 +44,7 @@ class UploadResource(DirectServeJsonResource): self.max_upload_size = hs.config.max_upload_size self.clock = hs.get_clock() - async def _async_render_OPTIONS(self, request: Request) -> None: + async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: respond_with_json(request, 200, {}, send_cors=True) async def _async_render_POST(self, request: SynapseRequest) -> None: diff --git a/tests/http/test_additional_resource.py b/tests/http/test_additional_resource.py index 768c2ba4ea..391196425c 100644 --- a/tests/http/test_additional_resource.py +++ b/tests/http/test_additional_resource.py @@ -45,7 +45,9 @@ class AdditionalResourceTests(HomeserverTestCase): handler = _AsyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) - channel = make_request(self.reactor, FakeSite(resource), "GET", "/") + channel = make_request( + self.reactor, FakeSite(resource, self.reactor), "GET", "/" + ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_async"}) @@ -54,7 +56,9 @@ class AdditionalResourceTests(HomeserverTestCase): handler = _SyncTestCustomEndpoint({}, None).handle_request resource = AdditionalResource(self.hs, handler) - channel = make_request(self.reactor, FakeSite(resource), "GET", "/") + channel = make_request( + self.reactor, FakeSite(resource, self.reactor), "GET", "/" + ) self.assertEqual(channel.code, 200) self.assertEqual(channel.json_body, {"some_key": "some_value_sync"}) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index 1160716929..f73fcd684e 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -152,7 +152,8 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): site = Mock(spec=["site_tag", "server_version_string", "getResourceFor"]) site.site_tag = "test-site" site.server_version_string = "Server v1" - request = SynapseRequest(FakeChannel(site, None)) + site.reactor = Mock() + request = SynapseRequest(FakeChannel(site, None), site) # Call requestReceived to finish instantiating the object. request.content = BytesIO() # Partially skip some of the internal processing of SynapseRequest. diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 01b1b0d4a0..13aa5eb51a 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): resource = hs.get_media_repository_resource().children[b"download"] channel = make_request( self.reactor, - FakeSite(resource), + FakeSite(resource, self.reactor), "GET", f"/{target}/{media_id}", shorthand=False, diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index febd40b656..192073c520 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -201,7 +201,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): """Ensure a piece of media is quarantined when trying to access it.""" channel = make_request( self.reactor, - FakeSite(self.download_resource), + FakeSite(self.download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, @@ -271,7 +271,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt to access the media channel = make_request( self.reactor, - FakeSite(self.download_resource), + FakeSite(self.download_resource, self.reactor), "GET", server_name_and_media_id, shorthand=False, @@ -458,7 +458,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): # Attempt to access each piece of media channel = make_request( self.reactor, - FakeSite(self.download_resource), + FakeSite(self.download_resource, self.reactor), "GET", server_and_media_id_2, shorthand=False, diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 2f02934e72..f813866073 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -125,7 +125,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Attempt to access media channel = make_request( self.reactor, - FakeSite(download_resource), + FakeSite(download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, @@ -164,7 +164,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): # Attempt to access media channel = make_request( self.reactor, - FakeSite(download_resource), + FakeSite(download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, @@ -525,7 +525,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): channel = make_request( self.reactor, - FakeSite(download_resource), + FakeSite(download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index cc3f16c62a..e79e0e1850 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2973,7 +2973,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): # Try to access a media and to create `last_access_ts` channel = make_request( self.reactor, - FakeSite(download_resource), + FakeSite(download_resource, self.reactor), "GET", server_and_media_id, shorthand=False, diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index b946fca8b3..9e9e953cf4 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -312,7 +312,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Load the password reset confirmation page channel = make_request( self.reactor, - FakeSite(self.submit_token_resource), + FakeSite(self.submit_token_resource, self.reactor), "GET", path, shorthand=False, @@ -326,7 +326,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Confirm the password reset channel = make_request( self.reactor, - FakeSite(self.submit_token_resource), + FakeSite(self.submit_token_resource, self.reactor), "POST", path, content=b"", diff --git a/tests/rest/client/test_consent.py b/tests/rest/client/test_consent.py index 65c58ce70a..84d092ca82 100644 --- a/tests/rest/client/test_consent.py +++ b/tests/rest/client/test_consent.py @@ -61,7 +61,11 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): """You can observe the terms form without specifying a user""" resource = consent_resource.ConsentResource(self.hs) channel = make_request( - self.reactor, FakeSite(resource), "GET", "/consent?v=1", shorthand=False + self.reactor, + FakeSite(resource, self.reactor), + "GET", + "/consent?v=1", + shorthand=False, ) self.assertEqual(channel.code, 200) @@ -83,7 +87,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): ) channel = make_request( self.reactor, - FakeSite(resource), + FakeSite(resource, self.reactor), "GET", consent_uri, access_token=access_token, @@ -98,7 +102,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # POST to the consent page, saying we've agreed channel = make_request( self.reactor, - FakeSite(resource), + FakeSite(resource, self.reactor), "POST", consent_uri + "&v=" + version, access_token=access_token, @@ -110,7 +114,7 @@ class ConsentResourceTestCase(unittest.HomeserverTestCase): # changed channel = make_request( self.reactor, - FakeSite(resource), + FakeSite(resource, self.reactor), "GET", consent_uri, access_token=access_token, diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py index c56e45fc10..3075d3f288 100644 --- a/tests/rest/client/utils.py +++ b/tests/rest/client/utils.py @@ -383,7 +383,7 @@ class RestHelper: path = "/_matrix/media/r0/upload?filename=%s" % (filename,) channel = make_request( self.hs.get_reactor(), - FakeSite(resource), + FakeSite(resource, self.hs.get_reactor()), "POST", path, content=image_data, diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index a75c0ea3f0..4672a68596 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -84,7 +84,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase): Checks that the response is a 200 and returns the decoded json body. """ channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel) + req = SynapseRequest(channel, self.site) req.content = BytesIO(b"") req.requestReceived( b"GET", @@ -183,7 +183,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): ) channel = FakeChannel(self.site, self.reactor) - req = SynapseRequest(channel) + req = SynapseRequest(channel, self.site) req.content = BytesIO(encode_canonical_json(data)) req.requestReceived( diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 9ea1c2bf25..44a643d506 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -252,7 +252,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel = make_request( self.reactor, - FakeSite(self.download_resource), + FakeSite(self.download_resource, self.reactor), "GET", self.media_id, shorthand=False, @@ -384,7 +384,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): params = "?width=32&height=32&method=scale" channel = make_request( self.reactor, - FakeSite(self.thumbnail_resource), + FakeSite(self.thumbnail_resource, self.reactor), "GET", self.media_id + params, shorthand=False, @@ -413,7 +413,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): channel = make_request( self.reactor, - FakeSite(self.thumbnail_resource), + FakeSite(self.thumbnail_resource, self.reactor), "GET", self.media_id + params, shorthand=False, @@ -433,7 +433,7 @@ class MediaRepoTests(unittest.HomeserverTestCase): params = "?width=32&height=32&method=" + method channel = make_request( self.reactor, - FakeSite(self.thumbnail_resource), + FakeSite(self.thumbnail_resource, self.reactor), "GET", self.media_id + params, shorthand=False, diff --git a/tests/server.py b/tests/server.py index b861c7b866..88dfa8058e 100644 --- a/tests/server.py +++ b/tests/server.py @@ -19,6 +19,7 @@ from twisted.internet.interfaces import ( IPullProducer, IPushProducer, IReactorPluggableNameResolver, + IReactorTime, IResolverSimple, ITransport, ) @@ -181,13 +182,14 @@ class FakeSite: site_tag = "test" access_logger = logging.getLogger("synapse.access.http.fake") - def __init__(self, resource: IResource): + def __init__(self, resource: IResource, reactor: IReactorTime): """ Args: resource: the resource to be used for rendering all requests """ self._resource = resource + self.reactor = reactor def getResourceFor(self, request): return self._resource @@ -268,7 +270,7 @@ def make_request( channel = FakeChannel(site, reactor, ip=client_ip) - req = request(channel) + req = request(channel, site) req.content = BytesIO(content) # Twisted expects to be at the end of the content when parsing the request. req.content.seek(SEEK_END) diff --git a/tests/test_server.py b/tests/test_server.py index 407e172e41..f2ffbc895b 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -65,7 +65,10 @@ class JsonResourceTests(unittest.TestCase): ) make_request( - self.reactor, FakeSite(res), b"GET", b"/_matrix/foo/%E2%98%83?a=%E2%98%83" + self.reactor, + FakeSite(res, self.reactor), + b"GET", + b"/_matrix/foo/%E2%98%83?a=%E2%98%83", ) self.assertEqual(got_kwargs, {"room_id": "\N{SNOWMAN}"}) @@ -84,7 +87,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" + ) self.assertEqual(channel.result["code"], b"500") @@ -100,7 +105,7 @@ class JsonResourceTests(unittest.TestCase): def _callback(request, **kwargs): d = Deferred() d.addCallback(_throw) - self.reactor.callLater(1, d.callback, True) + self.reactor.callLater(0.5, d.callback, True) return make_deferred_yieldable(d) res = JsonResource(self.homeserver) @@ -108,7 +113,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" + ) self.assertEqual(channel.result["code"], b"500") @@ -126,7 +133,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foo") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foo" + ) self.assertEqual(channel.result["code"], b"403") self.assertEqual(channel.json_body["error"], "Forbidden!!one!") @@ -148,7 +157,9 @@ class JsonResourceTests(unittest.TestCase): "GET", [re.compile("^/_matrix/foo$")], _callback, "test_servlet" ) - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/_matrix/foobar") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/_matrix/foobar" + ) self.assertEqual(channel.result["code"], b"400") self.assertEqual(channel.json_body["error"], "Unrecognized request") @@ -173,7 +184,9 @@ class JsonResourceTests(unittest.TestCase): ) # The path was registered as GET, but this is a HEAD request. - channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/_matrix/foo") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/_matrix/foo" + ) self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) @@ -280,7 +293,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" + ) self.assertEqual(channel.result["code"], b"200") body = channel.result["body"] @@ -298,7 +313,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" + ) self.assertEqual(channel.result["code"], b"301") headers = channel.result["headers"] @@ -319,7 +336,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - channel = make_request(self.reactor, FakeSite(res), b"GET", b"/path") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"GET", b"/path" + ) self.assertEqual(channel.result["code"], b"304") headers = channel.result["headers"] @@ -338,7 +357,9 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): res = WrapHtmlRequestHandlerTests.TestResource() res.callback = callback - channel = make_request(self.reactor, FakeSite(res), b"HEAD", b"/path") + channel = make_request( + self.reactor, FakeSite(res, self.reactor), b"HEAD", b"/path" + ) self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) -- cgit 1.5.1 From bb7fdd821b07016a43bdbb245eda5b35356863c0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 24 Sep 2021 07:25:21 -0400 Subject: Use direct references for configuration variables (part 5). (#10897) --- changelog.d/10897.misc | 1 + synapse/app/_base.py | 4 ++-- synapse/app/admin_cmd.py | 6 +++--- synapse/app/generic_worker.py | 6 +++--- synapse/app/homeserver.py | 2 +- synapse/config/logger.py | 4 +++- synapse/crypto/context_factory.py | 4 ++-- synapse/events/spamcheck.py | 2 +- synapse/events/third_party_rules.py | 4 ++-- synapse/handlers/auth.py | 10 ++++++---- synapse/handlers/directory.py | 6 +++--- synapse/handlers/federation.py | 2 +- synapse/handlers/message.py | 8 ++++---- synapse/handlers/register.py | 2 +- synapse/handlers/room.py | 8 +++++--- synapse/handlers/room_list.py | 2 +- synapse/handlers/room_member.py | 2 +- synapse/handlers/saml.py | 15 +++++++------- synapse/handlers/sso.py | 10 ++++++---- synapse/handlers/stats.py | 2 +- synapse/handlers/user_directory.py | 2 +- synapse/logging/opentracing.py | 6 +++--- synapse/replication/http/_base.py | 4 ++-- synapse/replication/tcp/handler.py | 4 ++-- synapse/rest/admin/__init__.py | 2 +- synapse/rest/client/login.py | 2 +- synapse/rest/client/user_directory.py | 2 +- synapse/rest/client/versions.py | 6 +++--- synapse/rest/client/voip.py | 12 +++++------ synapse/rest/media/v1/config_resource.py | 2 +- synapse/rest/media/v1/media_repository.py | 20 +++++++++++-------- synapse/rest/media/v1/preview_url_resource.py | 10 +++++----- synapse/rest/media/v1/storage_provider.py | 2 +- synapse/rest/media/v1/thumbnail_resource.py | 2 +- synapse/rest/media/v1/upload_resource.py | 2 +- synapse/rest/synapse/client/__init__.py | 2 +- .../rest/synapse/client/saml2/metadata_resource.py | 2 +- synapse/server_notices/server_notices_manager.py | 23 +++++++++++----------- synapse/storage/databases/main/registration.py | 2 +- synapse/storage/databases/main/stats.py | 2 +- synapse/storage/databases/main/user_directory.py | 4 ++-- tests/handlers/test_directory.py | 4 +++- tests/handlers/test_stats.py | 8 ++++---- tests/handlers/test_user_directory.py | 6 +++--- tests/rest/admin/test_media.py | 4 ++-- tests/rest/admin/test_user.py | 2 +- tests/rest/media/v1/test_media_storage.py | 2 +- .../test_resource_limits_server_notices.py | 2 +- 48 files changed, 128 insertions(+), 113 deletions(-) create mode 100644 changelog.d/10897.misc (limited to 'tests/rest') diff --git a/changelog.d/10897.misc b/changelog.d/10897.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10897.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index f657f11f76..548f6dcde9 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -88,8 +88,8 @@ def start_worker_reactor(appname, config, run_command=reactor.run): appname, soft_file_limit=config.soft_file_limit, gc_thresholds=config.gc_thresholds, - pid_file=config.worker_pid_file, - daemonize=config.worker_daemonize, + pid_file=config.worker.worker_pid_file, + daemonize=config.worker.worker_daemonize, print_pidfile=config.print_pidfile, logger=logger, run_command=run_command, diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 259d5ec7cc..f2c5b75247 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -186,9 +186,9 @@ def start(config_options): config.worker.worker_app = "synapse.app.admin_cmd" if ( - not config.worker_daemonize - and not config.worker_log_file - and not config.worker_log_config + not config.worker.worker_daemonize + and not config.worker.worker_log_file + and not config.worker.worker_log_config ): # Since we're meant to be run as a "command" let's not redirect stdio # unless we've actually set log config. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index e0776689ce..3036e1b4a0 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -140,7 +140,7 @@ class KeyUploadServlet(RestServlet): self.auth = hs.get_auth() self.store = hs.get_datastore() self.http_client = hs.get_simple_http_client() - self.main_uri = hs.config.worker_main_http_uri + self.main_uri = hs.config.worker.worker_main_http_uri async def on_POST(self, request: Request, device_id: Optional[str]): requester = await self.auth.get_user_by_req(request, allow_guest=True) @@ -321,7 +321,7 @@ class GenericWorkerServer(HomeServer): elif name == "federation": resources.update({FEDERATION_PREFIX: TransportLayerServer(self)}) elif name == "media": - if self.config.can_load_media_repo: + if self.config.media.can_load_media_repo: media_repo = self.get_media_repository_resource() # We need to serve the admin servlets for media on the @@ -384,7 +384,7 @@ class GenericWorkerServer(HomeServer): logger.info("Synapse worker now listening on port %d", port) def start_listening(self): - for listener in self.config.worker_listeners: + for listener in self.config.worker.worker_listeners: if listener.type == "http": self._listen_http(listener) elif listener.type == "manhole": diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index f1769f146b..205831dcda 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["media", "federation", "client"]: - if self.config.enable_media_repo: + if self.config.media.enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} diff --git a/synapse/config/logger.py b/synapse/config/logger.py index bf8ca7d5fe..0a08231e5a 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -322,7 +322,9 @@ def setup_logging( """ log_config_path = ( - config.worker_log_config if use_worker_options else config.logging.log_config + config.worker.worker_log_config + if use_worker_options + else config.logging.log_config ) # Perform one-time logging configuration. diff --git a/synapse/crypto/context_factory.py b/synapse/crypto/context_factory.py index d310976fe3..2a6110eb10 100644 --- a/synapse/crypto/context_factory.py +++ b/synapse/crypto/context_factory.py @@ -74,8 +74,8 @@ class ServerContextFactory(ContextFactory): context.set_options( SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL.OP_NO_TLSv1 | SSL.OP_NO_TLSv1_1 ) - context.use_certificate_chain_file(config.tls_certificate_file) - context.use_privatekey(config.tls_private_key) + context.use_certificate_chain_file(config.tls.tls_certificate_file) + context.use_privatekey(config.tls.tls_private_key) # https://hynek.me/articles/hardening-your-web-servers-ssl-ciphers/ context.set_cipher_list( diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 57f1d53fa8..19ee246f96 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -78,7 +78,7 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer"): """ spam_checkers: List[Any] = [] api = hs.get_module_api() - for module, config in hs.config.spam_checkers: + for module, config in hs.config.spamchecker.spam_checkers: # Older spam checkers don't accept the `api` argument, so we # try and detect support. spam_args = inspect.getfullargspec(module) diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 7a6eb3e516..d94b1bb4d2 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -42,10 +42,10 @@ def load_legacy_third_party_event_rules(hs: "HomeServer"): """Wrapper that loads a third party event rules module configured using the old configuration, and registers the hooks they implement. """ - if hs.config.third_party_event_rules is None: + if hs.config.thirdpartyrules.third_party_event_rules is None: return - module, config = hs.config.third_party_event_rules + module, config = hs.config.thirdpartyrules.third_party_event_rules api = hs.get_module_api() third_party_rules = module(config=config, module_api=api) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0f80dfdc43..a8c717efd5 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -277,23 +277,25 @@ class AuthHandler(BaseHandler): # after the SSO completes and before redirecting them back to their client. # It notifies the user they are about to give access to their matrix account # to the client. - self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template + self._sso_redirect_confirm_template = ( + hs.config.sso.sso_redirect_confirm_template + ) # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. - self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template + self._sso_auth_confirm_template = hs.config.sso.sso_auth_confirm_template # The following template is shown during the SSO authentication process if # the account is deactivated. self._sso_account_deactivated_template = ( - hs.config.sso_account_deactivated_template + hs.config.sso.sso_account_deactivated_template ) self._server_name = hs.config.server.server_name # cast to tuple for use with str.startswith - self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + self._whitelisted_sso_clients = tuple(hs.config.sso.sso_client_whitelist) # A mapping of user ID to extra attributes to include in the login # response. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index d487fee627..5cfba3c817 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -48,7 +48,7 @@ class DirectoryHandler(BaseHandler): self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastore() self.config = hs.config - self.enable_room_list_search = hs.config.enable_room_list_search + self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.require_membership_for_aliases self.third_party_event_rules = hs.get_third_party_event_rules() @@ -143,7 +143,7 @@ class DirectoryHandler(BaseHandler): ): raise AuthError(403, "This user is not permitted to create this alias") - if not self.config.is_alias_creation_allowed( + if not self.config.roomdirectory.is_alias_creation_allowed( user_id, room_id, room_alias_str ): # Lets just return a generic message, as there may be all sorts of @@ -459,7 +459,7 @@ class DirectoryHandler(BaseHandler): if canonical_alias: room_aliases.append(canonical_alias) - if not self.config.is_publishing_room_allowed( + if not self.config.roomdirectory.is_publishing_room_allowed( user_id, room_id, room_aliases ): # Lets just return a generic message, as there may be all sorts of diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 4523b25636..b17ef2a9a1 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -91,7 +91,7 @@ class FederationHandler(BaseHandler): self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self._event_auth_handler = hs.get_event_auth_handler() - self._server_notices_mxid = hs.config.server_notices_mxid + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self.config = hs.config self.http_client = hs.get_proxied_blacklisted_http_client() self._replication = hs.get_replication_data_handler() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ad4e4a3d6f..c66aefe2c4 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -692,10 +692,10 @@ class EventCreationHandler: return False async def _is_server_notices_room(self, room_id: str) -> bool: - if self.config.server_notices_mxid is None: + if self.config.servernotices.server_notices_mxid is None: return False user_ids = await self.store.get_users_in_room(room_id) - return self.config.server_notices_mxid in user_ids + return self.config.servernotices.server_notices_mxid in user_ids async def assert_accepted_privacy_policy(self, requester: Requester) -> None: """Check if a user has accepted the privacy policy @@ -731,8 +731,8 @@ class EventCreationHandler: # exempt the system notices user if ( - self.config.server_notices_mxid is not None - and user_id == self.config.server_notices_mxid + self.config.servernotices.server_notices_mxid is not None + and user_id == self.config.servernotices.server_notices_mxid ): return diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 01c5e1385d..4f99f137a2 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -98,7 +98,7 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._account_validity_handler = hs.get_account_validity_handler() self._user_consent_version = self.hs.config.consent.user_consent_version - self._server_notices_mxid = hs.config.server_notices_mxid + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self._server_name = hs.hostname self.spam_checker = hs.get_spam_checker() diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b5768220d9..408b7d7b74 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -126,7 +126,7 @@ class RoomCreationHandler(BaseHandler): for preset_name, preset_config in self._presets_dict.items(): encrypted = ( preset_name - in self.config.encryption_enabled_by_default_for_room_presets + in self.config.room.encryption_enabled_by_default_for_room_presets ) preset_config["encrypted"] = encrypted @@ -141,7 +141,7 @@ class RoomCreationHandler(BaseHandler): self._upgrade_response_cache: ResponseCache[Tuple[str, str]] = ResponseCache( hs.get_clock(), "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS ) - self._server_notices_mxid = hs.config.server_notices_mxid + self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self.third_party_event_rules = hs.get_third_party_event_rules() @@ -757,7 +757,9 @@ class RoomCreationHandler(BaseHandler): ) if is_public: - if not self.config.is_publishing_room_allowed(user_id, room_id, room_alias): + if not self.config.roomdirectory.is_publishing_room_allowed( + user_id, room_id, room_alias + ): # Lets just return a generic message, as there may be all sorts of # reasons why we said no. TODO: Allow configurable error messages # per alias creation rule? diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index c83ff585e3..c3d4199ed1 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -52,7 +52,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.enable_room_list_search = hs.config.enable_room_list_search + self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] ] = ResponseCache(hs.get_clock(), "room_list") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 7bb3f0bc47..1a56c82fbd 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -88,7 +88,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.clock = hs.get_clock() self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() - self._server_notices_mxid = self.config.server_notices_mxid + self._server_notices_mxid = self.config.servernotices.server_notices_mxid self._enable_lookup = hs.config.enable_3pid_lookup self.allow_per_room_profiles = self.config.allow_per_room_profiles diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 185befbe9f..2fed9f377a 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -54,19 +54,18 @@ class Saml2SessionData: class SamlHandler(BaseHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self._saml_client = Saml2Client(hs.config.saml2_sp_config) - self._saml_idp_entityid = hs.config.saml2_idp_entityid + self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) + self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid - self._saml2_session_lifetime = hs.config.saml2_session_lifetime + self._saml2_session_lifetime = hs.config.saml2.saml2_session_lifetime self._grandfathered_mxid_source_attribute = ( - hs.config.saml2_grandfathered_mxid_source_attribute + hs.config.saml2.saml2_grandfathered_mxid_source_attribute ) self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements - self._error_template = hs.config.sso_error_template # plugin to do custom mapping from saml response to mxid - self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( - hs.config.saml2_user_mapping_provider_config, + self._user_mapping_provider = hs.config.saml2.saml2_user_mapping_provider_class( + hs.config.saml2.saml2_user_mapping_provider_config, ModuleApi(hs, hs.get_auth_handler()), ) @@ -411,7 +410,7 @@ class DefaultSamlMappingProvider: self._mxid_mapper = parsed_config.mxid_mapper self._grandfathered_mxid_source_attribute = ( - module_api._hs.config.saml2_grandfathered_mxid_source_attribute + module_api._hs.config.saml2.saml2_grandfathered_mxid_source_attribute ) def get_remote_user_id( diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index e044251a13..49fde01cf0 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -184,15 +184,17 @@ class SsoHandler: self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._auth_handler = hs.get_auth_handler() - self._error_template = hs.config.sso_error_template - self._bad_user_template = hs.config.sso_auth_bad_user_template + self._error_template = hs.config.sso.sso_error_template + self._bad_user_template = hs.config.sso.sso_auth_bad_user_template self._profile_handler = hs.get_profile_handler() # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. - self._sso_auth_success_template = hs.config.sso_auth_success_template + self._sso_auth_success_template = hs.config.sso.sso_auth_success_template - self._sso_update_profile_information = hs.config.sso_update_profile_information + self._sso_update_profile_information = ( + hs.config.sso.sso_update_profile_information + ) # a lock on the mappings self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 9fc53333fc..bd3e6f2ec7 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -46,7 +46,7 @@ class StatsHandler: self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id - self.stats_enabled = hs.config.stats_enabled + self.stats_enabled = hs.config.stats.stats_enabled # The current position in the current_state_delta stream self.pos: Optional[int] = None diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 8dc46d7674..b91e7cb501 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -61,7 +61,7 @@ class UserDirectoryHandler(StateDeltasHandler): self.notifier = hs.get_notifier() self.is_mine_id = hs.is_mine_id self.update_user_directory = hs.config.update_user_directory - self.search_all_users = hs.config.user_directory_search_all_users + self.search_all_users = hs.config.userdirectory.user_directory_search_all_users self.spam_checker = hs.get_spam_checker() # The current position in the current_state_delta stream self.pos: Optional[int] = None diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index c6c4d3bd29..03d2dd94f6 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -363,7 +363,7 @@ def noop_context_manager(*args, **kwargs): def init_tracer(hs: "HomeServer"): """Set the whitelists and initialise the JaegerClient tracer""" global opentracing - if not hs.config.opentracer_enabled: + if not hs.config.tracing.opentracer_enabled: # We don't have a tracer opentracing = None return @@ -377,12 +377,12 @@ def init_tracer(hs: "HomeServer"): # Pull out the jaeger config if it was given. Otherwise set it to something sensible. # See https://github.com/jaegertracing/jaeger-client-python/blob/master/jaeger_client/config.py - set_homeserver_whitelist(hs.config.opentracer_whitelist) + set_homeserver_whitelist(hs.config.tracing.opentracer_whitelist) from jaeger_client.metrics.prometheus import PrometheusMetricsFactory config = JaegerConfig( - config=hs.config.jaeger_config, + config=hs.config.tracing.jaeger_config, service_name=f"{hs.config.server.server_name} {hs.get_instance_name()}", scope_manager=LogContextScopeManager(hs.config), metrics_factory=PrometheusMetricsFactory(), diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 25589b0042..f1b78d09f9 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -168,8 +168,8 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): client = hs.get_simple_http_client() local_instance_name = hs.get_instance_name() - master_host = hs.config.worker_replication_host - master_port = hs.config.worker_replication_http_port + master_host = hs.config.worker.worker_replication_host + master_port = hs.config.worker.worker_replication_http_port instance_map = hs.config.worker.instance_map diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 509ed7fb13..1438a82b60 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -322,8 +322,8 @@ class ReplicationCommandHandler: else: client_name = hs.get_instance_name() self._factory = DirectTcpReplicationClientFactory(hs, client_name, self) - host = hs.config.worker_replication_host - port = hs.config.worker_replication_port + host = hs.config.worker.worker_replication_host + port = hs.config.worker.worker_replication_port hs.get_reactor().connectTCP(host.encode(), port, self._factory) def get_streams(self) -> Dict[str, Stream]: diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index a03774c98a..e1506deb2b 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -267,7 +267,7 @@ def register_servlets_for_client_rest_resource( # Load the media repo ones if we're using them. Otherwise load the servlets which # don't need a media repo (typically readonly admin APIs). - if hs.config.can_load_media_repo: + if hs.config.media.can_load_media_repo: register_servlets_for_media_repo(hs, http_server) else: ListMediaInRoom(hs).register(http_server) diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 64446fc486..fa5c173f4b 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -76,7 +76,7 @@ class LoginRestServlet(RestServlet): self.jwt_audiences = hs.config.jwt.jwt_audiences # SSO configuration. - self.saml2_enabled = hs.config.saml2_enabled + self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled self._msc2918_enabled = hs.config.access_token_lifetime is not None diff --git a/synapse/rest/client/user_directory.py b/synapse/rest/client/user_directory.py index 8852811114..a47d9bd01d 100644 --- a/synapse/rest/client/user_directory.py +++ b/synapse/rest/client/user_directory.py @@ -58,7 +58,7 @@ class UserDirectorySearchRestServlet(RestServlet): requester = await self.auth.get_user_by_req(request, allow_guest=False) user_id = requester.user.to_string() - if not self.hs.config.user_directory_search_enabled: + if not self.hs.config.userdirectory.user_directory_search_enabled: return 200, {"limited": False, "results": []} body = parse_json_object_from_request(request) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index a1a815cf82..b52a296d8f 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -42,15 +42,15 @@ class VersionsRestServlet(RestServlet): # Calculate these once since they shouldn't change after start-up. self.e2ee_forced_public = ( RoomCreationPreset.PUBLIC_CHAT - in self.config.encryption_enabled_by_default_for_room_presets + in self.config.room.encryption_enabled_by_default_for_room_presets ) self.e2ee_forced_private = ( RoomCreationPreset.PRIVATE_CHAT - in self.config.encryption_enabled_by_default_for_room_presets + in self.config.room.encryption_enabled_by_default_for_room_presets ) self.e2ee_forced_trusted_private = ( RoomCreationPreset.TRUSTED_PRIVATE_CHAT - in self.config.encryption_enabled_by_default_for_room_presets + in self.config.room.encryption_enabled_by_default_for_room_presets ) def on_GET(self, request: Request) -> Tuple[int, JsonDict]: diff --git a/synapse/rest/client/voip.py b/synapse/rest/client/voip.py index 9d46ed3af3..ea2b8aa45f 100644 --- a/synapse/rest/client/voip.py +++ b/synapse/rest/client/voip.py @@ -37,14 +37,14 @@ class VoipRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req( - request, self.hs.config.turn_allow_guests + request, self.hs.config.voip.turn_allow_guests ) - turnUris = self.hs.config.turn_uris - turnSecret = self.hs.config.turn_shared_secret - turnUsername = self.hs.config.turn_username - turnPassword = self.hs.config.turn_password - userLifetime = self.hs.config.turn_user_lifetime + turnUris = self.hs.config.voip.turn_uris + turnSecret = self.hs.config.voip.turn_shared_secret + turnUsername = self.hs.config.voip.turn_username + turnPassword = self.hs.config.voip.turn_password + userLifetime = self.hs.config.voip.turn_user_lifetime if turnUris and turnSecret and userLifetime: expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000 diff --git a/synapse/rest/media/v1/config_resource.py b/synapse/rest/media/v1/config_resource.py index 712d4e8368..a95804d327 100644 --- a/synapse/rest/media/v1/config_resource.py +++ b/synapse/rest/media/v1/config_resource.py @@ -31,7 +31,7 @@ class MediaConfigResource(DirectServeJsonResource): config = hs.config self.clock = hs.get_clock() self.auth = hs.get_auth() - self.limits_dict = {"m.upload.size": config.max_upload_size} + self.limits_dict = {"m.upload.size": config.media.max_upload_size} async def _async_render_GET(self, request: SynapseRequest) -> None: await self.auth.get_user_by_req(request) diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index c1bd81100d..abd88a2d4f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -76,16 +76,16 @@ class MediaRepository: self.clock = hs.get_clock() self.server_name = hs.hostname self.store = hs.get_datastore() - self.max_upload_size = hs.config.max_upload_size - self.max_image_pixels = hs.config.max_image_pixels + self.max_upload_size = hs.config.media.max_upload_size + self.max_image_pixels = hs.config.media.max_image_pixels Thumbnailer.set_limits(self.max_image_pixels) - self.primary_base_path: str = hs.config.media_store_path + self.primary_base_path: str = hs.config.media.media_store_path self.filepaths: MediaFilePaths = MediaFilePaths(self.primary_base_path) - self.dynamic_thumbnails = hs.config.dynamic_thumbnails - self.thumbnail_requirements = hs.config.thumbnail_requirements + self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails + self.thumbnail_requirements = hs.config.media.thumbnail_requirements self.remote_media_linearizer = Linearizer(name="media_remote") @@ -100,7 +100,11 @@ class MediaRepository: # potentially upload to. storage_providers = [] - for clz, provider_config, wrapper_config in hs.config.media_storage_providers: + for ( + clz, + provider_config, + wrapper_config, + ) in hs.config.media.media_storage_providers: backend = clz(hs, provider_config) provider = StorageProviderWrapper( backend, @@ -975,7 +979,7 @@ class MediaRepositoryResource(Resource): def __init__(self, hs: "HomeServer"): # If we're not configured to use it, raise if we somehow got here. - if not hs.config.can_load_media_repo: + if not hs.config.media.can_load_media_repo: raise ConfigError("Synapse is not configured to use a media repo.") super().__init__() @@ -986,7 +990,7 @@ class MediaRepositoryResource(Resource): self.putChild( b"thumbnail", ThumbnailResource(hs, media_repo, media_repo.media_storage) ) - if hs.config.url_preview_enabled: + if hs.config.media.url_preview_enabled: self.putChild( b"preview_url", PreviewUrlResource(hs, media_repo, media_repo.media_storage), diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 128706d297..0b0c4d6469 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -125,14 +125,14 @@ class PreviewUrlResource(DirectServeJsonResource): self.auth = hs.get_auth() self.clock = hs.get_clock() self.filepaths = media_repo.filepaths - self.max_spider_size = hs.config.max_spider_size + self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname self.store = hs.get_datastore() self.client = SimpleHttpClient( hs, treq_args={"browser_like_redirects": True}, - ip_whitelist=hs.config.url_preview_ip_range_whitelist, - ip_blacklist=hs.config.url_preview_ip_range_blacklist, + ip_whitelist=hs.config.media.url_preview_ip_range_whitelist, + ip_blacklist=hs.config.media.url_preview_ip_range_blacklist, use_proxy=True, ) self.media_repo = media_repo @@ -150,8 +150,8 @@ class PreviewUrlResource(DirectServeJsonResource): or instance_running_jobs == hs.get_instance_name() ) - self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist - self.url_preview_accept_language = hs.config.url_preview_accept_language + self.url_preview_url_blacklist = hs.config.media.url_preview_url_blacklist + self.url_preview_accept_language = hs.config.media.url_preview_accept_language # memory cache mapping urls to an ObservableDeferred returning # JSON-encoded OG metadata diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index 6c9969e55f..289e4297f2 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -125,7 +125,7 @@ class FileStorageProviderBackend(StorageProvider): def __init__(self, hs: "HomeServer", config: str): self.hs = hs - self.cache_directory = hs.config.media_store_path + self.cache_directory = hs.config.media.media_store_path self.base_directory = config def __str__(self) -> str: diff --git a/synapse/rest/media/v1/thumbnail_resource.py b/synapse/rest/media/v1/thumbnail_resource.py index cb2f88676e..ed91ef5a42 100644 --- a/synapse/rest/media/v1/thumbnail_resource.py +++ b/synapse/rest/media/v1/thumbnail_resource.py @@ -53,7 +53,7 @@ class ThumbnailResource(DirectServeJsonResource): self.store = hs.get_datastore() self.media_repo = media_repo self.media_storage = media_storage - self.dynamic_thumbnails = hs.config.dynamic_thumbnails + self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails self.server_name = hs.hostname async def _async_render_GET(self, request: SynapseRequest) -> None: diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index 39b29318bb..7dcb1428e4 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -41,7 +41,7 @@ class UploadResource(DirectServeJsonResource): self.clock = hs.get_clock() self.server_name = hs.hostname self.auth = hs.get_auth() - self.max_upload_size = hs.config.max_upload_size + self.max_upload_size = hs.config.media.max_upload_size self.clock = hs.get_clock() async def _async_render_OPTIONS(self, request: SynapseRequest) -> None: diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 086c80b723..6ad558f5d1 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -50,7 +50,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc resources["/_synapse/client/oidc"] = OIDCResource(hs) - if hs.config.saml2_enabled: + if hs.config.saml2.saml2_enabled: from synapse.rest.synapse.client.saml2 import SAML2Resource res = SAML2Resource(hs) diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py index 64378ed57b..d8eae3970d 100644 --- a/synapse/rest/synapse/client/saml2/metadata_resource.py +++ b/synapse/rest/synapse/client/saml2/metadata_resource.py @@ -30,7 +30,7 @@ class SAML2MetadataResource(Resource): def __init__(self, hs: "HomeServer"): Resource.__init__(self) - self.sp_config = hs.config.saml2_sp_config + self.sp_config = hs.config.saml2.saml2_sp_config def render_GET(self, request: Request) -> bytes: metadata_xml = saml2.metadata.create_metadata_string( diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index d87a538917..cd1c5ff6f4 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -39,7 +39,7 @@ class ServerNoticesManager: self._server_name = hs.hostname self._notifier = hs.get_notifier() - self.server_notices_mxid = self._config.server_notices_mxid + self.server_notices_mxid = self._config.servernotices.server_notices_mxid def is_enabled(self): """Checks if server notices are enabled on this server. @@ -47,7 +47,7 @@ class ServerNoticesManager: Returns: bool """ - return self._config.server_notices_mxid is not None + return self.server_notices_mxid is not None async def send_notice( self, @@ -71,9 +71,9 @@ class ServerNoticesManager: room_id = await self.get_or_create_notice_room_for_user(user_id) await self.maybe_invite_user_to_room(user_id, room_id) - system_mxid = self._config.server_notices_mxid + assert self.server_notices_mxid is not None requester = create_requester( - system_mxid, authenticated_entity=self._server_name + self.server_notices_mxid, authenticated_entity=self._server_name ) logger.info("Sending server notice to %s", user_id) @@ -81,7 +81,7 @@ class ServerNoticesManager: event_dict = { "type": type, "room_id": room_id, - "sender": system_mxid, + "sender": self.server_notices_mxid, "content": event_content, } @@ -106,7 +106,7 @@ class ServerNoticesManager: Returns: room id of notice room. """ - if not self.is_enabled(): + if self.server_notices_mxid is None: raise Exception("Server notices not enabled") assert self._is_mine_id(user_id), "Cannot send server notices to remote users" @@ -139,12 +139,12 @@ class ServerNoticesManager: # avatar, we have to use both. join_profile = None if ( - self._config.server_notices_mxid_display_name is not None - or self._config.server_notices_mxid_avatar_url is not None + self._config.servernotices.server_notices_mxid_display_name is not None + or self._config.servernotices.server_notices_mxid_avatar_url is not None ): join_profile = { - "displayname": self._config.server_notices_mxid_display_name, - "avatar_url": self._config.server_notices_mxid_avatar_url, + "displayname": self._config.servernotices.server_notices_mxid_display_name, + "avatar_url": self._config.servernotices.server_notices_mxid_avatar_url, } requester = create_requester( @@ -154,7 +154,7 @@ class ServerNoticesManager: requester, config={ "preset": RoomCreationPreset.PRIVATE_CHAT, - "name": self._config.server_notices_room_name, + "name": self._config.servernotices.server_notices_room_name, "power_level_content_override": {"users_default": -10}, }, ratelimit=False, @@ -178,6 +178,7 @@ class ServerNoticesManager: user_id: The ID of the user to invite. room_id: The ID of the room to invite the user to. """ + assert self.server_notices_mxid is not None requester = create_requester( self.server_notices_mxid, authenticated_entity=self._server_name ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 52ef9deede..c83089ee63 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -2015,7 +2015,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): (user_id_obj.localpart, create_profile_with_displayname), ) - if self.hs.config.stats_enabled: + if self.hs.config.stats.stats_enabled: # we create a new completed user statistics row # we don't strictly need current_token since this user really can't diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 343d6efc92..e20033bb28 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -98,7 +98,7 @@ class StatsStore(StateDeltasStore): self.server_name = hs.hostname self.clock = self.hs.get_clock() - self.stats_enabled = hs.config.stats_enabled + self.stats_enabled = hs.config.stats.stats_enabled self.stats_delta_processing_lock = DeferredLock() diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 7ca04237a5..90d65edc42 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -551,7 +551,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): super().__init__(database, db_conn, hs) self._prefer_local_users_in_search = ( - hs.config.user_directory_search_prefer_local_users + hs.config.userdirectory.user_directory_search_prefer_local_users ) self._server_name = hs.config.server.server_name @@ -741,7 +741,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): } """ - if self.hs.config.user_directory_search_all_users: + if self.hs.config.userdirectory.user_directory_search_all_users: join_args = (user_id,) where_clause = "user_id != ?" else: diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index a0a48b564e..6a2e76ca4a 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -405,7 +405,9 @@ class TestCreateAliasACL(unittest.HomeserverTestCase): rd_config = RoomDirectoryConfig() rd_config.read_config(config) - self.hs.config.is_alias_creation_allowed = rd_config.is_alias_creation_allowed + self.hs.config.roomdirectory.is_alias_creation_allowed = ( + rd_config.is_alias_creation_allowed + ) return hs diff --git a/tests/handlers/test_stats.py b/tests/handlers/test_stats.py index 1ba4c05b9b..24b7ef6efc 100644 --- a/tests/handlers/test_stats.py +++ b/tests/handlers/test_stats.py @@ -118,7 +118,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(len(r), 0) # Disable stats - self.hs.config.stats_enabled = False + self.hs.config.stats.stats_enabled = False self.handler.stats_enabled = False u1 = self.register_user("u1", "pass") @@ -134,7 +134,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertEqual(len(r), 0) # Enable stats - self.hs.config.stats_enabled = True + self.hs.config.stats.stats_enabled = True self.handler.stats_enabled = True # Do the initial population of the user directory via the background update @@ -469,7 +469,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): behaviour eventually to still keep current rows. """ - self.hs.config.stats_enabled = False + self.hs.config.stats.stats_enabled = False self.handler.stats_enabled = False u1 = self.register_user("u1", "pass") @@ -481,7 +481,7 @@ class StatsRoomTests(unittest.HomeserverTestCase): self.assertIsNone(self._get_current_stats("room", r1)) self.assertIsNone(self._get_current_stats("user", u1)) - self.hs.config.stats_enabled = True + self.hs.config.stats.stats_enabled = True self.handler.stats_enabled = True self._perform_background_initial_update() diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index ba32585a14..266333c553 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -451,7 +451,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): visible. """ self.handler.search_all_users = True - self.hs.config.user_directory_search_all_users = True + self.hs.config.userdirectory.user_directory_search_all_users = True u1 = self.register_user("user1", "pass") self.register_user("user2", "pass") @@ -607,7 +607,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): return hs def test_disabling_room_list(self): - self.config.user_directory_search_enabled = True + self.config.userdirectory.user_directory_search_enabled = True # First we create a room with another user so that user dir is non-empty # for our user @@ -624,7 +624,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase): self.assertTrue(len(channel.json_body["results"]) > 0) # Disable user directory and check search returns nothing - self.config.user_directory_search_enabled = False + self.config.userdirectory.user_directory_search_enabled = False channel = self.make_request( "POST", b"user_directory/search", b'{"search_term":"user2"}' ) diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index f813866073..ce30a19213 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -43,7 +43,7 @@ class DeleteMediaByIDTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.filepaths = MediaFilePaths(hs.config.media.media_store_path) def test_no_auth(self): """ @@ -200,7 +200,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") - self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.url = "/_synapse/admin/v1/media/%s/delete" % self.server_name def test_no_auth(self): diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index e79e0e1850..ee3ae9cce4 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2473,7 +2473,7 @@ class UserMediaRestTestCase(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.store = hs.get_datastore() self.media_repo = hs.get_media_repository_resource() - self.filepaths = MediaFilePaths(hs.config.media_store_path) + self.filepaths = MediaFilePaths(hs.config.media.media_store_path) self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user_tok = self.login("admin", "pass") diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py index 44a643d506..4ae00755c9 100644 --- a/tests/rest/media/v1/test_media_storage.py +++ b/tests/rest/media/v1/test_media_storage.py @@ -53,7 +53,7 @@ class MediaStorageTests(unittest.HomeserverTestCase): self.primary_base_path = os.path.join(self.test_dir, "primary") self.secondary_base_path = os.path.join(self.test_dir, "secondary") - hs.config.media_store_path = self.primary_base_path + hs.config.media.media_store_path = self.primary_base_path storage_providers = [FileStorageProviderBackend(hs, self.secondary_base_path)] diff --git a/tests/server_notices/test_resource_limits_server_notices.py b/tests/server_notices/test_resource_limits_server_notices.py index 8701b5f7e3..7f25200a5d 100644 --- a/tests/server_notices/test_resource_limits_server_notices.py +++ b/tests/server_notices/test_resource_limits_server_notices.py @@ -326,7 +326,7 @@ class TestResourceLimitsServerNoticesWithRealRooms(unittest.HomeserverTestCase): for event in events: if ( event["type"] == EventTypes.Message - and event["sender"] == self.hs.config.server_notices_mxid + and event["sender"] == self.hs.config.servernotices.server_notices_mxid ): notice_in_room = True -- cgit 1.5.1 From b10257e87972d158f4b6a0c7d1fe7239014ea10a Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Fri, 24 Sep 2021 16:38:23 +0200 Subject: Add a spamchecker callback to allow or deny room creation based on invites (#10898) This is in the context of creating new module callbacks that modules in https://github.com/matrix-org/synapse-dinsic can use, in an effort to reconcile the spam checker API in synapse-dinsic with the one in mainline. This adds a callback that's fairly similar to user_may_create_room except it also allows processing based on the invites sent at room creation. --- changelog.d/10898.feature | 1 + docs/modules/spam_checker_callbacks.md | 29 ++++++++ synapse/events/spamcheck.py | 42 ++++++++++++ synapse/handlers/room.py | 14 ++-- tests/rest/client/test_rooms.py | 119 ++++++++++++++++++++++++++++++++- 5 files changed, 199 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10898.feature (limited to 'tests/rest') diff --git a/changelog.d/10898.feature b/changelog.d/10898.feature new file mode 100644 index 0000000000..97fa39fd0c --- /dev/null +++ b/changelog.d/10898.feature @@ -0,0 +1 @@ +Add a `user_may_create_room_with_invites` spam checker callback to allow modules to allow or deny a room creation request based on the invites and/or 3PID invites it includes. diff --git a/docs/modules/spam_checker_callbacks.md b/docs/modules/spam_checker_callbacks.md index 81574a015c..7920ac5f8f 100644 --- a/docs/modules/spam_checker_callbacks.md +++ b/docs/modules/spam_checker_callbacks.md @@ -38,6 +38,35 @@ async def user_may_create_room(user: str) -> bool Called when processing a room creation request. The module must return a `bool` indicating whether the given user (represented by their Matrix user ID) is allowed to create a room. +### `user_may_create_room_with_invites` + +```python +async def user_may_create_room_with_invites( + user: str, + invites: List[str], + threepid_invites: List[Dict[str, str]], +) -> bool +``` + +Called when processing a room creation request (right after `user_may_create_room`). +The module is given the Matrix user ID of the user trying to create a room, as well as a +list of Matrix users to invite and a list of third-party identifiers (3PID, e.g. email +addresses) to invite. + +An invited Matrix user to invite is represented by their Matrix user IDs, and an invited +3PIDs is represented by a dict that includes the 3PID medium (e.g. "email") through its +`medium` key and its address (e.g. "alice@example.com") through its `address` key. + +See [the Matrix specification](https://matrix.org/docs/spec/appendices#pid-types) for more +information regarding third-party identifiers. + +If no invite and/or 3PID invite were specified in the room creation request, the +corresponding list(s) will be empty. + +**Note**: This callback is not called when a room is cloned (e.g. during a room upgrade) +since no invites are sent when cloning a room. To cover this case, modules also need to +implement `user_may_create_room`. + ### `user_may_create_room_alias` ```python diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 19ee246f96..c389f70b8d 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -46,6 +46,9 @@ CHECK_EVENT_FOR_SPAM_CALLBACK = Callable[ ] USER_MAY_INVITE_CALLBACK = Callable[[str, str, str], Awaitable[bool]] USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]] +USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK = Callable[ + [str, List[str], List[Dict[str, str]]], Awaitable[bool] +] USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]] USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]] CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]] @@ -164,6 +167,9 @@ class SpamChecker: self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = [] self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = [] self._user_may_create_room_callbacks: List[USER_MAY_CREATE_ROOM_CALLBACK] = [] + self._user_may_create_room_with_invites_callbacks: List[ + USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK + ] = [] self._user_may_create_room_alias_callbacks: List[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = [] @@ -183,6 +189,9 @@ class SpamChecker: check_event_for_spam: Optional[CHECK_EVENT_FOR_SPAM_CALLBACK] = None, user_may_invite: Optional[USER_MAY_INVITE_CALLBACK] = None, user_may_create_room: Optional[USER_MAY_CREATE_ROOM_CALLBACK] = None, + user_may_create_room_with_invites: Optional[ + USER_MAY_CREATE_ROOM_WITH_INVITES_CALLBACK + ] = None, user_may_create_room_alias: Optional[ USER_MAY_CREATE_ROOM_ALIAS_CALLBACK ] = None, @@ -203,6 +212,11 @@ class SpamChecker: if user_may_create_room is not None: self._user_may_create_room_callbacks.append(user_may_create_room) + if user_may_create_room_with_invites is not None: + self._user_may_create_room_with_invites_callbacks.append( + user_may_create_room_with_invites, + ) + if user_may_create_room_alias is not None: self._user_may_create_room_alias_callbacks.append( user_may_create_room_alias, @@ -283,6 +297,34 @@ class SpamChecker: return True + async def user_may_create_room_with_invites( + self, + userid: str, + invites: List[str], + threepid_invites: List[Dict[str, str]], + ) -> bool: + """Checks if a given user may create a room with invites + + If this method returns false, the creation request will be rejected. + + Args: + userid: The ID of the user attempting to create a room + invites: The IDs of the Matrix users to be invited if the room creation is + allowed. + threepid_invites: The threepids to be invited if the room creation is allowed, + as a dict including a "medium" key indicating the threepid's medium (e.g. + "email") and an "address" key indicating the threepid's address (e.g. + "alice@example.com") + + Returns: + True if the user may create the room, otherwise False + """ + for callback in self._user_may_create_room_with_invites_callbacks: + if await callback(userid, invites, threepid_invites) is False: + return False + + return True + async def user_may_create_room_alias( self, userid: str, room_alias: RoomAlias ) -> bool: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 408b7d7b74..8fede5e935 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -649,8 +649,16 @@ class RoomCreationHandler(BaseHandler): requester, config, is_requester_admin=is_requester_admin ) - if not is_requester_admin and not await self.spam_checker.user_may_create_room( - user_id + invite_3pid_list = config.get("invite_3pid", []) + invite_list = config.get("invite", []) + + if not is_requester_admin and not ( + await self.spam_checker.user_may_create_room(user_id) + and await self.spam_checker.user_may_create_room_with_invites( + user_id, + invite_list, + invite_3pid_list, + ) ): raise SynapseError(403, "You are not permitted to create rooms") @@ -684,8 +692,6 @@ class RoomCreationHandler(BaseHandler): if mapping: raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE) - invite_3pid_list = config.get("invite_3pid", []) - invite_list = config.get("invite", []) for i in invite_list: try: uid = UserID.from_string(i) diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index ef847f0f5f..30bdaa9c27 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -18,7 +18,7 @@ """Tests REST events for /rooms paths.""" import json -from typing import Iterable +from typing import Dict, Iterable, List, Optional from unittest.mock import Mock, call from urllib import parse as urlparse @@ -30,7 +30,7 @@ from synapse.api.errors import Codes, HttpResponseException from synapse.handlers.pagination import PurgeStatus from synapse.rest import admin from synapse.rest.client import account, directory, login, profile, room, sync -from synapse.types import JsonDict, RoomAlias, UserID, create_requester +from synapse.types import JsonDict, Requester, RoomAlias, UserID, create_requester from synapse.util.stringutils import random_string from tests import unittest @@ -669,6 +669,121 @@ class RoomsCreateTestCase(RoomBase): channel = self.make_request("POST", "/createRoom", content) self.assertEqual(200, channel.code) + def test_spamchecker_invites(self): + """Tests the user_may_create_room_with_invites spam checker callback.""" + + # Mock do_3pid_invite, so we don't fail from failing to send a 3PID invite to an + # IS. + async def do_3pid_invite( + room_id: str, + inviter: UserID, + medium: str, + address: str, + id_server: str, + requester: Requester, + txn_id: Optional[str], + id_access_token: Optional[str] = None, + ) -> int: + return 0 + + do_3pid_invite_mock = Mock(side_effect=do_3pid_invite) + self.hs.get_room_member_handler().do_3pid_invite = do_3pid_invite_mock + + # Add a mock callback for user_may_create_room_with_invites. Make it allow any + # room creation request for now. + return_value = True + + async def user_may_create_room_with_invites( + user: str, + invites: List[str], + threepid_invites: List[Dict[str, str]], + ) -> bool: + return return_value + + callback_mock = Mock(side_effect=user_may_create_room_with_invites) + self.hs.get_spam_checker()._user_may_create_room_with_invites_callbacks.append( + callback_mock, + ) + + # The MXIDs we'll try to invite. + invited_mxids = [ + "@alice1:red", + "@alice2:red", + "@alice3:red", + "@alice4:red", + ] + + # The 3PIDs we'll try to invite. + invited_3pids = [ + { + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": "alice1@example.com", + }, + { + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": "alice2@example.com", + }, + { + "id_server": "example.com", + "id_access_token": "sometoken", + "medium": "email", + "address": "alice3@example.com", + }, + ] + + # Create a room and invite the Matrix users, and check that it succeeded. + channel = self.make_request( + "POST", + "/createRoom", + json.dumps({"invite": invited_mxids}).encode("utf8"), + ) + self.assertEqual(200, channel.code) + + # Check that the callback was called with the right arguments. + expected_call_args = ((self.user_id, invited_mxids, []),) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Create a room and invite the 3PIDs, and check that it succeeded. + channel = self.make_request( + "POST", + "/createRoom", + json.dumps({"invite_3pid": invited_3pids}).encode("utf8"), + ) + self.assertEqual(200, channel.code) + + # Check that do_3pid_invite was called the right amount of time + self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) + + # Check that the callback was called with the right arguments. + expected_call_args = ((self.user_id, [], invited_3pids),) + self.assertEquals( + callback_mock.call_args, + expected_call_args, + callback_mock.call_args, + ) + + # Now deny any room creation. + return_value = False + + # Create a room and invite the 3PIDs, and check that it failed. + channel = self.make_request( + "POST", + "/createRoom", + json.dumps({"invite_3pid": invited_3pids}).encode("utf8"), + ) + self.assertEqual(403, channel.code) + + # Check that do_3pid_invite wasn't called this time. + self.assertEquals(do_3pid_invite_mock.call_count, len(invited_3pids)) + class RoomTopicTestCase(RoomBase): """Tests /rooms/$room_id/topic REST events.""" -- cgit 1.5.1 From f7768f62cbf7579a1a91e694f83d47d275373369 Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Mon, 27 Sep 2021 12:55:27 +0100 Subject: Avoid storing URL cache files in storage providers (#10911) URL cache files are short-lived and it does not make sense to offload them (eg. to the cloud) or back them up. --- changelog.d/10911.bugfix | 1 + docs/upgrade.md | 7 ++ synapse/rest/media/v1/filepath.py | 11 ++- synapse/rest/media/v1/preview_url_resource.py | 1 - synapse/rest/media/v1/storage_provider.py | 10 ++ tests/rest/media/v1/test_url_preview.py | 130 ++++++++++++++++++++++++++ 6 files changed, 154 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10911.bugfix (limited to 'tests/rest') diff --git a/changelog.d/10911.bugfix b/changelog.d/10911.bugfix new file mode 100644 index 0000000000..96e36bb15a --- /dev/null +++ b/changelog.d/10911.bugfix @@ -0,0 +1 @@ +Avoid storing URL cache files in storage providers. Server admins may safely delete the `url_cache/` and `url_cache_thumbnails/` directories from any configured storage providers to reclaim space. diff --git a/docs/upgrade.md b/docs/upgrade.md index f9b832cb3f..a8221372df 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -85,6 +85,13 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.44.0 + +## The URL preview cache is no longer mirrored to storage providers +The `url_cache/` and `url_cache_thumbnails/` directories in the media store are +no longer mirrored to storage providers. These two directories can be safely +deleted from any configured storage providers to reclaim space. + # Upgrading to v1.43.0 ## The spaces summary APIs can now be handled by workers diff --git a/synapse/rest/media/v1/filepath.py b/synapse/rest/media/v1/filepath.py index 39bbe4e874..08bd85f664 100644 --- a/synapse/rest/media/v1/filepath.py +++ b/synapse/rest/media/v1/filepath.py @@ -195,23 +195,24 @@ class MediaFilePaths: url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel) - def url_cache_thumbnail_directory(self, media_id: str) -> str: + def url_cache_thumbnail_directory_rel(self, media_id: str) -> str: # Media id is of the form # E.g.: 2017-09-28-fsdRDt24DS234dsf if NEW_FORMAT_ID_RE.match(media_id): - return os.path.join( - self.base_path, "url_cache_thumbnails", media_id[:10], media_id[11:] - ) + return os.path.join("url_cache_thumbnails", media_id[:10], media_id[11:]) else: return os.path.join( - self.base_path, "url_cache_thumbnails", media_id[0:2], media_id[2:4], media_id[4:], ) + url_cache_thumbnail_directory = _wrap_in_base_path( + url_cache_thumbnail_directory_rel + ) + def url_cache_thumbnail_dirs_to_delete(self, media_id: str) -> List[str]: "The dirs to try and remove if we delete the media_id thumbnails" # Media id is of the form diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 0b0c4d6469..79a42b2455 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -485,7 +485,6 @@ class PreviewUrlResource(DirectServeJsonResource): async def _expire_url_cache_data(self) -> None: """Clean up expired url cache content, media and thumbnails.""" - # TODO: Delete from backup media store assert self._worker_run_media_background_jobs diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py index da78fcee5e..18bf977d3d 100644 --- a/synapse/rest/media/v1/storage_provider.py +++ b/synapse/rest/media/v1/storage_provider.py @@ -93,6 +93,11 @@ class StorageProviderWrapper(StorageProvider): if file_info.server_name and not self.store_remote: return None + if file_info.url_cache: + # The URL preview cache is short lived and not worth offloading or + # backing up. + return None + if self.store_synchronous: # store_file is supposed to return an Awaitable, but guard # against improper implementations. @@ -110,6 +115,11 @@ class StorageProviderWrapper(StorageProvider): run_in_background(store) async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]: + if file_info.url_cache: + # Files in the URL preview cache definitely aren't stored here, + # so avoid any potentially slow I/O or network access. + return None + # store_file is supposed to return an Awaitable, but guard # against improper implementations. return await maybe_awaitable(self.backend.fetch(path, file_info)) diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index d83dfacfed..4d09b5d07e 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -21,6 +21,7 @@ from twisted.internet.error import DNSLookupError from twisted.test.proto_helpers import AccumulatingProtocol from synapse.config.oembed import OEmbedEndpointConfig +from synapse.util.stringutils import parse_and_validate_mxc_uri from tests import unittest from tests.server import FakeTransport @@ -721,3 +722,132 @@ class URLPreviewTests(unittest.HomeserverTestCase): "og:description": "Content Preview", }, ) + + def _download_image(self): + """Downloads an image into the URL cache. + + Returns: + A (host, media_id) tuple representing the MXC URI of the image. + """ + self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")] + + channel = self.make_request( + "GET", + "preview_url?url=http://cdn.twitter.com/matrixdotorg", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\nContent-Type: image/png\r\n\r\n" + % (len(SMALL_PNG),) + + SMALL_PNG + ) + + self.pump() + self.assertEqual(channel.code, 200) + body = channel.json_body + mxc_uri = body["og:image"] + host, _port, media_id = parse_and_validate_mxc_uri(mxc_uri) + self.assertIsNone(_port) + return host, media_id + + def test_storage_providers_exclude_files(self): + """Test that files are not stored in or fetched from storage providers.""" + host, media_id = self._download_image() + + rel_file_path = self.preview_url.filepaths.url_cache_filepath_rel(media_id) + media_store_path = os.path.join(self.media_store_path, rel_file_path) + storage_provider_path = os.path.join(self.storage_path, rel_file_path) + + # Check storage + self.assertTrue(os.path.isfile(media_store_path)) + self.assertFalse( + os.path.isfile(storage_provider_path), + "URL cache file was unexpectedly stored in a storage provider", + ) + + # Check fetching + channel = self.make_request( + "GET", + f"download/{host}/{media_id}", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 200) + + # Move cached file into the storage provider + os.makedirs(os.path.dirname(storage_provider_path), exist_ok=True) + os.rename(media_store_path, storage_provider_path) + + channel = self.make_request( + "GET", + f"download/{host}/{media_id}", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual( + channel.code, + 404, + "URL cache file was unexpectedly retrieved from a storage provider", + ) + + def test_storage_providers_exclude_thumbnails(self): + """Test that thumbnails are not stored in or fetched from storage providers.""" + host, media_id = self._download_image() + + rel_thumbnail_path = ( + self.preview_url.filepaths.url_cache_thumbnail_directory_rel(media_id) + ) + media_store_thumbnail_path = os.path.join( + self.media_store_path, rel_thumbnail_path + ) + storage_provider_thumbnail_path = os.path.join( + self.storage_path, rel_thumbnail_path + ) + + # Check storage + self.assertTrue(os.path.isdir(media_store_thumbnail_path)) + self.assertFalse( + os.path.isdir(storage_provider_thumbnail_path), + "URL cache thumbnails were unexpectedly stored in a storage provider", + ) + + # Check fetching + channel = self.make_request( + "GET", + f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual(channel.code, 200) + + # Remove the original, otherwise thumbnails will regenerate + rel_file_path = self.preview_url.filepaths.url_cache_filepath_rel(media_id) + media_store_path = os.path.join(self.media_store_path, rel_file_path) + os.remove(media_store_path) + + # Move cached thumbnails into the storage provider + os.makedirs(os.path.dirname(storage_provider_thumbnail_path), exist_ok=True) + os.rename(media_store_thumbnail_path, storage_provider_thumbnail_path) + + channel = self.make_request( + "GET", + f"thumbnail/{host}/{media_id}?width=32&height=32&method=scale", + shorthand=False, + await_result=False, + ) + self.pump() + self.assertEqual( + channel.code, + 404, + "URL cache thumbnail was unexpectedly retrieved from a storage provider", + ) -- cgit 1.5.1