From 94b620a5edd6b5bc55c8aad6e00a11cc6bf210fa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 29 Sep 2021 06:44:15 -0400 Subject: Use direct references for configuration variables (part 6). (#10916) --- synapse/storage/databases/main/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'synapse/storage/databases/main/registration.py') diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index c83089ee63..7279b0924e 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -207,7 +207,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): return False now = self._clock.time_msec() - trial_duration_ms = self.config.mau_trial_days * 24 * 60 * 60 * 1000 + trial_duration_ms = self.config.server.mau_trial_days * 24 * 60 * 60 * 1000 is_trial = (now - info["creation_ts"] * 1000) < trial_duration_ms return is_trial -- cgit 1.5.1 From a0f48ee89d88fd7b6da8023dbba607a69073152e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 4 Oct 2021 07:18:54 -0400 Subject: Use direct references for configuration variables (part 7). (#10959) --- changelog.d/10959.misc | 1 + synapse/handlers/auth.py | 2 +- synapse/handlers/identity.py | 13 ++++++++++--- synapse/handlers/profile.py | 4 ++-- synapse/handlers/register.py | 9 ++++++--- synapse/handlers/room_member.py | 2 +- synapse/handlers/ui_auth/checkers.py | 14 ++++++++------ synapse/rest/admin/users.py | 4 ++-- synapse/rest/client/account.py | 22 +++++++++++----------- synapse/rest/client/auth.py | 6 ++++-- synapse/rest/client/capabilities.py | 6 +++--- synapse/rest/client/login.py | 6 +++--- synapse/rest/client/register.py | 26 +++++++++++++------------- synapse/rest/well_known.py | 4 ++-- synapse/storage/databases/main/registration.py | 2 +- synapse/util/threepids.py | 4 ++-- tests/config/test_load.py | 6 +++--- tests/handlers/test_profile.py | 4 ++-- tests/rest/admin/test_user.py | 4 ++-- tests/rest/client/test_account.py | 4 ++-- tests/rest/client/test_identity.py | 2 +- tests/rest/client/test_register.py | 4 ++-- tests/unittest.py | 2 +- 23 files changed, 83 insertions(+), 68 deletions(-) create mode 100644 changelog.d/10959.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/10959.misc b/changelog.d/10959.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10959.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index a8c717efd5..2d0f3d566c 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -198,7 +198,7 @@ class AuthHandler(BaseHandler): if inst.is_enabled(): self.checkers[inst.AUTH_TYPE] = inst # type: ignore - self.bcrypt_rounds = hs.config.bcrypt_rounds + self.bcrypt_rounds = hs.config.registration.bcrypt_rounds # we can't use hs.get_module_api() here, because to do so will create an # import loop. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index a0640fcac0..c881475c25 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -573,9 +573,15 @@ class IdentityHandler(BaseHandler): # Try to validate as email if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: + # Remote emails will only be used if a valid identity server is provided. + assert ( + self.hs.config.registration.account_threepid_delegate_email is not None + ) + # Ask our delegated email identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL: # Get a validated session matching these details @@ -587,10 +593,11 @@ class IdentityHandler(BaseHandler): return validation_session # Try to validate as msisdn - if self.hs.config.account_threepid_delegate_msisdn: + if self.hs.config.registration.account_threepid_delegate_msisdn: # Ask our delegated msisdn identity server validation_session = await self.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) return validation_session diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 425c0d4973..2e19706c69 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -178,7 +178,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's displayname") - if not by_admin and not self.hs.config.enable_set_displayname: + if not by_admin and not self.hs.config.registration.enable_set_displayname: profile = await self.store.get_profileinfo(target_user.localpart) if profile.display_name: raise SynapseError( @@ -268,7 +268,7 @@ class ProfileHandler(BaseHandler): if not by_admin and target_user != requester.user: raise AuthError(400, "Cannot set another user's avatar_url") - if not by_admin and not self.hs.config.enable_set_avatar_url: + if not by_admin and not self.hs.config.registration.enable_set_avatar_url: profile = await self.store.get_profileinfo(target_user.localpart) if profile.avatar_url: raise SynapseError( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index cb4eb0720b..441af7a848 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -116,8 +116,8 @@ class RegistrationHandler(BaseHandler): self._register_device_client = self.register_device_inner self.pusher_pool = hs.get_pusherpool() - self.session_lifetime = hs.config.session_lifetime - self.access_token_lifetime = hs.config.access_token_lifetime + self.session_lifetime = hs.config.registration.session_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime init_counters_for_auth_provider("") @@ -343,7 +343,10 @@ class RegistrationHandler(BaseHandler): # If the user does not need to consent at registration, auto-join any # configured rooms. if not self.hs.config.consent.user_consent_at_registration: - if not self.hs.config.auto_join_rooms_for_guests and make_guest: + if ( + not self.hs.config.registration.auto_join_rooms_for_guests + and make_guest + ): logger.info( "Skipping auto-join for %s because auto-join for guests is disabled", user_id, diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 29b3e41cc9..c8fb24a20c 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -89,7 +89,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.spam_checker = hs.get_spam_checker() self.third_party_event_rules = hs.get_third_party_event_rules() self._server_notices_mxid = self.config.servernotices.server_notices_mxid - self._enable_lookup = hs.config.enable_3pid_lookup + self._enable_lookup = hs.config.registration.enable_3pid_lookup self.allow_per_room_profiles = self.config.server.allow_per_room_profiles self._join_rate_limiter_local = Ratelimiter( diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index 8f5d465fa1..184730ebe8 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -153,21 +153,23 @@ class _BaseThreepidAuthChecker: # msisdns are currently always ThreepidBehaviour.REMOTE if medium == "msisdn": - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "Phone number verification is not enabled on this homeserver" ) threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_msisdn, threepid_creds + self.hs.config.registration.account_threepid_delegate_msisdn, + threepid_creds, ) elif medium == "email": if ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE ): - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email threepid = await identity_handler.threepid_from_creds( - self.hs.config.account_threepid_delegate_email, threepid_creds + self.hs.config.registration.account_threepid_delegate_email, + threepid_creds, ) elif ( self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL @@ -240,7 +242,7 @@ class MsisdnAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChecker): _BaseThreepidAuthChecker.__init__(self, hs) def is_enabled(self) -> bool: - return bool(self.hs.config.account_threepid_delegate_msisdn) + return bool(self.hs.config.registration.account_threepid_delegate_msisdn) async def check_auth(self, authdict: dict, clientip: str) -> Any: return await self._check_threepid("msisdn", authdict) @@ -252,7 +254,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.hs = hs - self._enabled = bool(hs.config.registration_requires_token) + self._enabled = bool(hs.config.registration.registration_requires_token) self.store = hs.get_datastore() def is_enabled(self) -> bool: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 46bfec4623..f20aa65301 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -442,7 +442,7 @@ class UserRegisterServlet(RestServlet): async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: self._clear_old_nonces() - if not self.hs.config.registration_shared_secret: + if not self.hs.config.registration.registration_shared_secret: raise SynapseError(400, "Shared secret registration is not enabled") body = parse_json_object_from_request(request) @@ -498,7 +498,7 @@ class UserRegisterServlet(RestServlet): got_mac = body["mac"] want_mac_builder = hmac.new( - key=self.hs.config.registration_shared_secret.encode(), + key=self.hs.config.registration.registration_shared_secret.encode(), digestmod=hashlib.sha1, ) want_mac_builder.update(nonce.encode("utf8")) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index fff133ef10..6b272658fc 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -130,11 +130,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -414,11 +414,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -496,7 +496,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): raise SynapseError(400, "MSISDN is already in use", Codes.THREEPID_IN_USE) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -507,7 +507,7 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -604,7 +604,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): self.identity_handler = hs.get_identity_handler() async def on_POST(self, request: Request) -> Tuple[int, JsonDict]: - if not self.config.account_threepid_delegate_msisdn: + if not self.config.registration.account_threepid_delegate_msisdn: raise SynapseError( 400, "This homeserver is not validating phone numbers. Use an identity server " @@ -617,7 +617,7 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( - self.config.account_threepid_delegate_msisdn, + self.config.registration.account_threepid_delegate_msisdn, body["client_secret"], body["sid"], body["token"], @@ -644,7 +644,7 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -693,7 +693,7 @@ class ThreepidAddRestServlet(RestServlet): @interactive_auth_handler async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) @@ -801,7 +801,7 @@ class ThreepidDeleteRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_3pid_changes: + if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( 400, "3PID changes are disabled on this server", Codes.FORBIDDEN ) diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index 282861fae2..c9ad35a3ad 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -49,8 +49,10 @@ class AuthRestServlet(RestServlet): self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template self.terms_template = hs.config.terms_template - self.registration_token_template = hs.config.registration_token_template - self.success_template = hs.config.fallback_success_template + self.registration_token_template = ( + hs.config.registration.registration_token_template + ) + self.success_template = hs.config.registration.fallback_success_template async def on_GET(self, request: SynapseRequest, stagetype: str) -> None: session = parse_string(request, "session") diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py index d6b6256413..2a3e24ae7e 100644 --- a/synapse/rest/client/capabilities.py +++ b/synapse/rest/client/capabilities.py @@ -64,13 +64,13 @@ class CapabilitiesRestServlet(RestServlet): if self.config.experimental.msc3283_enabled: response["capabilities"]["org.matrix.msc3283.set_displayname"] = { - "enabled": self.config.enable_set_displayname + "enabled": self.config.registration.enable_set_displayname } response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = { - "enabled": self.config.enable_set_avatar_url + "enabled": self.config.registration.enable_set_avatar_url } response["capabilities"]["org.matrix.msc3283.3pid_changes"] = { - "enabled": self.config.enable_3pid_changes + "enabled": self.config.registration.enable_3pid_changes } return 200, response diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index fa5c173f4b..d49a647b03 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -79,7 +79,7 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2.saml2_enabled self.cas_enabled = hs.config.cas.cas_enabled self.oidc_enabled = hs.config.oidc.oidc_enabled - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self.auth = hs.get_auth() @@ -447,7 +447,7 @@ class RefreshTokenServlet(RestServlet): def __init__(self, hs: "HomeServer"): self._auth_handler = hs.get_auth_handler() self._clock = hs.get_clock() - self.access_token_lifetime = hs.config.access_token_lifetime + self.access_token_lifetime = hs.config.registration.access_token_lifetime async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: refresh_submission = parse_json_object_from_request(request) @@ -556,7 +556,7 @@ class CasTicketServlet(RestServlet): def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: LoginRestServlet(hs).register(http_server) - if hs.config.access_token_lifetime is not None: + if hs.config.registration.access_token_lifetime is not None: RefreshTokenServlet(hs).register(http_server) SsoRedirectServlet(hs).register(http_server) if hs.config.cas.cas_enabled: diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py index a6eb6f6410..bf3cb34146 100644 --- a/synapse/rest/client/register.py +++ b/synapse/rest/client/register.py @@ -140,11 +140,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) if self.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE: - assert self.hs.config.account_threepid_delegate_email + assert self.hs.config.registration.account_threepid_delegate_email # Have the configured identity server handle the request ret = await self.identity_handler.requestEmailToken( - self.hs.config.account_threepid_delegate_email, + self.hs.config.registration.account_threepid_delegate_email, email, client_secret, send_attempt, @@ -221,7 +221,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): 400, "Phone number is already in use", Codes.THREEPID_IN_USE ) - if not self.hs.config.account_threepid_delegate_msisdn: + if not self.hs.config.registration.account_threepid_delegate_msisdn: logger.warning( "No upstream msisdn account_threepid_delegate configured on the server to " "handle this request" @@ -231,7 +231,7 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): ) ret = await self.identity_handler.requestMsisdnToken( - self.hs.config.account_threepid_delegate_msisdn, + self.hs.config.registration.account_threepid_delegate_msisdn, country, phone_number, client_secret, @@ -341,7 +341,7 @@ class UsernameAvailabilityRestServlet(RestServlet): ) async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -391,7 +391,7 @@ class RegistrationTokenValidityRestServlet(RestServlet): async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: await self.ratelimiter.ratelimit(None, (request.getClientIP(),)) - if not self.hs.config.enable_registration: + if not self.hs.config.registration.enable_registration: raise SynapseError( 403, "Registration has been disabled", errcode=Codes.FORBIDDEN ) @@ -419,8 +419,8 @@ class RegisterRestServlet(RestServlet): self.ratelimiter = hs.get_registration_ratelimiter() self.password_policy_handler = hs.get_password_policy_handler() self.clock = hs.get_clock() - self._registration_enabled = self.hs.config.enable_registration - self._msc2918_enabled = hs.config.access_token_lifetime is not None + self._registration_enabled = self.hs.config.registration.enable_registration + self._msc2918_enabled = hs.config.registration.access_token_lifetime is not None self._registration_flows = _calculate_registration_flows( hs.config, self.auth_handler @@ -800,7 +800,7 @@ class RegisterRestServlet(RestServlet): async def _do_guest_registration( self, params: JsonDict, address: Optional[str] = None ) -> Tuple[int, JsonDict]: - if not self.hs.config.allow_guest_access: + if not self.hs.config.registration.allow_guest_access: raise SynapseError(403, "Guest access is disabled") user_id = await self.registration_handler.register_user( make_guest=True, address=address @@ -849,13 +849,13 @@ def _calculate_registration_flows( """ # FIXME: need a better error than "no auth flow found" for scenarios # where we required 3PID for registration but the user didn't give one - require_email = "email" in config.registrations_require_3pid - require_msisdn = "msisdn" in config.registrations_require_3pid + require_email = "email" in config.registration.registrations_require_3pid + require_msisdn = "msisdn" in config.registration.registrations_require_3pid show_msisdn = True show_email = True - if config.disable_msisdn_registration: + if config.registration.disable_msisdn_registration: show_msisdn = False require_msisdn = False @@ -909,7 +909,7 @@ def _calculate_registration_flows( flow.insert(0, LoginType.RECAPTCHA) # Prepend registration token to all flows if we're requiring a token - if config.registration_requires_token: + if config.registration.registration_requires_token: for flow in flows: flow.insert(0, LoginType.REGISTRATION_TOKEN) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index c80a3a99aa..7ac01faab4 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -39,9 +39,9 @@ class WellKnownBuilder: result = {"m.homeserver": {"base_url": self._config.server.public_baseurl}} - if self._config.default_identity_server: + if self._config.registration.default_identity_server: result["m.identity_server"] = { - "base_url": self._config.default_identity_server + "base_url": self._config.registration.default_identity_server } return result diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7279b0924e..de262fbf5a 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1710,7 +1710,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): We do this by grandfathering in existing user threepids assuming that they used one of the server configured trusted identity servers. """ - id_servers = set(self.config.trusted_third_party_id_servers) + id_servers = set(self.config.registration.trusted_third_party_id_servers) def _bg_user_threepids_grandfather_txn(txn): sql = """ diff --git a/synapse/util/threepids.py b/synapse/util/threepids.py index baa9190a9a..389adf00f6 100644 --- a/synapse/util/threepids.py +++ b/synapse/util/threepids.py @@ -44,8 +44,8 @@ def check_3pid_allowed(hs: "HomeServer", medium: str, address: str) -> bool: bool: whether the 3PID medium/address is allowed to be added to this HS """ - if hs.config.allowed_local_3pids: - for constraint in hs.config.allowed_local_3pids: + if hs.config.registration.allowed_local_3pids: + for constraint in hs.config.registration.allowed_local_3pids: logger.debug( "Checking 3PID %s (%s) against %s (%s)", address, diff --git a/tests/config/test_load.py b/tests/config/test_load.py index ef6c2beec7..8e49ca26d9 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -84,16 +84,16 @@ class ConfigLoadingTestCase(unittest.TestCase): ) # Check that disable_registration clobbers enable_registration. config = HomeServerConfig.load_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertFalse(config.enable_registration) + self.assertFalse(config.registration.enable_registration) # Check that either config value is clobbered by the command line. config = HomeServerConfig.load_or_generate_config( "", ["-c", self.file, "--enable-registration"] ) - self.assertTrue(config.enable_registration) + self.assertTrue(config.registration.enable_registration) def test_stats_enabled(self): self.generate_config_and_remove_lines_containing("enable_metrics") diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 57cc3e2646..c153018fd8 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -110,7 +110,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_name_if_disabled(self): - self.hs.config.enable_set_displayname = False + self.hs.config.registration.enable_set_displayname = False # Setting displayname for the first time is allowed self.get_success( @@ -225,7 +225,7 @@ class ProfileTestCase(unittest.HomeserverTestCase): ) def test_set_my_avatar_if_disabled(self): - self.hs.config.enable_set_avatar_url = False + self.hs.config.registration.enable_set_avatar_url = False # Setting displayname for the first time is allowed self.get_success( diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index a285d5a7fe..6ed9e42173 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -59,7 +59,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" self.hs.get_media_repository = Mock() self.hs.get_deactivate_account_handler = Mock() @@ -71,7 +71,7 @@ class UserRegisterTestCase(unittest.HomeserverTestCase): If there is no shared secret, registration through this method will be prevented. """ - self.hs.config.registration_shared_secret = None + self.hs.config.registration.registration_shared_secret = None channel = self.make_request("POST", self.url, b"{}") diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index 2f44547bfb..89d85b0a17 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -664,7 +664,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False client_secret = "foobar" session_id = self._request_token(self.email, client_secret) @@ -734,7 +734,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_delete_email_if_disabled(self): """Test deleting an email from profile when disallowed""" - self.hs.config.enable_3pid_changes = False + self.hs.config.registration.enable_3pid_changes = False # Add a threepid self.get_success( diff --git a/tests/rest/client/test_identity.py b/tests/rest/client/test_identity.py index ca2e8ff8ef..becb4e8dcc 100644 --- a/tests/rest/client/test_identity.py +++ b/tests/rest/client/test_identity.py @@ -37,7 +37,7 @@ class IdentityTestCase(unittest.HomeserverTestCase): return self.hs def test_3pid_lookup_disabled(self): - self.hs.config.enable_3pid_lookup = False + self.hs.config.registration.enable_3pid_lookup = False self.register_user("kermit", "monkey") tok = self.login("kermit", "monkey") diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index af135d57e1..66dcfc9f88 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -147,7 +147,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): def test_POST_guest_registration(self): self.hs.config.key.macaroon_secret_key = "test" - self.hs.config.allow_guest_access = True + self.hs.config.registration.allow_guest_access = True channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") @@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): self.assertDictContainsSubset(det_data, channel.json_body) def test_POST_disabled_guest_registration(self): - self.hs.config.allow_guest_access = False + self.hs.config.registration.allow_guest_access = False channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}") diff --git a/tests/unittest.py b/tests/unittest.py index 0807467e39..1f803564f6 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -560,7 +560,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user. """ - self.hs.config.registration_shared_secret = "shared" + self.hs.config.registration.registration_shared_secret = "shared" # Create the user channel = self.make_request("GET", "/_synapse/admin/v1/register") -- cgit 1.5.1 From f4b1a9a527273ef71b2f7d970642b7af45462e0f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 6 Oct 2021 10:47:41 -0400 Subject: Require direct references to configuration variables. (#10985) This removes the magic allowing accessing configurable variables directly from the config object. It is now required that a specific configuration class is used (e.g. `config.foo` must be replaced with `config.server.foo`). --- changelog.d/10985.misc | 1 + scripts/synapse_port_db | 4 +- scripts/update_synapse_database | 2 +- synapse/app/_base.py | 2 +- synapse/app/admin_cmd.py | 4 +- synapse/app/homeserver.py | 2 +- synapse/config/_base.py | 64 ++++---------------------- synapse/config/account_validity.py | 2 +- synapse/config/cas.py | 2 +- synapse/config/emailconfig.py | 9 ++-- synapse/config/key.py | 6 ++- synapse/config/oidc.py | 2 +- synapse/config/registration.py | 7 ++- synapse/config/repository.py | 2 +- synapse/config/saml2.py | 2 +- synapse/config/server_notices.py | 4 +- synapse/config/sso.py | 6 ++- synapse/handlers/account_validity.py | 8 +--- synapse/handlers/room_member.py | 7 ++- synapse/replication/tcp/client.py | 2 +- synapse/replication/tcp/handler.py | 7 ++- synapse/rest/client/auth.py | 2 +- synapse/rest/client/push_rule.py | 4 +- synapse/storage/databases/main/push_rule.py | 4 +- synapse/storage/databases/main/registration.py | 4 +- tests/config/test_base.py | 21 +++++---- tests/config/test_cache.py | 50 ++++++++------------ tests/config/test_load.py | 12 +++-- tests/config/test_tls.py | 38 +++++++-------- tests/storage/test_appservice.py | 2 +- tests/storage/test_txn_limit.py | 2 +- 31 files changed, 124 insertions(+), 160 deletions(-) create mode 100644 changelog.d/10985.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/10985.misc b/changelog.d/10985.misc new file mode 100644 index 0000000000..586a0b3a96 --- /dev/null +++ b/changelog.d/10985.misc @@ -0,0 +1 @@ +Use direct references to config flags. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index fa6ac6d93a..a947d9e49e 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -215,7 +215,7 @@ class MockHomeserver: def __init__(self, config): self.clock = Clock(reactor) self.config = config - self.hostname = config.server_name + self.hostname = config.server.server_name self.version_string = "Synapse/" + get_version_string(synapse) def get_clock(self): @@ -583,7 +583,7 @@ class Porter(object): return self.postgres_store = self.build_db_store( - self.hs_config.get_single_database() + self.hs_config.database.get_single_database() ) await self.run_background_updates_on_postgres() diff --git a/scripts/update_synapse_database b/scripts/update_synapse_database index 26b29b0b45..6c088bad93 100755 --- a/scripts/update_synapse_database +++ b/scripts/update_synapse_database @@ -36,7 +36,7 @@ class MockHomeserver(HomeServer): def __init__(self, config, **kwargs): super(MockHomeserver, self).__init__( - config.server_name, reactor=reactor, config=config, **kwargs + config.server.server_name, reactor=reactor, config=config, **kwargs ) self.version_string = "Synapse/" + get_version_string(synapse) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 749bc1deb9..4a204a5823 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -301,7 +301,7 @@ def refresh_certificate(hs): if not hs.config.server.has_tls_listener(): return - hs.config.read_certificate_from_disk() + hs.config.tls.read_certificate_from_disk() hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config) if hs._listening_services: diff --git a/synapse/app/admin_cmd.py b/synapse/app/admin_cmd.py index 556bcc124e..13d20af457 100644 --- a/synapse/app/admin_cmd.py +++ b/synapse/app/admin_cmd.py @@ -197,9 +197,9 @@ def start(config_options): # Explicitly disable background processes config.server.update_user_directory = False config.worker.run_background_tasks = False - config.start_pushers = False + config.worker.start_pushers = False config.pusher_shard_config.instances = [] - config.send_federation = False + config.worker.send_federation = False config.federation_shard_config.instances = [] synapse.events.USE_FROZEN_DICTS = config.server.use_frozen_dicts diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b2d4bbf83..422f03cc04 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -234,7 +234,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["media", "federation", "client"]: - if self.config.media.enable_media_repo: + if self.config.server.enable_media_repo: media_repo = self.get_media_repository_resource() resources.update( {MEDIA_PREFIX: media_repo, LEGACY_MEDIA_PREFIX: media_repo} diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 26152b0924..7c4428a138 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -118,21 +118,6 @@ class Config: "synapse", "res/templates" ) - def __getattr__(self, item: str) -> Any: - """ - Try and fetch a configuration option that does not exist on this class. - - This is so that existing configs that rely on `self.value`, where value - is actually from a different config section, continue to work. - """ - if item in ["generate_config_section", "read_config"]: - raise AttributeError(item) - - if self.root is None: - raise AttributeError(item) - else: - return self.root._get_unclassed_config(self.section, item) - @staticmethod def parse_size(value): if isinstance(value, int): @@ -289,7 +274,9 @@ class Config: env.filters.update( { "format_ts": _format_ts_filter, - "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + "mxc_to_http": _create_mxc_to_http_filter( + self.root.server.public_baseurl + ), } ) @@ -311,8 +298,6 @@ class RootConfig: config_classes = [] def __init__(self): - self._configs = OrderedDict() - for config_class in self.config_classes: if config_class.section is None: raise ValueError("%r requires a section name" % (config_class,)) @@ -321,42 +306,7 @@ class RootConfig: conf = config_class(self) except Exception as e: raise Exception("Failed making %s: %r" % (config_class.section, e)) - self._configs[config_class.section] = conf - - def __getattr__(self, item: str) -> Any: - """ - Redirect lookups on this object either to config objects, or values on - config objects, so that `config.tls.blah` works, as well as legacy uses - of things like `config.server.server_name`. It will first look up the config - section name, and then values on those config classes. - """ - if item in self._configs.keys(): - return self._configs[item] - - return self._get_unclassed_config(None, item) - - def _get_unclassed_config(self, asking_section: Optional[str], item: str): - """ - Fetch a config value from one of the instantiated config classes that - has not been fetched directly. - - Args: - asking_section: If this check is coming from a Config child, which - one? This section will not be asked if it has the value. - item: The configuration value key. - - Raises: - AttributeError if no config classes have the config key. The body - will contain what sections were checked. - """ - for key, val in self._configs.items(): - if key == asking_section: - continue - - if item in dir(val): - return getattr(val, item) - - raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),)) + setattr(self, config_class.section, conf) def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]: """ @@ -373,9 +323,11 @@ class RootConfig: """ res = OrderedDict() - for name, config in self._configs.items(): + for config_class in self.config_classes: + config = getattr(self, config_class.section) + if hasattr(config, func_name): - res[name] = getattr(config, func_name)(*args, **kwargs) + res[config_class.section] = getattr(config, func_name)(*args, **kwargs) return res diff --git a/synapse/config/account_validity.py b/synapse/config/account_validity.py index ffaffc4931..b56c2a24df 100644 --- a/synapse/config/account_validity.py +++ b/synapse/config/account_validity.py @@ -76,7 +76,7 @@ class AccountValidityConfig(Config): ) if self.account_validity_renew_by_email_enabled: - if not self.public_baseurl: + if not self.root.server.public_baseurl: raise ConfigError("Can't send renewal emails without 'public_baseurl'") # Load account validity templates. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 901f4123e1..9b58ecf3d8 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -37,7 +37,7 @@ class CasConfig(Config): # The public baseurl is required because it is used by the redirect # template. - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if not public_baseurl: raise ConfigError("cas_config requires a public_baseurl to be set") diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 936abe6178..8ff59aa2f8 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -19,7 +19,6 @@ import email.utils import logging import os from enum import Enum -from typing import Optional import attr @@ -135,7 +134,7 @@ class EmailConfig(Config): # msisdn is currently always remote while Synapse does not support any method of # sending SMS messages ThreepidBehaviour.REMOTE - if self.account_threepid_delegate_email + if self.root.registration.account_threepid_delegate_email else ThreepidBehaviour.LOCAL ) # Prior to Synapse v1.4.0, there was another option that defined whether Synapse would @@ -144,7 +143,7 @@ class EmailConfig(Config): # identity server in the process. self.using_identity_server_from_trusted_list = False if ( - not self.account_threepid_delegate_email + not self.root.registration.account_threepid_delegate_email and config.get("trust_identity_server_for_password_resets", False) is True ): # Use the first entry in self.trusted_third_party_id_servers instead @@ -156,7 +155,7 @@ class EmailConfig(Config): # trusted_third_party_id_servers does not contain a scheme whereas # account_threepid_delegate_email is expected to. Presume https - self.account_threepid_delegate_email: Optional[str] = ( + self.root.registration.account_threepid_delegate_email = ( "https://" + first_trusted_identity_server ) self.using_identity_server_from_trusted_list = True @@ -335,7 +334,7 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if self.account_validity_renew_by_email_enabled: + if self.root.account_validity.account_validity_renew_by_email_enabled: expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) diff --git a/synapse/config/key.py b/synapse/config/key.py index 94a9063043..015dbb8a67 100644 --- a/synapse/config/key.py +++ b/synapse/config/key.py @@ -145,11 +145,13 @@ class KeyConfig(Config): # list of TrustedKeyServer objects self.key_servers = list( - _parse_key_servers(key_servers, self.federation_verify_certificates) + _parse_key_servers( + key_servers, self.root.tls.federation_verify_certificates + ) ) self.macaroon_secret_key = config.get( - "macaroon_secret_key", self.registration_shared_secret + "macaroon_secret_key", self.root.registration.registration_shared_secret ) if not self.macaroon_secret_key: diff --git a/synapse/config/oidc.py b/synapse/config/oidc.py index 7e67fbada1..10f5796330 100644 --- a/synapse/config/oidc.py +++ b/synapse/config/oidc.py @@ -58,7 +58,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if public_baseurl is None: raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/client/oidc/callback" diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 7cffdacfa5..a3d2a38c4c 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -45,7 +45,10 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: + if ( + self.account_threepid_delegate_msisdn + and not self.root.server.public_baseurl + ): raise ConfigError( "The configuration option `public_baseurl` is required if " "`account_threepid_delegate.msisdn` is set, such that " @@ -85,7 +88,7 @@ class RegistrationConfig(Config): if mxid_localpart: # Convert the localpart to a full mxid. self.auto_join_user_id = UserID( - mxid_localpart, self.server_name + mxid_localpart, self.root.server.server_name ).to_string() if self.autocreate_auto_join_rooms: diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 7481f3bf5f..69906a98d4 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -94,7 +94,7 @@ class ContentRepositoryConfig(Config): # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( - self.enable_media_repo is False + self.root.server.enable_media_repo is False and config.get("worker_app") != "synapse.app.media_repository" ): self.can_load_media_repo = False diff --git a/synapse/config/saml2.py b/synapse/config/saml2.py index 05e983625d..9c51b6a25a 100644 --- a/synapse/config/saml2.py +++ b/synapse/config/saml2.py @@ -199,7 +199,7 @@ class SAML2Config(Config): """ import saml2 - public_baseurl = self.public_baseurl + public_baseurl = self.root.server.public_baseurl if public_baseurl is None: raise ConfigError("saml2_config requires a public_baseurl to be set") diff --git a/synapse/config/server_notices.py b/synapse/config/server_notices.py index 48bf3241b6..bde4e879d9 100644 --- a/synapse/config/server_notices.py +++ b/synapse/config/server_notices.py @@ -73,7 +73,9 @@ class ServerNoticesConfig(Config): return mxid_localpart = c["system_mxid_localpart"] - self.server_notices_mxid = UserID(mxid_localpart, self.server_name).to_string() + self.server_notices_mxid = UserID( + mxid_localpart, self.root.server.server_name + ).to_string() self.server_notices_mxid_display_name = c.get("system_mxid_display_name", None) self.server_notices_mxid_avatar_url = c.get("system_mxid_avatar_url", None) # todo: i18n diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 524a7ff3aa..11a9b76aa0 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -103,8 +103,10 @@ class SSOConfig(Config): # the client's. # public_baseurl is an optional setting, so we only add the fallback's URL to the # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + if self.root.server.public_baseurl: + login_fallback_url = ( + self.root.server.public_baseurl + "_matrix/static/client/login" + ) self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 5a5f124ddf..87e415df75 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -67,12 +67,8 @@ class AccountValidityHandler: and self._account_validity_renew_by_email_enabled ): # Don't do email-specific configuration if renewal by email is disabled. - self._template_html = ( - hs.config.account_validity.account_validity_template_html - ) - self._template_text = ( - hs.config.account_validity.account_validity_template_text - ) + self._template_html = hs.config.email.account_validity_template_html + self._template_text = hs.config.email.account_validity_template_text self._renew_email_subject = ( hs.config.account_validity.account_validity_renew_email_subject ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 0b79dbcf8d..c05461bf2a 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1499,8 +1499,11 @@ class RoomMemberMasterHandler(RoomMemberHandler): if len(remote_room_hosts) == 0: raise SynapseError(404, "No known servers") - check_complexity = self.hs.config.limit_remote_rooms.enabled - if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join: + check_complexity = self.hs.config.server.limit_remote_rooms.enabled + if ( + check_complexity + and self.hs.config.server.limit_remote_rooms.admins_can_join + ): check_complexity = not await self.auth.is_server_admin(user) if check_complexity: diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 37769ace48..961c17762e 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -117,7 +117,7 @@ class ReplicationDataHandler: self._instance_name = hs.get_instance_name() self._typing_handler = hs.get_typing_handler() - self._notify_pushers = hs.config.start_pushers + self._notify_pushers = hs.config.worker.start_pushers self._pusher_pool = hs.get_pusherpool() self._presence_handler = hs.get_presence_handler() diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index d64d1dbacd..6aa9318027 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -171,7 +171,10 @@ class ReplicationCommandHandler: if hs.config.worker.worker_app is not None: continue - if stream.NAME == FederationStream.NAME and hs.config.send_federation: + if ( + stream.NAME == FederationStream.NAME + and hs.config.worker.send_federation + ): # We only support federation stream if federation sending # has been disabled on the master. continue @@ -225,7 +228,7 @@ class ReplicationCommandHandler: self._is_master = hs.config.worker.worker_app is None self._federation_sender = None - if self._is_master and not hs.config.send_federation: + if self._is_master and not hs.config.worker.send_federation: self._federation_sender = hs.get_federation_sender() self._server_notices_sender = None diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py index c9ad35a3ad..9c15a04338 100644 --- a/synapse/rest/client/auth.py +++ b/synapse/rest/client/auth.py @@ -48,7 +48,7 @@ class AuthRestServlet(RestServlet): self.auth_handler = hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() self.recaptcha_template = hs.config.captcha.recaptcha_template - self.terms_template = hs.config.terms_template + self.terms_template = hs.config.consent.terms_template self.registration_token_template = ( hs.config.registration.registration_token_template ) diff --git a/synapse/rest/client/push_rule.py b/synapse/rest/client/push_rule.py index ecebc46e8d..6f796d5e50 100644 --- a/synapse/rest/client/push_rule.py +++ b/synapse/rest/client/push_rule.py @@ -61,7 +61,9 @@ class PushRuleRestServlet(RestServlet): self.notifier = hs.get_notifier() self._is_worker = hs.config.worker.worker_app is not None - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: if self._is_worker: diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index a7fb8cd848..b81e33964a 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -101,7 +101,9 @@ class PushRulesWorkerStore( prefilled_cache=push_rules_prefill, ) - self._users_new_default_push_rules = hs.config.users_new_default_push_rules + self._users_new_default_push_rules = ( + hs.config.server.users_new_default_push_rules + ) @abc.abstractmethod def get_max_push_rules_stream_id(self): diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index de262fbf5a..7de4ad7f9b 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -1778,7 +1778,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): super().__init__(database, db_conn, hs) - self._ignore_unknown_session_error = hs.config.request_token_inhibit_3pid_errors + self._ignore_unknown_session_error = ( + hs.config.server.request_token_inhibit_3pid_errors + ) self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id") self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id") diff --git a/tests/config/test_base.py b/tests/config/test_base.py index baa5313fb3..6a52f862f4 100644 --- a/tests/config/test_base.py +++ b/tests/config/test_base.py @@ -14,23 +14,28 @@ import os.path import tempfile +from unittest.mock import Mock from synapse.config import ConfigError +from synapse.config._base import Config from synapse.util.stringutils import random_string from tests import unittest -class BaseConfigTestCase(unittest.HomeserverTestCase): - def prepare(self, reactor, clock, hs): - self.hs = hs +class BaseConfigTestCase(unittest.TestCase): + def setUp(self): + # The root object needs a server property with a public_baseurl. + root = Mock() + root.server.public_baseurl = "http://test" + self.config = Config(root) def test_loading_missing_templates(self): # Use a temporary directory that exists on the system, but that isn't likely to # contain template files with tempfile.TemporaryDirectory() as tmp_dir: # Attempt to load an HTML template from our custom template directory - template = self.hs.config.read_templates(["sso_error.html"], (tmp_dir,))[0] + template = self.config.read_templates(["sso_error.html"], (tmp_dir,))[0] # If no errors, we should've gotten the default template instead @@ -60,7 +65,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Attempt to load the template from our custom template directory template = ( - self.hs.config.read_templates([template_filename], (tmp_dir,)) + self.config.read_templates([template_filename], (tmp_dir,)) )[0] # Render the template @@ -97,7 +102,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [template_filename], (td.name for td in tempdirs), ) @@ -118,7 +123,7 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): # Retrieve the template. template = ( - self.hs.config.read_templates( + self.config.read_templates( [other_template_name], (td.name for td in tempdirs), ) @@ -134,6 +139,6 @@ class BaseConfigTestCase(unittest.HomeserverTestCase): def test_loading_template_from_nonexistent_custom_directory(self): with self.assertRaises(ConfigError): - self.hs.config.read_templates( + self.config.read_templates( ["some_filename.html"], ("a_nonexistent_directory",) ) diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index 857d9cd096..f518abdb7a 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -12,39 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.config._base import Config, RootConfig from synapse.config.cache import CacheConfig, add_resizable_cache from synapse.util.caches.lrucache import LruCache from tests.unittest import TestCase -class FakeServer(Config): - section = "server" - - -class TestConfig(RootConfig): - config_classes = [FakeServer, CacheConfig] - - class CacheConfigTests(TestCase): def setUp(self): # Reset caches before each test - TestConfig().caches.reset() + self.config = CacheConfig() + + def tearDown(self): + self.config.reset() def test_individual_caches_from_environ(self): """ Individual cache factors will be loaded from the environment. """ config = {} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_NOT_CACHE": "BLAH", } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + self.assertEqual(dict(self.config.cache_factors), {"something_or_other": 2.0}) def test_config_overrides_environ(self): """ @@ -52,15 +45,14 @@ class CacheConfigTests(TestCase): over those in the config. """ config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", "SYNAPSE_CACHE_FACTOR_FOO": 1, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( - dict(t.caches.cache_factors), + dict(self.config.cache_factors), {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, ) @@ -76,8 +68,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"per_cache_factors": {"foo": 3}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config) self.assertEqual(cache.max_size, 300) @@ -88,8 +79,7 @@ class CacheConfigTests(TestCase): there is one. """ config = {"caches": {"per_cache_factors": {"foo": 2}}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -106,8 +96,7 @@ class CacheConfigTests(TestCase): self.assertEqual(cache.max_size, 50) config = {"caches": {"global_factor": 4}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual(cache.max_size, 400) @@ -118,8 +107,7 @@ class CacheConfigTests(TestCase): is no per-cache factor. """ config = {"caches": {"global_factor": 1.5}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache(100) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) @@ -133,12 +121,11 @@ class CacheConfigTests(TestCase): "per_cache_factors": {"*cache_a*": 5, "cache_b": 6, "cache_c": 2} } } - t = TestConfig() - t.caches._environ = { + self.config._environ = { "SYNAPSE_CACHE_FACTOR_CACHE_A": "2", "SYNAPSE_CACHE_FACTOR_CACHE_B": 3, } - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache_a = LruCache(100) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) @@ -158,11 +145,10 @@ class CacheConfigTests(TestCase): """ config = {"caches": {"event_cache_size": "10k"}} - t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + self.config.read_config(config, config_dir_path="", data_dir_path="") cache = LruCache( - max_size=t.caches.event_cache_size, + max_size=self.config.event_cache_size, apply_cache_factor_from_config=False, ) add_resizable_cache("event_cache", cache_resize_callback=cache.set_cache_factor) diff --git a/tests/config/test_load.py b/tests/config/test_load.py index 8e49ca26d9..59635de205 100644 --- a/tests/config/test_load.py +++ b/tests/config/test_load.py @@ -49,7 +49,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -60,7 +60,7 @@ class ConfigLoadingTestCase(unittest.TestCase): config = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) self.assertTrue( - hasattr(config, "macaroon_secret_key"), + hasattr(config.key, "macaroon_secret_key"), "Want config to have attr macaroon_secret_key", ) if len(config.key.macaroon_secret_key) < 5: @@ -74,8 +74,12 @@ class ConfigLoadingTestCase(unittest.TestCase): config1 = HomeServerConfig.load_config("", ["-c", self.file]) config2 = HomeServerConfig.load_config("", ["-c", self.file]) config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.file]) - self.assertEqual(config1.macaroon_secret_key, config2.macaroon_secret_key) - self.assertEqual(config1.macaroon_secret_key, config3.macaroon_secret_key) + self.assertEqual( + config1.key.macaroon_secret_key, config2.key.macaroon_secret_key + ) + self.assertEqual( + config1.key.macaroon_secret_key, config3.key.macaroon_secret_key + ) def test_disable_registration(self): self.generate_config() diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index b6bc1876b5..9ba5781573 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -42,9 +42,9 @@ class TLSConfigTests(TestCase): """ config = {} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") def test_tls_client_minimum_set(self): """ @@ -52,29 +52,29 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": 1.1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.1") config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") # Also test a string version config = {"federation_client_minimum_tls_version": "1"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1") config = {"federation_client_minimum_tls_version": "1.2"} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.2") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.2") def test_tls_client_minimum_1_point_3_missing(self): """ @@ -91,7 +91,7 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() with self.assertRaises(ConfigError) as e: - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") self.assertEqual( e.exception.args[0], ( @@ -112,8 +112,8 @@ class TLSConfigTests(TestCase): config = {"federation_client_minimum_tls_version": 1.3} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") - self.assertEqual(t.federation_client_minimum_tls_version, "1.3") + t.tls.read_config(config, config_dir_path="", data_dir_path="") + self.assertEqual(t.tls.federation_client_minimum_tls_version, "1.3") def test_tls_client_minimum_set_passed_through_1_2(self): """ @@ -121,7 +121,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1.2} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -137,7 +137,7 @@ class TLSConfigTests(TestCase): """ config = {"federation_client_minimum_tls_version": 1} t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) options = _get_ssl_context_options(cf._verify_ssl_context) @@ -159,7 +159,7 @@ class TLSConfigTests(TestCase): } t = TestConfig() e = self.assertRaises( - ConfigError, t.read_config, config, config_dir_path="", data_dir_path="" + ConfigError, t.tls.read_config, config, config_dir_path="", data_dir_path="" ) self.assertIn("IDNA domain names", str(e)) @@ -174,7 +174,7 @@ class TLSConfigTests(TestCase): ] } t = TestConfig() - t.read_config(config, config_dir_path="", data_dir_path="") + t.tls.read_config(config, config_dir_path="", data_dir_path="") cf = FederationPolicyForHTTPS(t) diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index cf9748f218..f26d5acf9c 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -126,7 +126,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): self.db_pool = database._db_pool self.engine = database.engine - db_config = hs.config.get_single_database() + db_config = hs.config.database.get_single_database() self.store = TestTransactionStore( database, make_conn(db_config, self.engine, "test"), hs ) diff --git a/tests/storage/test_txn_limit.py b/tests/storage/test_txn_limit.py index 6ff3ebb137..ace82cbf42 100644 --- a/tests/storage/test_txn_limit.py +++ b/tests/storage/test_txn_limit.py @@ -22,7 +22,7 @@ class SQLTransactionLimitTestCase(unittest.HomeserverTestCase): return self.setup_test_homeserver(db_txn_limit=1000) def test_config(self): - db_config = self.hs.config.get_single_database() + db_config = self.hs.config.database.get_single_database() self.assertEqual(db_config.config["txn_limit"], 1000) def test_select(self): -- cgit 1.5.1 From 51a5da74ccd383806378b53ee8a09e27a8829f31 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 8 Oct 2021 15:25:16 +0100 Subject: Annotate synapse.storage.util (#10892) Also mark `synapse.streams` as having has no untyped defs Co-authored-by: Sean Quah <8349537+squahtx@users.noreply.github.com> --- changelog.d/10892.misc | 1 + mypy.ini | 6 + .../slave/storage/_slaved_id_tracker.py | 4 +- synapse/replication/slave/storage/pushers.py | 10 +- synapse/storage/databases/main/pusher.py | 10 +- synapse/storage/databases/main/registration.py | 9 +- synapse/storage/util/id_generators.py | 143 +++++++++++++-------- synapse/storage/util/sequence.py | 6 +- 8 files changed, 124 insertions(+), 65 deletions(-) create mode 100644 changelog.d/10892.misc (limited to 'synapse/storage/databases/main/registration.py') diff --git a/changelog.d/10892.misc b/changelog.d/10892.misc new file mode 100644 index 0000000000..c8c471159b --- /dev/null +++ b/changelog.d/10892.misc @@ -0,0 +1 @@ +Add further type hints to `synapse.storage.util`. diff --git a/mypy.ini b/mypy.ini index e7cb80b6eb..bc2b59ff56 100644 --- a/mypy.ini +++ b/mypy.ini @@ -105,6 +105,12 @@ disallow_untyped_defs = True [mypy-synapse.state.*] disallow_untyped_defs = True +[mypy-synapse.storage.util.*] +disallow_untyped_defs = True + +[mypy-synapse.streams.*] +disallow_untyped_defs = True + [mypy-synapse.util.batching_queue] disallow_untyped_defs = True diff --git a/synapse/replication/slave/storage/_slaved_id_tracker.py b/synapse/replication/slave/storage/_slaved_id_tracker.py index 2cb7489047..8c1bf9227a 100644 --- a/synapse/replication/slave/storage/_slaved_id_tracker.py +++ b/synapse/replication/slave/storage/_slaved_id_tracker.py @@ -13,14 +13,14 @@ # limitations under the License. from typing import List, Optional, Tuple -from synapse.storage.types import Connection +from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.util.id_generators import _load_current_id class SlavedIdTracker: def __init__( self, - db_conn: Connection, + db_conn: LoggingDatabaseConnection, table: str, column: str, extra_tables: Optional[List[Tuple[str, str]]] = None, diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py index 2672a2c94b..cea90c0f1b 100644 --- a/synapse/replication/slave/storage/pushers.py +++ b/synapse/replication/slave/storage/pushers.py @@ -15,9 +15,8 @@ from typing import TYPE_CHECKING from synapse.replication.tcp.streams import PushersStream -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.pusher import PusherWorkerStore -from synapse.storage.types import Connection from ._base import BaseSlavedStore from ._slaved_id_tracker import SlavedIdTracker @@ -27,7 +26,12 @@ if TYPE_CHECKING: class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._pushers_id_gen = SlavedIdTracker( # type: ignore db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index a93caae8d0..b73ce53c91 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -18,8 +18,7 @@ from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import DatabasePool -from synapse.storage.types import Connection +from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder @@ -32,7 +31,12 @@ logger = logging.getLogger(__name__) class PusherWorkerStore(SQLBaseStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._pushers_id_gen = StreamIdGenerator( db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 7de4ad7f9b..181841ee06 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.databases.main.stats import StatsStore -from synapse.storage.types import Connection, Cursor +from synapse.storage.types import Cursor from synapse.storage.util.id_generators import IdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import UserID, UserInfo @@ -1775,7 +1775,12 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): - def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): super().__init__(database, db_conn, hs) self._ignore_unknown_session_error = ( diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 6f7cbe40f4..852bd79fee 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -16,42 +16,62 @@ import logging import threading from collections import OrderedDict from contextlib import contextmanager -from typing import Dict, Iterable, List, Optional, Set, Tuple, Union +from types import TracebackType +from typing import ( + AsyncContextManager, + ContextManager, + Dict, + Generator, + Generic, + Iterable, + List, + Optional, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, +) import attr from sortedcontainers import SortedSet from synapse.metrics.background_process_metrics import run_as_background_process -from synapse.storage.database import DatabasePool, LoggingTransaction +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, +) from synapse.storage.types import Cursor from synapse.storage.util.sequence import PostgresSequenceGenerator logger = logging.getLogger(__name__) +T = TypeVar("T") + + class IdGenerator: - def __init__(self, db_conn, table, column): + def __init__( + self, + db_conn: LoggingDatabaseConnection, + table: str, + column: str, + ): self._lock = threading.Lock() self._next_id = _load_current_id(db_conn, table, column) - def get_next(self): + def get_next(self) -> int: with self._lock: self._next_id += 1 return self._next_id -def _load_current_id(db_conn, table, column, step=1): - """ - - Args: - db_conn (object): - table (str): - column (str): - step (int): - - Returns: - int - """ +def _load_current_id( + db_conn: LoggingDatabaseConnection, table: str, column: str, step: int = 1 +) -> int: # debug logging for https://github.com/matrix-org/synapse/issues/7968 logger.info("initialising stream generator for %s(%s)", table, column) cur = db_conn.cursor(txn_name="_load_current_id") @@ -59,7 +79,9 @@ def _load_current_id(db_conn, table, column, step=1): cur.execute("SELECT MAX(%s) FROM %s" % (column, table)) else: cur.execute("SELECT MIN(%s) FROM %s" % (column, table)) - (val,) = cur.fetchone() + result = cur.fetchone() + assert result is not None + (val,) = result cur.close() current_id = int(val) if val else step return (max if step > 0 else min)(current_id, step) @@ -93,16 +115,16 @@ class StreamIdGenerator: def __init__( self, - db_conn, - table, - column, + db_conn: LoggingDatabaseConnection, + table: str, + column: str, extra_tables: Iterable[Tuple[str, str]] = (), - step=1, - ): + step: int = 1, + ) -> None: assert step != 0 self._lock = threading.Lock() - self._step = step - self._current = _load_current_id(db_conn, table, column, step) + self._step: int = step + self._current: int = _load_current_id(db_conn, table, column, step) for table, column in extra_tables: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) @@ -115,7 +137,7 @@ class StreamIdGenerator: # The key and values are the same, but we never look at the values. self._unfinished_ids: OrderedDict[int, int] = OrderedDict() - def get_next(self): + def get_next(self) -> AsyncContextManager[int]: """ Usage: async with stream_id_gen.get_next() as stream_id: @@ -128,7 +150,7 @@ class StreamIdGenerator: self._unfinished_ids[next_id] = next_id @contextmanager - def manager(): + def manager() -> Generator[int, None, None]: try: yield next_id finally: @@ -137,7 +159,7 @@ class StreamIdGenerator: return _AsyncCtxManagerWrapper(manager()) - def get_next_mult(self, n): + def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]: """ Usage: async with stream_id_gen.get_next(n) as stream_ids: @@ -155,7 +177,7 @@ class StreamIdGenerator: self._unfinished_ids[next_id] = next_id @contextmanager - def manager(): + def manager() -> Generator[Sequence[int], None, None]: try: yield next_ids finally: @@ -215,7 +237,7 @@ class MultiWriterIdGenerator: def __init__( self, - db_conn, + db_conn: LoggingDatabaseConnection, db: DatabasePool, stream_name: str, instance_name: str, @@ -223,7 +245,7 @@ class MultiWriterIdGenerator: sequence_name: str, writers: List[str], positive: bool = True, - ): + ) -> None: self._db = db self._stream_name = stream_name self._instance_name = instance_name @@ -285,9 +307,9 @@ class MultiWriterIdGenerator: def _load_current_ids( self, - db_conn, + db_conn: LoggingDatabaseConnection, tables: List[Tuple[str, str, str]], - ): + ) -> None: cur = db_conn.cursor(txn_name="_load_current_ids") # Load the current positions of all writers for the stream. @@ -335,7 +357,9 @@ class MultiWriterIdGenerator: "agg": "MAX" if self._positive else "-MIN", } cur.execute(sql) - (stream_id,) = cur.fetchone() + result = cur.fetchone() + assert result is not None + (stream_id,) = result max_stream_id = max(max_stream_id, stream_id) @@ -354,7 +378,7 @@ class MultiWriterIdGenerator: self._persisted_upto_position = min_stream_id - rows = [] + rows: List[Tuple[str, int]] = [] for table, instance_column, id_column in tables: sql = """ SELECT %(instance)s, %(id)s FROM %(table)s @@ -367,7 +391,8 @@ class MultiWriterIdGenerator: } cur.execute(sql, (min_stream_id * self._return_factor,)) - rows.extend(cur) + # Cast safety: this corresponds to the types returned by the query above. + rows.extend(cast(Iterable[Tuple[str, int]], cur)) # Sort so that we handle rows in order for each instance. rows.sort() @@ -385,13 +410,13 @@ class MultiWriterIdGenerator: cur.close() - def _load_next_id_txn(self, txn) -> int: + def _load_next_id_txn(self, txn: Cursor) -> int: return self._sequence_gen.get_next_id_txn(txn) - def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: + def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - def get_next(self): + def get_next(self) -> AsyncContextManager[int]: """ Usage: async with stream_id_gen.get_next() as stream_id: @@ -403,9 +428,12 @@ class MultiWriterIdGenerator: if self._writers and self._instance_name not in self._writers: raise Exception("Tried to allocate stream ID on non-writer") - return _MultiWriterCtxManager(self) + # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids, + # controls the return type. If `None` or omitted, the context manager yields + # a single integer stream_id; otherwise it yields a list of stream_ids. + return cast(AsyncContextManager[int], _MultiWriterCtxManager(self)) - def get_next_mult(self, n: int): + def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]: """ Usage: async with stream_id_gen.get_next_mult(5) as stream_ids: @@ -417,9 +445,10 @@ class MultiWriterIdGenerator: if self._writers and self._instance_name not in self._writers: raise Exception("Tried to allocate stream ID on non-writer") - return _MultiWriterCtxManager(self, n) + # Cast safety: see get_next. + return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n)) - def get_next_txn(self, txn: LoggingTransaction): + def get_next_txn(self, txn: LoggingTransaction) -> int: """ Usage: @@ -457,7 +486,7 @@ class MultiWriterIdGenerator: return self._return_factor * next_id - def _mark_id_as_finished(self, next_id: int): + def _mark_id_as_finished(self, next_id: int) -> None: """The ID has finished being processed so we should advance the current position if possible. """ @@ -534,7 +563,7 @@ class MultiWriterIdGenerator: for name, i in self._current_positions.items() } - def advance(self, instance_name: str, new_id: int): + def advance(self, instance_name: str, new_id: int) -> None: """Advance the position of the named writer to the given ID, if greater than existing entry. """ @@ -560,7 +589,7 @@ class MultiWriterIdGenerator: with self._lock: return self._return_factor * self._persisted_upto_position - def _add_persisted_position(self, new_id: int): + def _add_persisted_position(self, new_id: int) -> None: """Record that we have persisted a position. This is used to keep the `_current_positions` up to date. @@ -606,7 +635,7 @@ class MultiWriterIdGenerator: # do. break - def _update_stream_positions_table_txn(self, txn: Cursor): + def _update_stream_positions_table_txn(self, txn: Cursor) -> None: """Update the `stream_positions` table with newly persisted position.""" if not self._writers: @@ -628,20 +657,25 @@ class MultiWriterIdGenerator: txn.execute(sql, (self._stream_name, self._instance_name, pos)) -@attr.s(slots=True) -class _AsyncCtxManagerWrapper: +@attr.s(frozen=True, auto_attribs=True) +class _AsyncCtxManagerWrapper(Generic[T]): """Helper class to convert a plain context manager to an async one. This is mainly useful if you have a plain context manager but the interface requires an async one. """ - inner = attr.ib() + inner: ContextManager[T] - async def __aenter__(self): + async def __aenter__(self) -> T: return self.inner.__enter__() - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> Optional[bool]: return self.inner.__exit__(exc_type, exc, tb) @@ -671,7 +705,12 @@ class _MultiWriterCtxManager: else: return [i * self.id_gen._return_factor for i in self.stream_ids] - async def __aexit__(self, exc_type, exc, tb): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> bool: for i in self.stream_ids: self.id_gen._mark_id_as_finished(i) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index bb33e04fb1..75268cbe15 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -81,7 +81,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta): id_column: str, stream_name: Optional[str] = None, positive: bool = True, - ): + ) -> None: """Should be called during start up to test that the current value of the sequence is greater than or equal to the maximum ID in the table. @@ -122,7 +122,7 @@ class PostgresSequenceGenerator(SequenceGenerator): id_column: str, stream_name: Optional[str] = None, positive: bool = True, - ): + ) -> None: """See SequenceGenerator.check_consistency for docstring.""" txn = db_conn.cursor(txn_name="sequence.check_consistency") @@ -244,7 +244,7 @@ class LocalSequenceGenerator(SequenceGenerator): id_column: str, stream_name: Optional[str] = None, positive: bool = True, - ): + ) -> None: # There is nothing to do for in memory sequences pass -- cgit 1.5.1