From 94b620a5edd6b5bc55c8aad6e00a11cc6bf210fa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Sep 2021 06:44:15 -0400 Subject: Use direct references for configuration variables (part 6). (#10916) --- synapse/handlers/profile.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/handlers/profile.py') diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index b23a1541bc..425c0d4973 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -397,7 +397,7 @@ class ProfileHandler(BaseHandler): # when building a membership event. In this case, we must allow the # lookup. if ( - not self.hs.config.limit_profile_requests_to_users_who_share_rooms + not self.hs.config.server.limit_profile_requests_to_users_who_share_rooms or not requester ): return -- cgit 1.5.1 From a0f48ee89d88fd7b6da8023dbba607a69073152e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 4 Oct 2021 07:18:54 -0400 Subject: Use direct references for configuration variables (part 7). (#10959) --- changelog.d/10959.misc | 1 + synapse/handlers/auth.py | 2 +- synapse/handlers/identity.py | 13 ++++++++++--- synapse/handlers/profile.py | 4 ++-- synapse/handlers/register.py | 9 ++++++--- synapse/handlers/room_member.py | 2 +- synapse/handlers/ui_auth/checkers.py | 14 ++++++++------ synapse/rest/admin/users.py | 4 ++-- synapse/rest/client/account.py | 22 +++++++++++----------- synapse/rest/client/auth.py | 6 ++++-- synapse/rest/client/capabilities.py | 6 +++--- synapse/rest/client/login.py | 6 +++--- synapse/rest/client/register.py | 26 +++++++++++++------------- synapse/rest/well_known.py | 4 ++-- synapse/storage/databases/main/registration.py | 2 +- synapse/util/threepids.py | 4 ++-- tests/config/test_load.py | 6 +++--- tests/handlers/test_profile.py | 4 ++-- tests/rest/admin/test_user.py | 4 ++-- tests/rest/client/test_account.py | 4 ++-- tests/rest/client/test_identity.py | 2 +- tests/rest/client/test_register.py | 4 ++-- tests/unittest.py | 2 +- 23 files changed, 83 insertions(+), 68 deletions(-) create mode 100644 changelog.d/10959.misc (limited to 'synapse/handlers/profile.py') diff --git a/changelog.d/10959.misc b/changelog.d/10959.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10959.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a8c717efd5..2d0f3d566c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -198,7 +198,7 @@ class AuthHandler(BaseHandler): if inst.is_enabled(): self.checkers[inst.AUTH_TYPE] = inst # type: ignore - self.bcrypt_rounds = hs.config.bcrypt_rounds + self.bcrypt_rounds = hs.config.registration.bcrypt_rounds # we can't use hs.get_module_api() here, because to do so will create an # import loop. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index a0640fcac0..c881475c25 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -573,9 +573,15 @@ class IdentityHandler(BaseHandler): # Try to validate as email if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Remote emails will only be used if a valid identity server is provided. + assert ( + self.hs.config.registration.account_threepid_delegate_email is not None + ) + # Ask our delegated email identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details @@ -587,10 +593,11 @@ class IdentityHandler(BaseHandler): return validation_session # Try to validate as msisdn - if self.hs.config.account_threepid_delegate_msisdn: + if self.hs.config.registration.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) return validation_session diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 425c0d4973..2e19706c69 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -178,7 +178,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: + if not by_admin and not self.hs.config.registration.enable_set_displayname: profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( @@ -268,7 +268,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: + if not by_admin and not self.hs.config.registration.enable_set_avatar_url: profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cb4eb0720b..441af7a848 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,8 +116,8 @@ class RegistrationHandler(BaseHandler): self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() - self.session_lifetime = hs.config.session_lifetime - self.access_token_lifetime = hs.config.access_token_lifetime + self.session_lifetime = hs.config.registration.session_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime init_counters_for_auth_provider("") @@ -343,7 +343,10 @@ class RegistrationHandler(BaseHandler): # If the user does not need to consent at registration, auto-join any # configured rooms. if not self.hs.config.consent.user_consent_at_registration: - if not self.hs.config.auto_join_rooms_for_guests and make_guest: + if ( + not self.hs.config.registration.auto_join_rooms_for_guests + and make_guest + ): logger.info( "Skipping auto-join for %s because auto-join for guests is disabled", user_id, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 29b3e41cc9..c8fb24a20c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -89,7 +89,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid - self._enable_lookup = hs.config.enable_3pid_lookup + self._enable_lookup = hs.config.registration.enable_3pid_lookup self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8f5d465fa1..184730ebe8 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -153,21 +153,23 @@ class _BaseThreepidAuthChecker: # msisdns are currently always ThreepidBehaviour.REMOTE if medium == "msisdn": - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "Phone number verification is not enabled on this homeserver" ) threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) elif medium == "email": if ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE ): - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL @@ -240,7 +242,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return bool(self.hs.config.account_threepid_delegate_msisdn) + return bool(self.hs.config.registration.account_threepid_delegate_msisdn) async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("msisdn", authdict) @@ -252,7 +254,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self._enabled = bool(hs.config.registration_requires_token) + self._enabled = bool(hs.config.registration.registration_requires_token) self.store = hs.get_datastore() def is_enabled(self) -> bool: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 46bfec4623..f20aa65301 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -442,7 +442,7 @@ class UserRegisterServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() - if not self.hs.config.registration_shared_secret: + if not self.hs.config.registration.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") body = parse_json_object_from_request(request) @@ -498,7 +498,7 @@ class UserRegisterServlet(RestServlet): got_mac = body["mac"] want_mac_builder = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), + key=self.hs.config.registration.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) want_mac_builder.update(nonce.encode("utf8")) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index fff133ef10..6b272658fc 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -130,11 +130,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -414,11 +414,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -507,7 +507,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -604,7 +604,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - if not self.config.account_threepid_delegate_msisdn: + if not self.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "This homeserver is not validating phone numbers. Use an identity server " @@ -617,7 +617,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.account_threepid_delegate_msisdn, + self.config.registration.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], body["token"], @@ -644,7 +644,7 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -693,7 +693,7 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -801,7 +801,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 282861fae2..c9ad35a3ad 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -49,8 +49,10 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template self.terms_template = hs.config.terms_template - self.registration_token_template = hs.config.registration_token_template - self.success_template = hs.config.fallback_success_template + self.registration_token_template = ( + hs.config.registration.registration_token_template + ) + self.success_template = hs.config.registration.fallback_success_template async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index d6b6256413..2a3e24ae7e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -64,13 +64,13 @@ class CapabilitiesRestServlet(RestServlet): if self.config.experimental.msc3283_enabled: response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.enable_set_displayname + "enabled": self.config.registration.enable_set_displayname } response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.enable_set_avatar_url + "enabled": self.config.registration.enable_set_avatar_url } response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.enable_3pid_changes + "enabled": self.config.registration.enable_3pid_changes } return 200, response diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index fa5c173f4b..d49a647b03 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -79,7 +79,7 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self.auth = hs.get_auth() @@ -447,7 +447,7 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.access_token_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -556,7 +556,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.access_token_lifetime is not None: + if hs.config.registration.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index a6eb6f6410..bf3cb34146 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -140,11 +140,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -221,7 +221,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -231,7 +231,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -341,7 +341,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -391,7 +391,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -419,8 +419,8 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() - self._registration_enabled = self.hs.config.enable_registration - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._registration_enabled = self.hs.config.registration.enable_registration + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -800,7 +800,7 @@ class RegisterRestServlet(RestServlet): async def _do_guest_registration( self, params: JsonDict, address: Optional[str] = None ) -> Tuple[int, JsonDict]: - if not self.hs.config.allow_guest_access: + if not self.hs.config.registration.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( make_guest=True, address=address @@ -849,13 +849,13 @@ def _calculate_registration_flows( """ # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registrations_require_3pid - require_msisdn = "msisdn" in config.registrations_require_3pid + require_email = "email" in config.registration.registrations_require_3pid + require_msisdn = "msisdn" in config.registration.registrations_require_3pid show_msisdn = True show_email = True - if config.disable_msisdn_registration: + if config.registration.disable_msisdn_registration: show_msisdn = False require_msisdn = False @@ -909,7 +909,7 @@ def _calculate_registration_flows( flow.insert(0, LoginType.RECAPTCHA) # Prepend registration token to all flows if we're requiring a token - if config.registration_requires_token: + if config.registration.registration_requires_token: for flow in flows: flow.insert(0, LoginType.REGISTRATION_TOKEN) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index c80a3a99aa..7ac01faab4 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -39,9 +39,9 @@ class WellKnownBuilder: result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} - if self._config.default_identity_server: + if self._config.registration.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server + "base_url": self._config.registration.default_identity_server } return result diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7279b0924e..de262fbf5a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): We do this by grandfathering in existing user threepids assuming that they used one of the server configured trusted identity servers. """ - id_servers = set(self.config.trusted_third_party_id_servers) + id_servers = set(self.config.registration.trusted_third_party_id_servers) def _bg_user_threepids_grandfather_txn(txn): sql = """ diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index baa9190a9a..389adf00f6 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -44,8 +44,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: bool: whether the 3PID medium/address is allowed to be added to this HS """ - if hs.config.allowed_local_3pids: - for constraint in hs.config.allowed_local_3pids: + if hs.config.registration.allowed_local_3pids: + for constraint in hs.config.registration.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", address, diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ef6c2beec7..8e49ca26d9 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -84,16 +84,16 @@ class ConfigLoadingTestCase(unittest.TestCase): ) # Check that disable_registration clobbers enable_registration. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( "", ["-c", self.file, "--enable-registration"] ) - self.assertTrue(config.enable_registration) + self.assertTrue(config.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics") diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 57cc3e2646..c153018fd8 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False + self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed self.get_success( @@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False + self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed self.get_success( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a285d5a7fe..6ed9e42173 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): If there is no shared secret, registration through this method will be prevented. """ - self.hs.config.registration_shared_secret = None + self.hs.config.registration.registration_shared_secret = None channel = self.make_request("POST", self.url, b"{}") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 2f44547bfb..89d85b0a17 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -664,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -734,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False # Add a threepid self.get_success( diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index ca2e8ff8ef..becb4e8dcc 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index af135d57e1..66dcfc9f88 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True + self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False + self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/unittest.py b/tests/unittest.py index 0807467e39..1f803564f6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -560,7 +560,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user. """ - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" # Create the user channel = self.make_request("GET", "/_synapse/admin/v1/register") -- cgit 1.5.1 From eb9ddc8c2e807e691fd1820f88f7c0bf43822661 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 8 Oct 2021 07:44:43 -0400 Subject: Remove the deprecated BaseHandler. (#11005) The shared ratelimit function was replaced with a dedicated RequestRatelimiter class (accessible from the HomeServer object). Other properties were copied to each sub-class that inherited from BaseHandler. --- changelog.d/11005.misc | 1 + synapse/api/ratelimiting.py | 86 +++++++++++++++++++++++ synapse/handlers/_base.py | 120 --------------------------------- synapse/handlers/admin.py | 7 +- synapse/handlers/auth.py | 8 +-- synapse/handlers/deactivate_account.py | 6 +- synapse/handlers/device.py | 10 +-- synapse/handlers/directory.py | 9 ++- synapse/handlers/events.py | 12 ++-- synapse/handlers/federation.py | 6 +- synapse/handlers/identity.py | 7 +- synapse/handlers/initial_sync.py | 8 +-- synapse/handlers/message.py | 7 +- synapse/handlers/profile.py | 11 +-- synapse/handlers/read_marker.py | 5 +- synapse/handlers/receipts.py | 6 +- synapse/handlers/register.py | 9 ++- synapse/handlers/room.py | 15 +++-- synapse/handlers/room_list.py | 7 +- synapse/handlers/room_member.py | 8 +-- synapse/handlers/saml.py | 7 +- synapse/handlers/search.py | 9 +-- synapse/handlers/set_password.py | 6 +- synapse/server.py | 11 ++- 24 files changed, 166 insertions(+), 215 deletions(-) create mode 100644 changelog.d/11005.misc delete mode 100644 synapse/handlers/_base.py (limited to 'synapse/handlers/profile.py') diff --git a/changelog.d/11005.misc b/changelog.d/11005.misc new file mode 100644 index 0000000000..a893591971 --- /dev/null +++ b/changelog.d/11005.misc @@ -0,0 +1 @@ +Remove the deprecated `BaseHandler` object. diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index cbdd74025b..e8964097d3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -17,6 +17,7 @@ from collections import OrderedDict from typing import Hashable, Optional, Tuple from synapse.api.errors import LimitExceededError +from synapse.config.ratelimiting import RateLimitConfig from synapse.storage.databases.main import DataStore from synapse.types import Requester from synapse.util import Clock @@ -233,3 +234,88 @@ class Ratelimiter: raise LimitExceededError( retry_after_ms=int(1000 * (time_allowed - time_now_s)) ) + + +class RequestRatelimiter: + def __init__( + self, + store: DataStore, + clock: Clock, + rc_message: RateLimitConfig, + rc_admin_redaction: Optional[RateLimitConfig], + ): + self.store = store + self.clock = clock + + # The rate_hz and burst_count are overridden on a per-user basis + self.request_ratelimiter = Ratelimiter( + store=self.store, clock=self.clock, rate_hz=0, burst_count=0 + ) + self._rc_message = rc_message + + # Check whether ratelimiting room admin message redaction is enabled + # by the presence of rate limits in the config + if rc_admin_redaction: + self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( + store=self.store, + clock=self.clock, + rate_hz=rc_admin_redaction.per_second, + burst_count=rc_admin_redaction.burst_count, + ) + else: + self.admin_redaction_ratelimiter = None + + async def ratelimit( + self, + requester: Requester, + update: bool = True, + is_admin_redaction: bool = False, + ) -> None: + """Ratelimits requests. + + Args: + requester + update: Whether to record that a request is being processed. + Set to False when doing multiple checks for one request (e.g. + to check up front if we would reject the request), and set to + True for the last call for a given request. + is_admin_redaction: Whether this is a room admin/moderator + redacting an event. If so then we may apply different + ratelimits depending on config. + + Raises: + LimitExceededError if the request should be ratelimited + """ + user_id = requester.user.to_string() + + # The AS user itself is never rate limited. + app_service = self.store.get_app_service_by_user_id(user_id) + if app_service is not None: + return # do not ratelimit app service senders + + messages_per_second = self._rc_message.per_second + burst_count = self._rc_message.burst_count + + # Check if there is a per user override in the DB. + override = await self.store.get_ratelimit_for_user(user_id) + if override: + # If overridden with a null Hz then ratelimiting has been entirely + # disabled for the user + if not override.messages_per_second: + return + + messages_per_second = override.messages_per_second + burst_count = override.burst_count + + if is_admin_redaction and self.admin_redaction_ratelimiter: + # If we have separate config for admin redactions, use a separate + # ratelimiter as to not have user_ids clash + await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) + else: + # Override rate and burst count per-user + await self.request_ratelimiter.ratelimit( + requester, + rate_hz=messages_per_second, + burst_count=burst_count, + update=update, + ) diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py deleted file mode 100644 index 0ccef884e7..0000000000 --- a/synapse/handlers/_base.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright 2014 - 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Optional - -from synapse.api.ratelimiting import Ratelimiter -from synapse.types import Requester - -if TYPE_CHECKING: - from synapse.server import HomeServer - -logger = logging.getLogger(__name__) - - -class BaseHandler: - """ - Common base class for the event handlers. - - Deprecated: new code should not use this. Instead, Handler classes should define the - fields they actually need. The utility methods should either be factored out to - standalone helper functions, or to different Handler classes. - """ - - def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() - self.auth = hs.get_auth() - self.notifier = hs.get_notifier() - self.state_handler = hs.get_state_handler() - self.distributor = hs.get_distributor() - self.clock = hs.get_clock() - self.hs = hs - - # The rate_hz and burst_count are overridden on a per-user basis - self.request_ratelimiter = Ratelimiter( - store=self.store, clock=self.clock, rate_hz=0, burst_count=0 - ) - self._rc_message = self.hs.config.ratelimiting.rc_message - - # Check whether ratelimiting room admin message redaction is enabled - # by the presence of rate limits in the config - if self.hs.config.ratelimiting.rc_admin_redaction: - self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( - store=self.store, - clock=self.clock, - rate_hz=self.hs.config.ratelimiting.rc_admin_redaction.per_second, - burst_count=self.hs.config.ratelimiting.rc_admin_redaction.burst_count, - ) - else: - self.admin_redaction_ratelimiter = None - - self.server_name = hs.hostname - - self.event_builder_factory = hs.get_event_builder_factory() - - async def ratelimit( - self, - requester: Requester, - update: bool = True, - is_admin_redaction: bool = False, - ) -> None: - """Ratelimits requests. - - Args: - requester - update: Whether to record that a request is being processed. - Set to False when doing multiple checks for one request (e.g. - to check up front if we would reject the request), and set to - True for the last call for a given request. - is_admin_redaction: Whether this is a room admin/moderator - redacting an event. If so then we may apply different - ratelimits depending on config. - - Raises: - LimitExceededError if the request should be ratelimited - """ - user_id = requester.user.to_string() - - # The AS user itself is never rate limited. - app_service = self.store.get_app_service_by_user_id(user_id) - if app_service is not None: - return # do not ratelimit app service senders - - messages_per_second = self._rc_message.per_second - burst_count = self._rc_message.burst_count - - # Check if there is a per user override in the DB. - override = await self.store.get_ratelimit_for_user(user_id) - if override: - # If overridden with a null Hz then ratelimiting has been entirely - # disabled for the user - if not override.messages_per_second: - return - - messages_per_second = override.messages_per_second - burst_count = override.burst_count - - if is_admin_redaction and self.admin_redaction_ratelimiter: - # If we have separate config for admin redactions, use a separate - # ratelimiter as to not have user_ids clash - await self.admin_redaction_ratelimiter.ratelimit(requester, update=update) - else: - # Override rate and burst count per-user - await self.request_ratelimiter.ratelimit( - requester, - rate_hz=messages_per_second, - burst_count=burst_count, - update=update, - ) diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py index bfa7f2c545..a53cd62d3c 100644 --- a/synapse/handlers/admin.py +++ b/synapse/handlers/admin.py @@ -21,18 +21,15 @@ from synapse.events import EventBase from synapse.types import JsonDict, RoomStreamToken, StateMap, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class AdminHandler(BaseHandler): +class AdminHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 2d0f3d566c..f4612a5b92 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -52,7 +52,6 @@ from synapse.api.errors import ( UserDeactivatedError, ) from synapse.api.ratelimiting import Ratelimiter -from synapse.handlers._base import BaseHandler from synapse.handlers.ui_auth import ( INTERACTIVE_AUTH_CHECKERS, UIAuthSessionDataConstants, @@ -186,12 +185,13 @@ class LoginTokenAttributes: auth_provider_id = attr.ib(type=str) -class AuthHandler(BaseHandler): +class AuthHandler: SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() self.checkers: Dict[str, UserInteractiveAuthChecker] = {} for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index 12bdca7445..e88c3c27ce 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -19,19 +19,17 @@ from synapse.api.errors import SynapseError from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import Requester, UserID, create_requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class DeactivateAccountHandler(BaseHandler): +class DeactivateAccountHandler: """Handler which deals with deactivating user accounts.""" def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self.hs = hs self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 35334725d7..75e6019760 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -40,8 +40,6 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.util.retryutils import NotRetryingDestination -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -50,14 +48,16 @@ logger = logging.getLogger(__name__) MAX_DEVICE_DISPLAY_NAME_LEN = 100 -class DeviceWorkerHandler(BaseHandler): +class DeviceWorkerHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.clock = hs.get_clock() self.hs = hs + self.store = hs.get_datastore() + self.notifier = hs.get_notifier() self.state = hs.get_state_handler() self.state_store = hs.get_storage().state self._auth_handler = hs.get_auth_handler() + self.server_name = hs.hostname @trace async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 9078781d5a..14ed7d9879 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -31,18 +31,16 @@ from synapse.appservice import ApplicationService from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class DirectoryHandler(BaseHandler): +class DirectoryHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.auth = hs.get_auth() + self.hs = hs self.state = hs.get_state_handler() self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() @@ -51,6 +49,7 @@ class DirectoryHandler(BaseHandler): self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases self.third_party_event_rules = hs.get_third_party_event_rules() + self.server_name = hs.hostname self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 4b3f037072..1f64534a8a 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -25,8 +25,6 @@ from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -34,11 +32,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class EventStreamHandler(BaseHandler): +class EventStreamHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() self.clock = hs.get_clock() + self.hs = hs self.notifier = hs.get_notifier() self.state = hs.get_state_handler() @@ -138,9 +136,9 @@ class EventStreamHandler(BaseHandler): return chunk -class EventHandler(BaseHandler): +class EventHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self.storage = hs.get_storage() async def get_event( diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 043ca4a224..3e341bd287 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -53,7 +53,6 @@ from synapse.events import EventBase from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.federation.federation_client import InvalidResponseError -from synapse.handlers._base import BaseHandler from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -78,15 +77,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class FederationHandler(BaseHandler): +class FederationHandler: """Handles general incoming federation requests Incoming events are *not* handled here, for which see FederationEventHandler. """ def __init__(self, hs: "HomeServer"): - super().__init__(hs) - self.hs = hs self.store = hs.get_datastore() @@ -99,6 +96,7 @@ class FederationHandler(BaseHandler): self.is_mine_id = hs.is_mine_id self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() + self.event_builder_factory = hs.get_event_builder_factory() self._event_auth_handler = hs.get_event_auth_handler() self._server_notices_mxid = hs.config.servernotices.server_notices_mxid self.config = hs.config diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c881475c25..9c319b5383 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -39,8 +39,6 @@ from synapse.util.stringutils import ( valid_id_server_location, ) -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -49,10 +47,9 @@ logger = logging.getLogger(__name__) id_server_scheme = "https://" -class IdentityHandler(BaseHandler): +class IdentityHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() # An HTTP client for contacting trusted URLs. self.http_client = SimpleHttpClient(hs) # An HTTP client for contacting identity servers specified by clients. diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index 9ad39a65d8..d4e4556155 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -31,8 +31,6 @@ from synapse.util.async_helpers import concurrently_execute from synapse.util.caches.response_cache import ResponseCache from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -40,9 +38,11 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class InitialSyncHandler(BaseHandler): +class InitialSyncHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.state_handler = hs.get_state_handler() self.hs = hs self.state = hs.get_state_handler() self.clock = hs.get_clock() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ccd7827207..4de9f4b828 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -62,8 +62,6 @@ from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.events.third_party_rules import ThirdPartyEventRules from synapse.server import HomeServer @@ -433,8 +431,7 @@ class EventCreationHandler: self.send_event = ReplicationSendEventRestServlet.make_client(hs) - # This is only used to get at ratelimit function - self.base_handler = BaseHandler(hs) + self.request_ratelimiter = hs.get_request_ratelimiter() # We arbitrarily limit concurrent event creation for a room to 5. # This is to stop us from diverging history *too* much. @@ -1322,7 +1319,7 @@ class EventCreationHandler: original_event and event.sender != original_event.sender ) - await self.base_handler.ratelimit( + await self.request_ratelimiter.ratelimit( requester, is_admin_redaction=is_admin_redaction ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 2e19706c69..e6c3cf585b 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -32,8 +32,6 @@ from synapse.types import ( get_domain_from_id, ) -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -43,7 +41,7 @@ MAX_DISPLAYNAME_LEN = 256 MAX_AVATAR_URL_LEN = 1000 -class ProfileHandler(BaseHandler): +class ProfileHandler: """Handles fetching and updating user profile information. ProfileHandler can be instantiated directly on workers and will @@ -54,7 +52,9 @@ class ProfileHandler(BaseHandler): PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.hs = hs self.federation = hs.get_federation_client() hs.get_federation_registry().register_query_handler( @@ -62,6 +62,7 @@ class ProfileHandler(BaseHandler): ) self.user_directory_handler = hs.get_user_directory_handler() + self.request_ratelimiter = hs.get_request_ratelimiter() if hs.config.worker.run_background_tasks: self.clock.looping_call( @@ -346,7 +347,7 @@ class ProfileHandler(BaseHandler): if not self.hs.is_mine(target_user): return - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) # Do not actually update the room state for shadow-banned users. if requester.shadow_banned: diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index bd8160e7ed..58593e570e 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -17,17 +17,14 @@ from typing import TYPE_CHECKING from synapse.util.async_helpers import Linearizer -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class ReadMarkerHandler(BaseHandler): +class ReadMarkerHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) self.server_name = hs.config.server.server_name self.store = hs.get_datastore() self.account_data_handler = hs.get_account_data_handler() diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index f21f33ada2..374e961e3b 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse.api.constants import ReadReceiptEventFields from synapse.appservice import ApplicationService -from synapse.handlers._base import BaseHandler from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -26,10 +25,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class ReceiptsHandler(BaseHandler): +class ReceiptsHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.notifier = hs.get_notifier() self.server_name = hs.config.server.server_name self.store = hs.get_datastore() self.event_auth_handler = hs.get_event_auth_handler() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 441af7a848..a0e6a01775 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -41,8 +41,6 @@ from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -85,9 +83,10 @@ class LoginDict(TypedDict): refresh_token: Optional[str] -class RegistrationHandler(BaseHandler): +class RegistrationHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -515,7 +514,7 @@ class RegistrationHandler(BaseHandler): # we don't have a local user in the room to craft up an invite with. requires_invite = await self.store.is_host_joined( room_id, - self.server_name, + self._server_name, ) if requires_invite: diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index d40dbd761d..7072bca1fc 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -76,8 +76,6 @@ from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -88,15 +86,18 @@ id_server_scheme = "https://" FIVE_MINUTES_IN_MS = 5 * 60 * 1000 -class RoomCreationHandler(BaseHandler): +class RoomCreationHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) - + self.store = hs.get_datastore() + self.auth = hs.get_auth() + self.clock = hs.get_clock() + self.hs = hs self.spam_checker = hs.get_spam_checker() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self._event_auth_handler = hs.get_event_auth_handler() self.config = hs.config + self.request_ratelimiter = hs.get_request_ratelimiter() # Room state based off defined presets self._presets_dict: Dict[str, Dict[str, Any]] = { @@ -162,7 +163,7 @@ class RoomCreationHandler(BaseHandler): Raises: ShadowBanError if the requester is shadow-banned. """ - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) user_id = requester.user.to_string() @@ -665,7 +666,7 @@ class RoomCreationHandler(BaseHandler): raise SynapseError(403, "You are not permitted to create rooms") if ratelimit: - await self.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) room_version_id = config.get( "room_version", self.config.server.default_room_version.identifier diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index c3d4199ed1..ba7a14d651 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -36,8 +36,6 @@ from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -49,9 +47,10 @@ REMOTE_ROOM_LIST_POLL_INTERVAL = 60 * 1000 EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) -class RoomListHandler(BaseHandler): +class RoomListHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ Tuple[Optional[int], Optional[str], Optional[ThirdPartyInstanceID]] diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index eef337feeb..74e6c7eca6 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -51,8 +51,6 @@ from synapse.types import ( from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_left_room -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer @@ -118,9 +116,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, ) - # This is only used to get at the ratelimit function. It's fine there are - # multiple of these as it doesn't store state. - self.base_handler = BaseHandler(hs) + self.request_ratelimiter = hs.get_request_ratelimiter() @abc.abstractmethod async def _remote_join( @@ -1275,7 +1271,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # We need to rate limit *before* we send out any 3PID invites, so we # can't just rely on the standard ratelimiting of events. - await self.base_handler.ratelimit(requester) + await self.request_ratelimiter.ratelimit(requester) can_invite = await self.third_party_event_rules.check_threepid_can_be_invited( medium, address, room_id diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py index 2fed9f377a..727d75a50c 100644 --- a/synapse/handlers/saml.py +++ b/synapse/handlers/saml.py @@ -22,7 +22,6 @@ from saml2.client import Saml2Client from synapse.api.errors import SynapseError from synapse.config import ConfigError -from synapse.handlers._base import BaseHandler from synapse.handlers.sso import MappingException, UserAttributes from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest @@ -51,9 +50,11 @@ class Saml2SessionData: ui_auth_session_id: Optional[str] = None -class SamlHandler(BaseHandler): +class SamlHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self.server_name = hs.hostname self._saml_client = Saml2Client(hs.config.saml2.saml2_sp_config) self._saml_idp_entityid = hs.config.saml2.saml2_idp_entityid diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 6d3333ee00..a3ffa26be8 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -26,17 +26,18 @@ from synapse.storage.state import StateFilter from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class SearchHandler(BaseHandler): +class SearchHandler: def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() + self.state_handler = hs.get_state_handler() + self.clock = hs.get_clock() + self.hs = hs self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() self.state_store = self.storage.state diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index a63fac8283..706ad72761 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -17,19 +17,17 @@ from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester -from ._base import BaseHandler - if TYPE_CHECKING: from synapse.server import HomeServer logger = logging.getLogger(__name__) -class SetPasswordHandler(BaseHandler): +class SetPasswordHandler: """Handler which deals with changing user account passwords""" def __init__(self, hs: "HomeServer"): - super().__init__(hs) + self.store = hs.get_datastore() self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() diff --git a/synapse/server.py b/synapse/server.py index 637eb15b78..0783df41d4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -39,7 +39,7 @@ from twisted.web.resource import IResource from synapse.api.auth import Auth from synapse.api.filtering import Filtering -from synapse.api.ratelimiting import Ratelimiter +from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.config.homeserver import HomeServerConfig @@ -816,3 +816,12 @@ class HomeServer(metaclass=abc.ABCMeta): def should_send_federation(self) -> bool: "Should this server be sending federation traffic directly?" return self.config.worker.send_federation + + @cache_in_self + def get_request_ratelimiter(self) -> RequestRatelimiter: + return RequestRatelimiter( + self.get_datastore(), + self.get_clock(), + self.config.ratelimiting.rc_message, + self.config.ratelimiting.rc_admin_redaction, + ) -- cgit 1.5.1