From 39230d217104f3cd7aba9065dc478f935ce1e614 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 24 Mar 2020 14:45:33 +0000 Subject: Clean up some LoggingContext stuff (#7120) * Pull Sentinel out of LoggingContext ... and drop a few unnecessary references to it * Factor out LoggingContext.current_context move `current_context` and `set_context` out to top-level functions. Mostly this means that I can more easily trace what's actually referring to LoggingContext, but I think it's generally neater. * move copy-to-parent into `stop` this really just makes `start` and `stop` more symetric. It also means that it behaves correctly if you manually `set_log_context` rather than using the context manager. * Replace `LoggingContext.alive` with `finished` Turn `alive` into `finished` and make it a bit better defined. --- tests/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'tests/utils.py') diff --git a/tests/utils.py b/tests/utils.py index 513f358f4f..968d109f77 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -35,7 +35,7 @@ from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION from synapse.federation.transport import server as federation_server from synapse.http.server import HttpServer -from synapse.logging.context import LoggingContext +from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -493,10 +493,10 @@ class MockClock(object): return self.time() * 1000 def call_later(self, delay, callback, *args, **kwargs): - current_context = LoggingContext.current_context() + ctx = current_context() def wrapped_callback(): - LoggingContext.thread_local.current_context = current_context + set_current_context(ctx) callback(*args, **kwargs) t = [self.now + delay, wrapped_callback, False] -- cgit 1.5.1 From eed7c5b89eee6951ac17861b1695817470bace36 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 15 Apr 2020 12:40:18 -0400 Subject: Convert auth handler to async/await (#7261) --- changelog.d/7261.misc | 1 + synapse/handlers/auth.py | 173 ++++++++++++++++++--------------------- synapse/handlers/device.py | 12 ++- synapse/handlers/register.py | 28 +++++-- synapse/handlers/set_password.py | 13 ++- synapse/module_api/__init__.py | 6 +- tests/api/test_auth.py | 64 +++++++++------ tests/handlers/test_auth.py | 80 +++++++++++------- tests/handlers/test_register.py | 4 +- tests/utils.py | 13 ++- 10 files changed, 224 insertions(+), 170 deletions(-) create mode 100644 changelog.d/7261.misc (limited to 'tests/utils.py') diff --git a/changelog.d/7261.misc b/changelog.d/7261.misc new file mode 100644 index 0000000000..88165f0105 --- /dev/null +++ b/changelog.d/7261.misc @@ -0,0 +1 @@ +Convert auth handler to async/await. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index fbfbd44a2e..0aae929ecc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -18,14 +18,12 @@ import logging import time import unicodedata import urllib.parse -from typing import Any, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import attr import bcrypt # type: ignore[import] import pymacaroons -from twisted.internet import defer - import synapse.util.stringutils as stringutils from synapse.api.constants import LoginType from synapse.api.errors import ( @@ -170,15 +168,14 @@ class AuthHandler(BaseHandler): # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) - @defer.inlineCallbacks - def validate_user_via_ui_auth( + async def validate_user_via_ui_auth( self, requester: Requester, request: SynapseRequest, request_body: Dict[str, Any], clientip: str, description: str, - ): + ) -> dict: """ Checks that the user is who they claim to be, via a UI auth. @@ -199,7 +196,7 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - defer.Deferred[dict]: the parameters for this request (which may + The parameters for this request (which may have been given only in a previous call). Raises: @@ -229,7 +226,7 @@ class AuthHandler(BaseHandler): flows = [[login_type] for login_type in self._supported_ui_auth_types] try: - result, params, _ = yield self.check_auth( + result, params, _ = await self.check_auth( flows, request, request_body, clientip, description ) except LoginError: @@ -268,15 +265,14 @@ class AuthHandler(BaseHandler): """ return self.checkers.keys() - @defer.inlineCallbacks - def check_auth( + async def check_auth( self, flows: List[List[str]], request: SynapseRequest, clientdict: Dict[str, Any], clientip: str, description: str, - ): + ) -> Tuple[dict, dict, str]: """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. @@ -306,8 +302,7 @@ class AuthHandler(BaseHandler): describes the operation happening on their account. Returns: - defer.Deferred[dict, dict, str]: a deferred tuple of - (creds, params, session_id). + A tuple of (creds, params, session_id). 'creds' contains the authenticated credentials of each stage. @@ -380,7 +375,7 @@ class AuthHandler(BaseHandler): if "type" in authdict: login_type = authdict["type"] # type: str try: - result = yield self._check_auth_dict(authdict, clientip) + result = await self._check_auth_dict(authdict, clientip) if result: creds[login_type] = result self._save_session(session) @@ -419,8 +414,9 @@ class AuthHandler(BaseHandler): ret.update(errordict) raise InteractiveAuthIncompleteError(ret) - @defer.inlineCallbacks - def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): + async def add_oob_auth( + self, stagetype: str, authdict: Dict[str, Any], clientip: str + ) -> bool: """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -435,7 +431,7 @@ class AuthHandler(BaseHandler): sess["creds"] = {} creds = sess["creds"] - result = yield self.checkers[stagetype].check_auth(authdict, clientip) + result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: creds[stagetype] = result self._save_session(sess) @@ -489,8 +485,9 @@ class AuthHandler(BaseHandler): sess = self._get_session_info(session_id) return sess.setdefault("serverdict", {}).get(key, default) - @defer.inlineCallbacks - def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): + async def _check_auth_dict( + self, authdict: Dict[str, Any], clientip: str + ) -> Union[Dict[str, Any], str]: """Attempt to validate the auth dict provided by a client Args: @@ -498,7 +495,7 @@ class AuthHandler(BaseHandler): clientip: IP address of the client Returns: - Deferred: result of the stage verification. + Result of the stage verification. Raises: StoreError if there was a problem accessing the database @@ -508,7 +505,7 @@ class AuthHandler(BaseHandler): login_type = authdict["type"] checker = self.checkers.get(login_type) if checker is not None: - res = yield checker.check_auth(authdict, clientip=clientip) + res = await checker.check_auth(authdict, clientip=clientip) return res # build a v1-login-style dict out of the authdict and fall back to the @@ -518,7 +515,7 @@ class AuthHandler(BaseHandler): if user_id is None: raise SynapseError(400, "", Codes.MISSING_PARAM) - (canonical_id, callback) = yield self.validate_login(user_id, authdict) + (canonical_id, callback) = await self.validate_login(user_id, authdict) return canonical_id def _get_params_recaptcha(self) -> dict: @@ -584,8 +581,7 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] - @defer.inlineCallbacks - def get_access_token_for_user_id( + async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): """ @@ -615,10 +611,10 @@ class AuthHandler(BaseHandler): ) logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) access_token = self.macaroon_gen.generate_access_token(user_id) - yield self.store.add_access_token_to_user( + await self.store.add_access_token_to_user( user_id, access_token, device_id, valid_until_ms ) @@ -628,15 +624,14 @@ class AuthHandler(BaseHandler): # device, so we double-check it here. if device_id is not None: try: - yield self.store.get_device(user_id, device_id) + await self.store.get_device(user_id, device_id) except StoreError: - yield self.store.delete_access_token(access_token) + await self.store.delete_access_token(access_token) raise StoreError(400, "Login raced against device deletion") return access_token - @defer.inlineCallbacks - def check_user_exists(self, user_id: str): + async def check_user_exists(self, user_id: str) -> Optional[str]: """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. @@ -645,25 +640,25 @@ class AuthHandler(BaseHandler): user_id: complete @user:id Returns: - defer.Deferred: (unicode) canonical_user_id, or None if zero or - multiple matches + The canonical_user_id, or None if zero or multiple matches """ - res = yield self._find_user_id_and_pwd_hash(user_id) + res = await self._find_user_id_and_pwd_hash(user_id) if res is not None: return res[0] return None - @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id: str): + async def _find_user_id_and_pwd_hash( + self, user_id: str + ) -> Optional[Tuple[str, str]]: """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. Returns: - tuple: A 2-tuple of `(canonical_user_id, password_hash)` - None: if there is not exactly one match + A 2-tuple of `(canonical_user_id, password_hash)` or `None` + if there is not exactly one match """ - user_infos = yield self.store.get_users_by_id_case_insensitive(user_id) + user_infos = await self.store.get_users_by_id_case_insensitive(user_id) result = None if not user_infos: @@ -696,8 +691,9 @@ class AuthHandler(BaseHandler): """ return self._supported_login_types - @defer.inlineCallbacks - def validate_login(self, username: str, login_submission: Dict[str, Any]): + async def validate_login( + self, username: str, login_submission: Dict[str, Any] + ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]: """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate @@ -708,7 +704,7 @@ class AuthHandler(BaseHandler): login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: - Deferred[str, func]: canonical user id, and optional callback + A tuple of the canonical user id, and optional callback to be called once the access token and device id are issued Raises: StoreError if there was a problem accessing the database @@ -737,7 +733,7 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD: known_login_type = True - is_valid = yield provider.check_password(qualified_user_id, password) + is_valid = await provider.check_password(qualified_user_id, password) if is_valid: return qualified_user_id, None @@ -769,7 +765,7 @@ class AuthHandler(BaseHandler): % (login_type, missing_fields), ) - result = yield provider.check_auth(username, login_type, login_dict) + result = await provider.check_auth(username, login_type, login_dict) if result: if isinstance(result, str): result = (result, None) @@ -778,8 +774,8 @@ class AuthHandler(BaseHandler): if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled: known_login_type = True - canonical_user_id = yield self._check_local_password( - qualified_user_id, password + canonical_user_id = await self._check_local_password( + qualified_user_id, password # type: ignore ) if canonical_user_id: @@ -792,8 +788,9 @@ class AuthHandler(BaseHandler): # login, it turns all LoginErrors into a 401 anyway. raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) - @defer.inlineCallbacks - def check_password_provider_3pid(self, medium: str, address: str, password: str): + async def check_password_provider_3pid( + self, medium: str, address: str, password: str + ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]: """Check if a password provider is able to validate a thirdparty login Args: @@ -802,9 +799,8 @@ class AuthHandler(BaseHandler): password: The password of the user. Returns: - Deferred[(str|None, func|None)]: A tuple of `(user_id, - callback)`. If authentication is successful, `user_id` is a `str` - containing the authenticated, canonical user ID. `callback` is + A tuple of `(user_id, callback)`. If authentication is successful, + `user_id`is the authenticated, canonical user ID. `callback` is then either a function to be later run after the server has completed login/registration, or `None`. If authentication was unsuccessful, `user_id` and `callback` are both `None`. @@ -816,7 +812,7 @@ class AuthHandler(BaseHandler): # success, to a str (which is the user_id) or a tuple of # (user_id, callback_func), where callback_func should be run # after we've finished everything else - result = yield provider.check_3pid_auth(medium, address, password) + result = await provider.check_3pid_auth(medium, address, password) if result: # Check if the return value is a str or a tuple if isinstance(result, str): @@ -826,8 +822,7 @@ class AuthHandler(BaseHandler): return None, None - @defer.inlineCallbacks - def _check_local_password(self, user_id: str, password: str): + async def _check_local_password(self, user_id: str, password: str) -> Optional[str]: """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are @@ -837,28 +832,26 @@ class AuthHandler(BaseHandler): user_id: complete @user:id password: the provided password Returns: - Deferred[unicode] the canonical_user_id, or Deferred[None] if - unknown user/bad password + The canonical_user_id, or None if unknown user/bad password """ - lookupres = yield self._find_user_id_and_pwd_hash(user_id) + lookupres = await self._find_user_id_and_pwd_hash(user_id) if not lookupres: return None (user_id, password_hash) = lookupres # If the password hash is None, the account has likely been deactivated if not password_hash: - deactivated = yield self.store.get_user_deactivated_status(user_id) + deactivated = await self.store.get_user_deactivated_status(user_id) if deactivated: raise UserDeactivatedError("This account has been deactivated") - result = yield self.validate_hash(password, password_hash) + result = await self.validate_hash(password, password_hash) if not result: logger.warning("Failed password login for user %s", user_id) return None return user_id - @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token: str): + async def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -868,26 +861,23 @@ class AuthHandler(BaseHandler): except Exception: raise AuthError(403, "Invalid token", errcode=Codes.FORBIDDEN) - yield self.auth.check_auth_blocking(user_id) + await self.auth.check_auth_blocking(user_id) return user_id - @defer.inlineCallbacks - def delete_access_token(self, access_token: str): + async def delete_access_token(self, access_token: str): """Invalidate a single access token Args: access_token: access token to be deleted - Returns: - Deferred """ - user_info = yield self.auth.get_user_by_access_token(access_token) - yield self.store.delete_access_token(access_token) + user_info = await self.auth.get_user_by_access_token(access_token) + await self.store.delete_access_token(access_token) # see if any of our auth providers want to know about this for provider in self.password_providers: if hasattr(provider, "on_logged_out"): - yield provider.on_logged_out( + await provider.on_logged_out( user_id=str(user_info["user"]), device_id=user_info["device_id"], access_token=access_token, @@ -895,12 +885,11 @@ class AuthHandler(BaseHandler): # delete pushers associated with this access token if user_info["token_id"] is not None: - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( str(user_info["user"]), (user_info["token_id"],) ) - @defer.inlineCallbacks - def delete_access_tokens_for_user( + async def delete_access_tokens_for_user( self, user_id: str, except_token_id: Optional[str] = None, @@ -914,10 +903,8 @@ class AuthHandler(BaseHandler): device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted - Returns: - Deferred """ - tokens_and_devices = yield self.store.user_delete_access_tokens( + tokens_and_devices = await self.store.user_delete_access_tokens( user_id, except_token_id=except_token_id, device_id=device_id ) @@ -925,17 +912,18 @@ class AuthHandler(BaseHandler): for provider in self.password_providers: if hasattr(provider, "on_logged_out"): for token, token_id, device_id in tokens_and_devices: - yield provider.on_logged_out( + await provider.on_logged_out( user_id=user_id, device_id=device_id, access_token=token ) # delete pushers associated with the access tokens - yield self.hs.get_pusherpool().remove_pushers_by_access_token( + await self.hs.get_pusherpool().remove_pushers_by_access_token( user_id, (token_id for _, token_id, _ in tokens_and_devices) ) - @defer.inlineCallbacks - def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): + async def add_threepid( + self, user_id: str, medium: str, address: str, validated_at: int + ): # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -956,14 +944,13 @@ class AuthHandler(BaseHandler): if medium == "email": address = address.lower() - yield self.store.user_add_threepid( + await self.store.user_add_threepid( user_id, medium, address, validated_at, self.hs.get_clock().time_msec() ) - @defer.inlineCallbacks - def delete_threepid( + async def delete_threepid( self, user_id: str, medium: str, address: str, id_server: Optional[str] = None - ): + ) -> bool: """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. @@ -976,7 +963,7 @@ class AuthHandler(BaseHandler): identity server specified when binding (if known). Returns: - Deferred[bool]: Returns True if successfully unbound the 3pid on + Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the unbind API. """ @@ -986,11 +973,11 @@ class AuthHandler(BaseHandler): address = address.lower() identity_handler = self.hs.get_handlers().identity_handler - result = yield identity_handler.try_unbind_threepid( + result = await identity_handler.try_unbind_threepid( user_id, {"medium": medium, "address": address, "id_server": id_server} ) - yield self.store.user_delete_threepid(user_id, medium, address) + await self.store.user_delete_threepid(user_id, medium, address) return result def _save_session(self, session: Dict[str, Any]) -> None: @@ -1000,14 +987,14 @@ class AuthHandler(BaseHandler): session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - def hash(self, password: str): + async def hash(self, password: str) -> str: """Computes a secure hash of password. Args: password: Password to hash. Returns: - Deferred(unicode): Hashed password. + Hashed password. """ def _do_hash(): @@ -1019,9 +1006,11 @@ class AuthHandler(BaseHandler): bcrypt.gensalt(self.bcrypt_rounds), ).decode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password: str, stored_hash: bytes): + async def validate_hash( + self, password: str, stored_hash: Union[bytes, str] + ) -> bool: """Validates that self.hash(password) == stored_hash. Args: @@ -1029,7 +1018,7 @@ class AuthHandler(BaseHandler): stored_hash: Expected hash value. Returns: - Deferred(bool): Whether self.hash(password) == stored_hash. + Whether self.hash(password) == stored_hash. """ def _do_validate_hash(): @@ -1045,9 +1034,9 @@ class AuthHandler(BaseHandler): if not isinstance(stored_hash, bytes): stored_hash = stored_hash.encode("ascii") - return defer_to_thread(self.hs.get_reactor(), _do_validate_hash) + return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) else: - return defer.succeed(False) + return False def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: """ diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 993499f446..9bd941b5a0 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -338,8 +338,10 @@ class DeviceHandler(DeviceWorkerHandler): else: raise - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id) @@ -391,8 +393,10 @@ class DeviceHandler(DeviceWorkerHandler): # Delete access tokens and e2e keys for each device. Not optimised as it is not # considered as part of a critical path. for device_id in device_ids: - yield self._auth_handler.delete_access_tokens_for_user( - user_id, device_id=device_id + yield defer.ensureDeferred( + self._auth_handler.delete_access_tokens_for_user( + user_id, device_id=device_id + ) ) yield self.store.delete_e2e_keys_by_device( user_id=user_id, device_id=device_id diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 7ffc194f0c..3a65b46ecd 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -166,7 +166,9 @@ class RegistrationHandler(BaseHandler): yield self.auth.check_auth_blocking(threepid=threepid) password_hash = None if password: - password_hash = yield self._auth_handler.hash(password) + password_hash = yield defer.ensureDeferred( + self._auth_handler.hash(password) + ) if localpart is not None: yield self.check_username(localpart, guest_access_token=guest_access_token) @@ -540,8 +542,10 @@ class RegistrationHandler(BaseHandler): user_id, ["guest = true"] ) else: - access_token = yield self._auth_handler.get_access_token_for_user_id( - user_id, device_id=device_id, valid_until_ms=valid_until_ms + access_token = yield defer.ensureDeferred( + self._auth_handler.get_access_token_for_user_id( + user_id, device_id=device_id, valid_until_ms=valid_until_ms + ) ) return (device_id, access_token) @@ -617,8 +621,13 @@ class RegistrationHandler(BaseHandler): logger.info("Can't add incomplete 3pid") return - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) # And we add an email pusher for them by default, but only @@ -670,6 +679,11 @@ class RegistrationHandler(BaseHandler): return None raise - yield self._auth_handler.add_threepid( - user_id, threepid["medium"], threepid["address"], threepid["validated_at"] + yield defer.ensureDeferred( + self._auth_handler.add_threepid( + user_id, + threepid["medium"], + threepid["address"], + threepid["validated_at"], + ) ) diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index 7d1263caf2..63d8f9aa0d 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -15,8 +15,6 @@ import logging from typing import Optional -from twisted.internet import defer - from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester @@ -34,8 +32,7 @@ class SetPasswordHandler(BaseHandler): self._device_handler = hs.get_device_handler() self._password_policy_handler = hs.get_password_policy_handler() - @defer.inlineCallbacks - def set_password( + async def set_password( self, user_id: str, new_password: str, @@ -46,10 +43,10 @@ class SetPasswordHandler(BaseHandler): raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) self._password_policy_handler.validate_password(new_password) - password_hash = yield self._auth_handler.hash(new_password) + password_hash = await self._auth_handler.hash(new_password) try: - yield self.store.user_set_password_hash(user_id, password_hash) + await self.store.user_set_password_hash(user_id, password_hash) except StoreError as e: if e.code == 404: raise SynapseError(404, "Unknown user", Codes.NOT_FOUND) @@ -61,12 +58,12 @@ class SetPasswordHandler(BaseHandler): except_access_token_id = requester.access_token_id if requester else None # First delete all of their other devices. - yield self._device_handler.delete_all_devices_for_user( + await self._device_handler.delete_all_devices_for_user( user_id, except_device_id=except_device_id ) # and now delete any access tokens which weren't associated with # devices (or were associated with this device). - yield self._auth_handler.delete_access_tokens_for_user( + await self._auth_handler.delete_access_tokens_for_user( user_id, except_token_id=except_access_token_id ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index afc3598e11..d678c0eb9b 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -86,7 +86,7 @@ class ModuleApi(object): Deferred[str|None]: Canonical (case-corrected) user_id, or None if the user is not registered. """ - return self._auth_handler.check_user_exists(user_id) + return defer.ensureDeferred(self._auth_handler.check_user_exists(user_id)) @defer.inlineCallbacks def register(self, localpart, displayname=None, emails=[]): @@ -196,7 +196,9 @@ class ModuleApi(object): yield self._hs.get_device_handler().delete_device(user_id, device_id) else: # no associated device. Just delete the access token. - yield self._auth_handler.delete_access_token(access_token) + yield defer.ensureDeferred( + self._auth_handler.delete_access_token(access_token) + ) def run_db_interaction(self, desc, func, *args, **kwargs): """Run a function with a database connection diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index 6121efcfa9..cc0b10e7f6 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -68,7 +68,7 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_user_bad_token(self): @@ -105,7 +105,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "127.0.0.1" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) @defer.inlineCallbacks @@ -125,7 +125,7 @@ class AuthTestCase(unittest.TestCase): request.getClientIP.return_value = "192.168.10.10" request.args[b"access_token"] = [self.test_token] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals(requester.user.to_string(), self.test_user) def test_get_user_by_req_appservice_valid_token_bad_ip(self): @@ -188,7 +188,7 @@ class AuthTestCase(unittest.TestCase): request.args[b"access_token"] = [self.test_token] request.args[b"user_id"] = [masquerading_user_id] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request) + requester = yield defer.ensureDeferred(self.auth.get_user_by_req(request)) self.assertEquals( requester.user.to_string(), masquerading_user_id.decode("utf8") ) @@ -225,7 +225,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = access") macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) - user_info = yield self.auth.get_user_by_access_token(macaroon.serialize()) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(macaroon.serialize()) + ) user = user_info["user"] self.assertEqual(UserID.from_string(user_id), user) @@ -250,7 +252,9 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("guest = true") serialized = macaroon.serialize() - user_info = yield self.auth.get_user_by_access_token(serialized) + user_info = yield defer.ensureDeferred( + self.auth.get_user_by_access_token(serialized) + ) user = user_info["user"] is_guest = user_info["is_guest"] self.assertEqual(UserID.from_string(user_id), user) @@ -260,10 +264,13 @@ class AuthTestCase(unittest.TestCase): @defer.inlineCallbacks def test_cannot_use_regular_token_as_guest(self): USER_ID = "@percy:matrix.org" - self.store.add_access_token_to_user = Mock() + self.store.add_access_token_to_user = Mock(return_value=defer.succeed(None)) + self.store.get_device = Mock(return_value=defer.succeed(None)) - token = yield self.hs.handlers.auth_handler.get_access_token_for_user_id( - USER_ID, "DEVICE", valid_until_ms=None + token = yield defer.ensureDeferred( + self.hs.handlers.auth_handler.get_access_token_for_user_id( + USER_ID, "DEVICE", valid_until_ms=None + ) ) self.store.add_access_token_to_user.assert_called_with( USER_ID, token, "DEVICE", None @@ -286,7 +293,9 @@ class AuthTestCase(unittest.TestCase): request = Mock(args={}) request.args[b"access_token"] = [token.encode("ascii")] request.requestHeaders.getRawHeaders = mock_getRawHeaders() - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + requester = yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(UserID.from_string(USER_ID), requester.user) self.assertFalse(requester.is_guest) @@ -301,7 +310,9 @@ class AuthTestCase(unittest.TestCase): request.requestHeaders.getRawHeaders = mock_getRawHeaders() with self.assertRaises(InvalidClientCredentialsError) as cm: - yield self.auth.get_user_by_req(request, allow_guest=True) + yield defer.ensureDeferred( + self.auth.get_user_by_req(request, allow_guest=True) + ) self.assertEqual(401, cm.exception.code) self.assertEqual("Guest access token used for regular user", cm.exception.msg) @@ -316,7 +327,7 @@ class AuthTestCase(unittest.TestCase): small_number_of_users = 1 # Ensure no error thrown - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.hs.config.limit_usage_by_mau = True @@ -325,7 +336,7 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -334,7 +345,7 @@ class AuthTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock( return_value=defer.succeed(small_number_of_users) ) - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_blocking_mau__depending_on_user_type(self): @@ -343,15 +354,19 @@ class AuthTestCase(unittest.TestCase): self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Support users allowed - yield self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.SUPPORT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Bots not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(user_type=UserTypes.BOT) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(user_type=UserTypes.BOT) + ) self.store.get_monthly_active_count = Mock(return_value=defer.succeed(100)) # Real users not allowed with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) @defer.inlineCallbacks def test_reserved_threepid(self): @@ -362,21 +377,22 @@ class AuthTestCase(unittest.TestCase): unknown_threepid = {"medium": "email", "address": "unreserved@server.com"} self.hs.config.mau_limits_reserved_threepids = [threepid] - yield self.store.register_user(user_id="user1", password_hash=None) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) with self.assertRaises(ResourceLimitError): - yield self.auth.check_auth_blocking(threepid=unknown_threepid) + yield defer.ensureDeferred( + self.auth.check_auth_blocking(threepid=unknown_threepid) + ) - yield self.auth.check_auth_blocking(threepid=threepid) + yield defer.ensureDeferred(self.auth.check_auth_blocking(threepid=threepid)) @defer.inlineCallbacks def test_hs_disabled(self): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -393,7 +409,7 @@ class AuthTestCase(unittest.TestCase): self.hs.config.hs_disabled = True self.hs.config.hs_disabled_message = "Reason for being disabled" with self.assertRaises(ResourceLimitError) as e: - yield self.auth.check_auth_blocking() + yield defer.ensureDeferred(self.auth.check_auth_blocking()) self.assertEquals(e.exception.admin_contact, self.hs.config.admin_contact) self.assertEquals(e.exception.errcode, Codes.RESOURCE_LIMIT_EXCEEDED) self.assertEquals(e.exception.code, 403) @@ -404,4 +420,4 @@ class AuthTestCase(unittest.TestCase): user = "@user:server" self.hs.config.server_notices_mxid = user self.hs.config.hs_disabled_message = "Reason for being disabled" - yield self.auth.check_auth_blocking(user) + yield defer.ensureDeferred(self.auth.check_auth_blocking(user)) diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index b03103d96f..52c4ac8b11 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -82,16 +82,16 @@ class AuthTestCase(unittest.TestCase): self.hs.clock.now = 1000 token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) self.assertEqual("a_user", user_id) # when we advance the clock, the token should be rejected self.hs.clock.now = 6000 with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - token + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id(token) ) @defer.inlineCallbacks @@ -99,8 +99,10 @@ class AuthTestCase(unittest.TestCase): token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000) macaroon = pymacaroons.Macaroon.deserialize(token) - user_id = yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + user_id = yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) self.assertEqual("a_user", user_id) @@ -109,20 +111,26 @@ class AuthTestCase(unittest.TestCase): macaroon.add_first_party_caveat("user_id = b_user") with self.assertRaises(synapse.api.errors.AuthError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - macaroon.serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + macaroon.serialize() + ) ) @defer.inlineCallbacks def test_mau_limits_disabled(self): self.hs.config.limit_usage_by_mau = False # Ensure does not throw exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -133,16 +141,20 @@ class AuthTestCase(unittest.TestCase): ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.large_number_of_users) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -154,16 +166,20 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) with self.assertRaises(ResourceLimitError): - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) # If in monthly active cohort self.hs.get_datastore().user_last_seen_monthly_active = Mock( @@ -172,8 +188,10 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().user_last_seen_monthly_active = Mock( return_value=defer.succeed(self.hs.get_clock().time_msec()) @@ -181,8 +199,10 @@ class AuthTestCase(unittest.TestCase): self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.hs.config.max_mau_value) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) @defer.inlineCallbacks @@ -193,15 +213,19 @@ class AuthTestCase(unittest.TestCase): return_value=defer.succeed(self.small_number_of_users) ) # Ensure does not raise exception - yield self.auth_handler.get_access_token_for_user_id( - "user_a", device_id=None, valid_until_ms=None + yield defer.ensureDeferred( + self.auth_handler.get_access_token_for_user_id( + "user_a", device_id=None, valid_until_ms=None + ) ) self.hs.get_datastore().get_monthly_active_count = Mock( return_value=defer.succeed(self.small_number_of_users) ) - yield self.auth_handler.validate_short_term_login_token_and_get_user_id( - self._get_macaroon().serialize() + yield defer.ensureDeferred( + self.auth_handler.validate_short_term_login_token_and_get_user_id( + self._get_macaroon().serialize() + ) ) def _get_macaroon(self): diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e7b638dbfe..f1dc51d6c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -294,7 +294,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase): create_profile_with_displayname=user.localpart, ) else: - yield self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + yield defer.ensureDeferred( + self.hs.get_auth_handler().delete_access_tokens_for_user(user_id) + ) yield self.store.add_access_token_to_user( user_id=user_id, token=token, device_id=None, valid_until_ms=None diff --git a/tests/utils.py b/tests/utils.py index 968d109f77..2079e0143d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -332,10 +332,15 @@ def setup_test_homeserver( # Need to let the HS build an auth handler and then mess with it # because AuthHandler's constructor requires the HS, so we can't make one # beforehand and pass it in to the HS's constructor (chicken / egg) - hs.get_auth_handler().hash = lambda p: hashlib.md5(p.encode("utf8")).hexdigest() - hs.get_auth_handler().validate_hash = ( - lambda p, h: hashlib.md5(p.encode("utf8")).hexdigest() == h - ) + async def hash(p): + return hashlib.md5(p.encode("utf8")).hexdigest() + + hs.get_auth_handler().hash = hash + + async def validate_hash(p, h): + return hashlib.md5(p.encode("utf8")).hexdigest() == h + + hs.get_auth_handler().validate_hash = validate_hash fed = kargs.get("resource_for_federation", None) if fed: -- cgit 1.5.1 From fb8ff79efd0897b0b7bf52b0c4bb4061a4ef4018 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 28 Apr 2020 14:21:48 +0100 Subject: Fix collation for postgres for unit tests (#7359) When running the UTs against a postgres deatbase, we need to set the collation correctly. --- changelog.d/7359.misc | 1 + tests/utils.py | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) create mode 100644 changelog.d/7359.misc (limited to 'tests/utils.py') diff --git a/changelog.d/7359.misc b/changelog.d/7359.misc new file mode 100644 index 0000000000..b99f257d9a --- /dev/null +++ b/changelog.d/7359.misc @@ -0,0 +1 @@ +Fix collation for postgres for unit tests. diff --git a/tests/utils.py b/tests/utils.py index 2079e0143d..037cb134f0 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -74,7 +74,10 @@ def setupdb(): db_conn.autocommit = True cur = db_conn.cursor() cur.execute("DROP DATABASE IF EXISTS %s;" % (POSTGRES_BASE_DB,)) - cur.execute("CREATE DATABASE %s;" % (POSTGRES_BASE_DB,)) + cur.execute( + "CREATE DATABASE %s ENCODING 'UTF8' LC_COLLATE='C' LC_CTYPE='C' " + "template=template0;" % (POSTGRES_BASE_DB,) + ) cur.close() db_conn.close() -- cgit 1.5.1 From 627b0f5f2753e6910adb7a877541d50f5936b8a5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Apr 2020 13:47:49 -0400 Subject: Persist user interactive authentication sessions (#7302) By persisting the user interactive authentication sessions to the database, this fixes situations where a user hits different works throughout their auth session and also allows sessions to persist through restarts of Synapse. --- changelog.d/7302.bugfix | 1 + synapse/app/generic_worker.py | 2 + synapse/handlers/auth.py | 175 +++++-------- synapse/handlers/cas_handler.py | 2 +- synapse/handlers/saml_handler.py | 2 +- synapse/rest/client/v2_alpha/auth.py | 4 +- synapse/rest/client/v2_alpha/register.py | 4 +- synapse/storage/data_stores/main/__init__.py | 2 + .../main/schema/delta/58/03persist_ui_auth.sql | 36 +++ synapse/storage/data_stores/main/ui_auth.py | 279 +++++++++++++++++++++ synapse/storage/engines/sqlite.py | 1 + tests/rest/client/v2_alpha/test_auth.py | 40 +++ tests/utils.py | 8 +- tox.ini | 3 +- 14 files changed, 434 insertions(+), 125 deletions(-) create mode 100644 changelog.d/7302.bugfix create mode 100644 synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql create mode 100644 synapse/storage/data_stores/main/ui_auth.py (limited to 'tests/utils.py') diff --git a/changelog.d/7302.bugfix b/changelog.d/7302.bugfix new file mode 100644 index 0000000000..820646d1f9 --- /dev/null +++ b/changelog.d/7302.bugfix @@ -0,0 +1 @@ +Persist user interactive authentication sessions across workers and Synapse restarts. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index d125327f08..0ace7b787d 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -127,6 +127,7 @@ from synapse.storage.data_stores.main.monthly_active_users import ( MonthlyActiveUsersWorkerStore, ) from synapse.storage.data_stores.main.presence import UserPresenceState +from synapse.storage.data_stores.main.ui_auth import UIAuthWorkerStore from synapse.storage.data_stores.main.user_directory import UserDirectoryStore from synapse.types import ReadReceipt from synapse.util.async_helpers import Linearizer @@ -439,6 +440,7 @@ class GenericWorkerSlavedStore( # FIXME(#3714): We need to add UserDirectoryStore as we write directly # rather than going via the correct worker. UserDirectoryStore, + UIAuthWorkerStore, SlavedDeviceInboxStore, SlavedDeviceStore, SlavedReceiptsStore, diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index dbe165ce1e..7613e5b6ab 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -41,10 +41,10 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker from synapse.http.server import finish_request from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID -from synapse.util.caches.expiringcache import ExpiringCache from ._base import BaseHandler @@ -69,15 +69,6 @@ class AuthHandler(BaseHandler): self.bcrypt_rounds = hs.config.bcrypt_rounds - # This is not a cache per se, but a store of all current sessions that - # expire after N hours - self.sessions = ExpiringCache( - cache_name="register_sessions", - clock=hs.get_clock(), - expiry_ms=self.SESSION_EXPIRE_MS, - reset_expiry_on_get=True, - ) - account_handler = ModuleApi(hs, self) self.password_providers = [ module(config=config, account_handler=account_handler) @@ -119,6 +110,15 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() + # Expire old UI auth sessions after a period of time. + if hs.config.worker_app is None: + self._clock.looping_call( + run_as_background_process, + 5 * 60 * 1000, + "expire_old_sessions", + self._expire_old_sessions, + ) + # Load the SSO HTML templates. # The following template is shown to the user during a client login via SSO, @@ -301,16 +301,21 @@ class AuthHandler(BaseHandler): if "session" in authdict: sid = authdict["session"] + # Convert the URI and method to strings. + uri = request.uri.decode("utf-8") + method = request.uri.decode("utf-8") + # If there's no session ID, create a new session. if not sid: - session = self._create_session( - clientdict, (request.uri, request.method, clientdict), description + session = await self.store.create_ui_auth_session( + clientdict, uri, method, description ) - session_id = session["id"] else: - session = self._get_session_info(sid) - session_id = sid + try: + session = await self.store.get_ui_auth_session(sid) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (sid,)) if not clientdict: # This was designed to allow the client to omit the parameters @@ -322,15 +327,15 @@ class AuthHandler(BaseHandler): # on a homeserver. # Revisit: Assuming the REST APIs do sensible validation, the data # isn't arbitrary. - clientdict = session["clientdict"] + clientdict = session.clientdict # Ensure that the queried operation does not vary between stages of # the UI authentication session. This is done by generating a stable # comparator based on the URI, method, and body (minus the auth dict) # and storing it during the initial query. Subsequent queries ensure # that this comparator has not changed. - comparator = (request.uri, request.method, clientdict) - if session["ui_auth"] != comparator: + comparator = (uri, method, clientdict) + if (session.uri, session.method, session.clientdict) != comparator: raise SynapseError( 403, "Requested operation has changed during the UI authentication session.", @@ -338,11 +343,9 @@ class AuthHandler(BaseHandler): if not authdict: raise InteractiveAuthIncompleteError( - self._auth_dict_for_flows(flows, session_id) + self._auth_dict_for_flows(flows, session.session_id) ) - creds = session["creds"] - # check auth type currently being presented errordict = {} # type: Dict[str, Any] if "type" in authdict: @@ -350,8 +353,9 @@ class AuthHandler(BaseHandler): try: result = await self._check_auth_dict(authdict, clientip) if result: - creds[login_type] = result - self._save_session(session) + await self.store.mark_ui_auth_stage_complete( + session.session_id, login_type, result + ) except LoginError as e: if login_type == LoginType.EMAIL_IDENTITY: # riot used to have a bug where it would request a new @@ -367,6 +371,7 @@ class AuthHandler(BaseHandler): # so that the client can have another go. errordict = e.error_dict() + creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can @@ -380,9 +385,9 @@ class AuthHandler(BaseHandler): list(clientdict), ) - return creds, clientdict, session_id + return creds, clientdict, session.session_id - ret = self._auth_dict_for_flows(flows, session_id) + ret = self._auth_dict_for_flows(flows, session.session_id) ret["completed"] = list(creds) ret.update(errordict) raise InteractiveAuthIncompleteError(ret) @@ -399,13 +404,11 @@ class AuthHandler(BaseHandler): if "session" not in authdict: raise LoginError(400, "", Codes.MISSING_PARAM) - sess = self._get_session_info(authdict["session"]) - creds = sess["creds"] - result = await self.checkers[stagetype].check_auth(authdict, clientip) if result: - creds[stagetype] = result - self._save_session(sess) + await self.store.mark_ui_auth_stage_complete( + authdict["session"], stagetype, result + ) return True return False @@ -427,7 +430,7 @@ class AuthHandler(BaseHandler): sid = authdict["session"] return sid - def set_session_data(self, session_id: str, key: str, value: Any) -> None: + async def set_session_data(self, session_id: str, key: str, value: Any) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by @@ -438,11 +441,12 @@ class AuthHandler(BaseHandler): key: The key to store the data under value: The data to store """ - sess = self._get_session_info(session_id) - sess["serverdict"][key] = value - self._save_session(sess) + try: + await self.store.set_ui_auth_session_data(session_id, key, value) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - def get_session_data( + async def get_session_data( self, session_id: str, key: str, default: Optional[Any] = None ) -> Any: """ @@ -453,8 +457,18 @@ class AuthHandler(BaseHandler): key: The key to store the data under default: Value to return if the key has not been set """ - sess = self._get_session_info(session_id) - return sess["serverdict"].get(key, default) + try: + return await self.store.get_ui_auth_session_data(session_id, key, default) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) + + async def _expire_old_sessions(self): + """ + Invalidate any user interactive authentication sessions that have expired. + """ + now = self._clock.time_msec() + expiration_time = now - self.SESSION_EXPIRE_MS + await self.store.delete_old_ui_auth_sessions(expiration_time) async def _check_auth_dict( self, authdict: Dict[str, Any], clientip: str @@ -534,67 +548,6 @@ class AuthHandler(BaseHandler): "params": params, } - def _create_session( - self, - clientdict: Dict[str, Any], - ui_auth: Tuple[bytes, bytes, Dict[str, Any]], - description: str, - ) -> dict: - """ - Creates a new user interactive authentication session. - - The session can be used to track data across multiple requests, e.g. for - interactive authentication. - - Each session has the following keys: - - id: - A unique identifier for this session. Passed back to the client - and returned for each stage. - clientdict: - The dictionary from the client root level, not the 'auth' key. - ui_auth: - A tuple which is checked at each stage of the authentication to - ensure that the asked for operation has not changed. - creds: - A map, which maps each auth-type (str) to the relevant identity - authenticated by that auth-type (mostly str, but for captcha, bool). - serverdict: - A map of data that is stored server-side and cannot be modified - by the client. - description: - A string description of the operation that the current - authentication is authorising. - Returns: - The newly created session. - """ - session_id = None - while session_id is None or session_id in self.sessions: - session_id = stringutils.random_string(24) - - self.sessions[session_id] = { - "id": session_id, - "clientdict": clientdict, - "ui_auth": ui_auth, - "creds": {}, - "serverdict": {}, - "description": description, - } - - return self.sessions[session_id] - - def _get_session_info(self, session_id: str) -> dict: - """ - Gets a session given a session ID. - - The session can be used to track data across multiple requests, e.g. for - interactive authentication. - """ - try: - return self.sessions[session_id] - except KeyError: - raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) - async def get_access_token_for_user_id( self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] ): @@ -994,13 +947,6 @@ class AuthHandler(BaseHandler): await self.store.user_delete_threepid(user_id, medium, address) return result - def _save_session(self, session: Dict[str, Any]) -> None: - """Update the last used time on the session to now and add it back to the session store.""" - # TODO: Persistent storage - logger.debug("Saving session %s", session) - session["last_used"] = self.hs.get_clock().time_msec() - self.sessions[session["id"]] = session - async def hash(self, password: str) -> str: """Computes a secure hash of password. @@ -1052,7 +998,7 @@ class AuthHandler(BaseHandler): else: return False - def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: + async def start_sso_ui_auth(self, redirect_url: str, session_id: str) -> str: """ Get the HTML for the SSO redirect confirmation page. @@ -1063,12 +1009,15 @@ class AuthHandler(BaseHandler): Returns: The HTML to render. """ - session = self._get_session_info(session_id) + try: + session = await self.store.get_ui_auth_session(session_id) + except StoreError: + raise SynapseError(400, "Unknown session ID: %s" % (session_id,)) return self._sso_auth_confirm_template.render( - description=session["description"], redirect_url=redirect_url, + description=session.description, redirect_url=redirect_url, ) - def complete_sso_ui_auth( + async def complete_sso_ui_auth( self, registered_user_id: str, session_id: str, request: SynapseRequest, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1080,13 +1029,11 @@ class AuthHandler(BaseHandler): process. """ # Mark the stage of the authentication as successful. - sess = self._get_session_info(session_id) - creds = sess["creds"] - # Save the user who authenticated with SSO, this will be used to ensure # that the account be modified is also the person who logged in. - creds[LoginType.SSO] = registered_user_id - self._save_session(sess) + await self.store.mark_ui_auth_stage_complete( + session_id, LoginType.SSO, registered_user_id + ) # Render the HTML and return. html_bytes = self._sso_auth_success_template.encode("utf-8") diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 5cb3f9d133..64aaa1335c 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -206,7 +206,7 @@ class CasHandler: registered_user_id = await self._auth_handler.check_user_exists(user_id) if session: - self._auth_handler.complete_sso_ui_auth( + await self._auth_handler.complete_sso_ui_auth( registered_user_id, session, request, ) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 7c9454b504..96f2dd36ad 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -149,7 +149,7 @@ class SamlHandler: # Complete the interactive auth session or the login. if current_session and current_session.ui_auth_session_id: - self._auth_handler.complete_sso_ui_auth( + await self._auth_handler.complete_sso_ui_auth( user_id, current_session.ui_auth_session_id, request ) diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py index 11599f5005..24dd3d3e96 100644 --- a/synapse/rest/client/v2_alpha/auth.py +++ b/synapse/rest/client/v2_alpha/auth.py @@ -140,7 +140,7 @@ class AuthRestServlet(RestServlet): self._cas_server_url = hs.config.cas_server_url self._cas_service_url = hs.config.cas_service_url - def on_GET(self, request, stagetype): + async def on_GET(self, request, stagetype): session = parse_string(request, "session") if not session: raise SynapseError(400, "No session supplied") @@ -180,7 +180,7 @@ class AuthRestServlet(RestServlet): else: raise SynapseError(400, "Homeserver not configured for SSO.") - html = self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) + html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session) else: raise SynapseError(404, "Unknown auth stage type") diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index d1b5c49989..af08cc6cce 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -499,7 +499,7 @@ class RegisterRestServlet(RestServlet): # registered a user for this session, so we could just return the # user here. We carry on and go through the auth checks though, # for paranoia. - registered_user_id = self.auth_handler.get_session_data( + registered_user_id = await self.auth_handler.get_session_data( session_id, "registered_user_id", None ) @@ -598,7 +598,7 @@ class RegisterRestServlet(RestServlet): # remember that we've now registered that user account, and with # what user ID (since the user may not have specified) - self.auth_handler.set_session_data( + await self.auth_handler.set_session_data( session_id, "registered_user_id", registered_user_id ) diff --git a/synapse/storage/data_stores/main/__init__.py b/synapse/storage/data_stores/main/__init__.py index bd7c3a00ea..ceba10882c 100644 --- a/synapse/storage/data_stores/main/__init__.py +++ b/synapse/storage/data_stores/main/__init__.py @@ -66,6 +66,7 @@ from .stats import StatsStore from .stream import StreamStore from .tags import TagsStore from .transactions import TransactionStore +from .ui_auth import UIAuthStore from .user_directory import UserDirectoryStore from .user_erasure_store import UserErasureStore @@ -112,6 +113,7 @@ class DataStore( StatsStore, RelationsStore, CacheInvalidationStore, + UIAuthStore, ): def __init__(self, database: Database, db_conn, hs): self.hs = hs diff --git a/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql new file mode 100644 index 0000000000..dcb593fc2d --- /dev/null +++ b/synapse/storage/data_stores/main/schema/delta/58/03persist_ui_auth.sql @@ -0,0 +1,36 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS ui_auth_sessions( + session_id TEXT NOT NULL, -- The session ID passed to the client. + creation_time BIGINT NOT NULL, -- The time this session was created (epoch time in milliseconds). + serverdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data added by Synapse. + clientdict TEXT NOT NULL, -- A JSON dictionary of arbitrary data from the client. + uri TEXT NOT NULL, -- The URI the UI authentication session is using. + method TEXT NOT NULL, -- The HTTP method the UI authentication session is using. + -- The clientdict, uri, and method make up an tuple that must be immutable + -- throughout the lifetime of the UI Auth session. + description TEXT NOT NULL, -- A human readable description of the operation which caused the UI Auth flow to occur. + UNIQUE (session_id) +); + +CREATE TABLE IF NOT EXISTS ui_auth_sessions_credentials( + session_id TEXT NOT NULL, -- The corresponding UI Auth session. + stage_type TEXT NOT NULL, -- The stage type. + result TEXT NOT NULL, -- The result of the stage verification, stored as JSON. + UNIQUE (session_id, stage_type), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/data_stores/main/ui_auth.py b/synapse/storage/data_stores/main/ui_auth.py new file mode 100644 index 0000000000..c8eebc9378 --- /dev/null +++ b/synapse/storage/data_stores/main/ui_auth.py @@ -0,0 +1,279 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json +from typing import Any, Dict, Optional, Union + +import attr + +import synapse.util.stringutils as stringutils +from synapse.api.errors import StoreError +from synapse.storage._base import SQLBaseStore +from synapse.types import JsonDict + + +@attr.s +class UIAuthSessionData: + session_id = attr.ib(type=str) + # The dictionary from the client root level, not the 'auth' key. + clientdict = attr.ib(type=JsonDict) + # The URI and method the session was intiatied with. These are checked at + # each stage of the authentication to ensure that the asked for operation + # has not changed. + uri = attr.ib(type=str) + method = attr.ib(type=str) + # A string description of the operation that the current authentication is + # authorising. + description = attr.ib(type=str) + + +class UIAuthWorkerStore(SQLBaseStore): + """ + Manage user interactive authentication sessions. + """ + + async def create_ui_auth_session( + self, clientdict: JsonDict, uri: str, method: str, description: str, + ) -> UIAuthSessionData: + """ + Creates a new user interactive authentication session. + + The session can be used to track the stages necessary to authenticate a + user across multiple HTTP requests. + + Args: + clientdict: + The dictionary from the client root level, not the 'auth' key. + uri: + The URI this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + method: + The method this session was initiated with, this is checked at each + stage of the authentication to ensure that the asked for + operation has not changed. + description: + A string description of the operation that the current + authentication is authorising. + Returns: + The newly created session. + Raises: + StoreError if a unique session ID cannot be generated. + """ + # The clientdict gets stored as JSON. + clientdict_json = json.dumps(clientdict) + + # autogen a session ID and try to create it. We may clash, so just + # try a few times till one goes through, giving up eventually. + attempts = 0 + while attempts < 5: + session_id = stringutils.random_string(24) + + try: + await self.db.simple_insert( + table="ui_auth_sessions", + values={ + "session_id": session_id, + "clientdict": clientdict_json, + "uri": uri, + "method": method, + "description": description, + "serverdict": "{}", + "creation_time": self.hs.get_clock().time_msec(), + }, + desc="create_ui_auth_session", + ) + return UIAuthSessionData( + session_id, clientdict, uri, method, description + ) + except self.db.engine.module.IntegrityError: + attempts += 1 + raise StoreError(500, "Couldn't generate a session ID.") + + async def get_ui_auth_session(self, session_id: str) -> UIAuthSessionData: + """Retrieve a UI auth session. + + Args: + session_id: The ID of the session. + Returns: + A dict containing the device information. + Raises: + StoreError if the session is not found. + """ + result = await self.db.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("clientdict", "uri", "method", "description"), + desc="get_ui_auth_session", + ) + + result["clientdict"] = json.loads(result["clientdict"]) + + return UIAuthSessionData(session_id, **result) + + async def mark_ui_auth_stage_complete( + self, session_id: str, stage_type: str, result: Union[str, bool, JsonDict], + ): + """ + Mark a session stage as completed. + + Args: + session_id: The ID of the corresponding session. + stage_type: The completed stage type. + result: The result of the stage verification. + Raises: + StoreError if the session cannot be found. + """ + # Add (or update) the results of the current stage to the database. + # + # Note that we need to allow for the same stage to complete multiple + # times here so that registration is idempotent. + try: + await self.db.simple_upsert( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id, "stage_type": stage_type}, + values={"result": json.dumps(result)}, + desc="mark_ui_auth_stage_complete", + ) + except self.db.engine.module.IntegrityError: + raise StoreError(400, "Unknown session ID: %s" % (session_id,)) + + async def get_completed_ui_auth_stages( + self, session_id: str + ) -> Dict[str, Union[str, bool, JsonDict]]: + """ + Retrieve the completed stages of a UI authentication session. + + Args: + session_id: The ID of the session. + Returns: + The completed stages mapped to the result of the verification of + that auth-type. + """ + results = {} + for row in await self.db.simple_select_list( + table="ui_auth_sessions_credentials", + keyvalues={"session_id": session_id}, + retcols=("stage_type", "result"), + desc="get_completed_ui_auth_stages", + ): + results[row["stage_type"]] = json.loads(row["result"]) + + return results + + async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any): + """ + Store a key-value pair into the sessions data associated with this + request. This data is stored server-side and cannot be modified by + the client. + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store + Raises: + StoreError if the session cannot be found. + """ + await self.db.runInteraction( + "set_ui_auth_session_data", + self._set_ui_auth_session_data_txn, + session_id, + key, + value, + ) + + def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any): + # Get the current value. + result = self.db.simple_select_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + ) + + # Update it and add it back to the database. + serverdict = json.loads(result["serverdict"]) + serverdict[key] = value + + self.db.simple_update_one_txn( + txn, + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + updatevalues={"serverdict": json.dumps(serverdict)}, + ) + + async def get_ui_auth_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: + """ + Retrieve data stored with set_session_data + + Args: + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set + Raises: + StoreError if the session cannot be found. + """ + result = await self.db.simple_select_one( + table="ui_auth_sessions", + keyvalues={"session_id": session_id}, + retcols=("serverdict",), + desc="get_ui_auth_session_data", + ) + + serverdict = json.loads(result["serverdict"]) + + return serverdict.get(key, default) + + +class UIAuthStore(UIAuthWorkerStore): + def delete_old_ui_auth_sessions(self, expiration_time: int): + """ + Remove sessions which were last used earlier than the expiration time. + + Args: + expiration_time: The latest time that is still considered valid. + This is an epoch time in milliseconds. + + """ + return self.db.runInteraction( + "delete_old_ui_auth_sessions", + self._delete_old_ui_auth_sessions_txn, + expiration_time, + ) + + def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int): + # Get the expired sessions. + sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?" + txn.execute(sql, [expiration_time]) + session_ids = [r[0] for r in txn.fetchall()] + + # Delete the corresponding completed credentials. + self.db.simple_delete_many_txn( + txn, + table="ui_auth_sessions_credentials", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + + # Finally, delete the sessions. + self.db.simple_delete_many_txn( + txn, + table="ui_auth_sessions", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index 3bc2e8b986..215a949442 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -85,6 +85,7 @@ class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): prepare_database(db_conn, self, config=None) db_conn.create_function("rank", 1, _rank) + db_conn.execute("PRAGMA foreign_keys = ON;") def is_deadlock(self, error): return False diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 624bf5ada2..587be7b2e7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -181,3 +181,43 @@ class FallbackAuthTests(unittest.HomeserverTestCase): ) self.render(request) self.assertEqual(channel.code, 403) + + def test_complete_operation_unknown_session(self): + """ + Attempting to mark an invalid session as complete should error. + """ + + # Make the initial request to register. (Later on a different password + # will be used.) + request, channel = self.make_request( + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + self.render(request) + + # Returns a 401 as per the spec + self.assertEqual(request.code, 401) + # Grab the session + session = channel.json_body["session"] + # Assert our configured public key is being given + self.assertEqual( + channel.json_body["params"]["m.login.recaptcha"]["public_key"], "brokencake" + ) + + request, channel = self.make_request( + "GET", "auth/m.login.recaptcha/fallback/web?session=" + session + ) + self.render(request) + self.assertEqual(request.code, 200) + + # Attempt to complete an unknown session, which should return an error. + unknown_session = session + "unknown" + request, channel = self.make_request( + "POST", + "auth/m.login.recaptcha/fallback/web?session=" + + unknown_session + + "&g-recaptcha-response=a", + ) + self.render(request) + self.assertEqual(request.code, 400) diff --git a/tests/utils.py b/tests/utils.py index 037cb134f0..f9be62b499 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -512,8 +512,8 @@ class MockClock(object): return t - def looping_call(self, function, interval): - self.loopers.append([function, interval / 1000.0, self.now]) + def looping_call(self, function, interval, *args, **kwargs): + self.loopers.append([function, interval / 1000.0, self.now, args, kwargs]) def cancel_call_later(self, timer, ignore_errs=False): if timer[2]: @@ -543,9 +543,9 @@ class MockClock(object): self.timers.append(t) for looped in self.loopers: - func, interval, last = looped + func, interval, last, args, kwargs = looped if last + interval < self.now: - func() + func(*args, **kwargs) looped[2] = self.now def advance_time_msec(self, ms): diff --git a/tox.ini b/tox.ini index 2630857436..eccc44e436 100644 --- a/tox.ini +++ b/tox.ini @@ -200,8 +200,9 @@ commands = mypy \ synapse/replication \ synapse/rest \ synapse/spam_checker_api \ - synapse/storage/engines \ + synapse/storage/data_stores/main/ui_auth.py \ synapse/storage/database.py \ + synapse/storage/engines \ synapse/streams \ synapse/util/caches/stream_change_cache.py \ tests/replication/tcp/streams \ -- cgit 1.5.1 From 7cb8b4bc67042a39bd1b0e05df46089a2fce1955 Mon Sep 17 00:00:00 2001 From: Amber Brown Date: Tue, 12 May 2020 03:45:23 +1000 Subject: Allow configuration of Synapse's cache without using synctl or environment variables (#6391) --- changelog.d/6391.feature | 1 + docs/sample_config.yaml | 43 +++++- synapse/api/auth.py | 4 +- synapse/app/homeserver.py | 5 +- synapse/config/cache.py | 164 ++++++++++++++++++++++ synapse/config/database.py | 6 - synapse/config/homeserver.py | 2 + synapse/http/client.py | 6 +- synapse/metrics/_exposition.py | 12 +- synapse/push/bulk_push_rule_evaluator.py | 4 +- synapse/push/push_rule_evaluator.py | 4 +- synapse/replication/slave/storage/client_ips.py | 3 +- synapse/state/__init__.py | 4 +- synapse/storage/data_stores/main/client_ips.py | 3 +- synapse/storage/data_stores/main/events_worker.py | 5 +- synapse/storage/data_stores/state/store.py | 6 +- synapse/util/caches/__init__.py | 144 ++++++++++--------- synapse/util/caches/descriptors.py | 36 ++++- synapse/util/caches/expiringcache.py | 29 +++- synapse/util/caches/lrucache.py | 52 +++++-- synapse/util/caches/response_cache.py | 2 +- synapse/util/caches/stream_change_cache.py | 33 ++++- synapse/util/caches/ttlcache.py | 2 +- tests/config/test_cache.py | 127 +++++++++++++++++ tests/storage/test__base.py | 8 +- tests/storage/test_appservice.py | 10 +- tests/storage/test_base.py | 3 +- tests/test_metrics.py | 34 +++++ tests/util/test_expiring_cache.py | 2 +- tests/util/test_lrucache.py | 6 +- tests/util/test_stream_change_cache.py | 5 +- tests/utils.py | 1 + 32 files changed, 620 insertions(+), 146 deletions(-) create mode 100644 changelog.d/6391.feature create mode 100644 synapse/config/cache.py create mode 100644 tests/config/test_cache.py (limited to 'tests/utils.py') diff --git a/changelog.d/6391.feature b/changelog.d/6391.feature new file mode 100644 index 0000000000..f123426e23 --- /dev/null +++ b/changelog.d/6391.feature @@ -0,0 +1 @@ +Synapse's cache factor can now be configured in `homeserver.yaml` by the `caches.global_factor` setting. Additionally, `caches.per_cache_factors` controls the cache factors for individual caches. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 5abeaf519b..8a8415b9a2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -603,6 +603,45 @@ acme: +## Caching ## + +# Caching can be configured through the following options. +# +# A cache 'factor' is a multiplier that can be applied to each of +# Synapse's caches in order to increase or decrease the maximum +# number of entries that can be stored. + +# The number of events to cache in memory. Not affected by +# caches.global_factor. +# +#event_cache_size: 10K + +caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + + ## Database ## # The 'database' setting defines the database that synapse uses to store all of @@ -646,10 +685,6 @@ database: args: database: DATADIR/homeserver.db -# Number of events to cache in memory. -# -#event_cache_size: 10K - ## Logging ## diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 1ad5ff9410..e009b1a760 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -37,7 +37,7 @@ from synapse.api.errors import ( from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.types import StateMap, UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache from synapse.util.metrics import Measure @@ -73,7 +73,7 @@ class Auth(object): self.store = hs.get_datastore() self.state = hs.get_state_handler() - self.token_cache = LruCache(CACHE_SIZE_FACTOR * 10000) + self.token_cache = LruCache(10000) register_cache("cache", "token_cache", self.token_cache) self._auth_blocking = AuthBlocking(self.hs) diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bc8695d8dd..d7f337e586 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -69,7 +69,6 @@ from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import IncorrectDatabaseSetup from synapse.storage.prepare_database import UpgradeDatabaseException -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.httpresourcetree import create_resource_tree from synapse.util.manhole import manhole from synapse.util.module_loader import load_module @@ -516,8 +515,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process): daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages() stats["daily_sent_messages"] = daily_sent_messages - stats["cache_factor"] = CACHE_SIZE_FACTOR - stats["event_cache_size"] = hs.config.event_cache_size + stats["cache_factor"] = hs.config.caches.global_factor + stats["event_cache_size"] = hs.config.caches.event_cache_size # # Performance statistics diff --git a/synapse/config/cache.py b/synapse/config/cache.py new file mode 100644 index 0000000000..91036a012e --- /dev/null +++ b/synapse/config/cache.py @@ -0,0 +1,164 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, Dict + +from ._base import Config, ConfigError + +# The prefix for all cache factor-related environment variables +_CACHES = {} +_CACHE_PREFIX = "SYNAPSE_CACHE_FACTOR" +_DEFAULT_FACTOR_SIZE = 0.5 +_DEFAULT_EVENT_CACHE_SIZE = "10K" + + +class CacheProperties(object): + def __init__(self): + # The default factor size for all caches + self.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + self.resize_all_caches_func = None + + +properties = CacheProperties() + + +def add_resizable_cache(cache_name: str, cache_resize_callback: Callable): + """Register a cache that's size can dynamically change + + Args: + cache_name: A reference to the cache + cache_resize_callback: A callback function that will be ran whenever + the cache needs to be resized + """ + _CACHES[cache_name.lower()] = cache_resize_callback + + # Ensure all loaded caches are sized appropriately + # + # This method should only run once the config has been read, + # as it uses values read from it + if properties.resize_all_caches_func: + properties.resize_all_caches_func() + + +class CacheConfig(Config): + section = "caches" + _environ = os.environ + + @staticmethod + def reset(): + """Resets the caches to their defaults. Used for tests.""" + properties.default_factor_size = float( + os.environ.get(_CACHE_PREFIX, _DEFAULT_FACTOR_SIZE) + ) + properties.resize_all_caches_func = None + _CACHES.clear() + + def generate_config_section(self, **kwargs): + return """\ + ## Caching ## + + # Caching can be configured through the following options. + # + # A cache 'factor' is a multiplier that can be applied to each of + # Synapse's caches in order to increase or decrease the maximum + # number of entries that can be stored. + + # The number of events to cache in memory. Not affected by + # caches.global_factor. + # + #event_cache_size: 10K + + caches: + # Controls the global cache factor, which is the default cache factor + # for all caches if a specific factor for that cache is not otherwise + # set. + # + # This can also be set by the "SYNAPSE_CACHE_FACTOR" environment + # variable. Setting by environment variable takes priority over + # setting through the config file. + # + # Defaults to 0.5, which will half the size of all caches. + # + #global_factor: 1.0 + + # A dictionary of cache name to cache factor for that individual + # cache. Overrides the global cache factor for a given cache. + # + # These can also be set through environment variables comprised + # of "SYNAPSE_CACHE_FACTOR_" + the name of the cache in capital + # letters and underscores. Setting by environment variable + # takes priority over setting through the config file. + # Ex. SYNAPSE_CACHE_FACTOR_GET_USERS_WHO_SHARE_ROOM_WITH_USER=2.0 + # + per_cache_factors: + #get_users_who_share_room_with_user: 2.0 + """ + + def read_config(self, config, **kwargs): + self.event_cache_size = self.parse_size( + config.get("event_cache_size", _DEFAULT_EVENT_CACHE_SIZE) + ) + self.cache_factors = {} # type: Dict[str, float] + + cache_config = config.get("caches") or {} + self.global_factor = cache_config.get( + "global_factor", properties.default_factor_size + ) + if not isinstance(self.global_factor, (int, float)): + raise ConfigError("caches.global_factor must be a number.") + + # Set the global one so that it's reflected in new caches + properties.default_factor_size = self.global_factor + + # Load cache factors from the config + individual_factors = cache_config.get("per_cache_factors") or {} + if not isinstance(individual_factors, dict): + raise ConfigError("caches.per_cache_factors must be a dictionary") + + # Override factors from environment if necessary + individual_factors.update( + { + key[len(_CACHE_PREFIX) + 1 :].lower(): float(val) + for key, val in self._environ.items() + if key.startswith(_CACHE_PREFIX + "_") + } + ) + + for cache, factor in individual_factors.items(): + if not isinstance(factor, (int, float)): + raise ConfigError( + "caches.per_cache_factors.%s must be a number" % (cache.lower(),) + ) + self.cache_factors[cache.lower()] = factor + + # Resize all caches (if necessary) with the new factors we've loaded + self.resize_all_caches() + + # Store this function so that it can be called from other classes without + # needing an instance of Config + properties.resize_all_caches_func = self.resize_all_caches + + def resize_all_caches(self): + """Ensure all cache sizes are up to date + + For each cache, run the mapped callback function with either + a specific cache factor or the default, global one. + """ + for cache_name, callback in _CACHES.items(): + new_factor = self.cache_factors.get(cache_name, self.global_factor) + callback(new_factor) diff --git a/synapse/config/database.py b/synapse/config/database.py index 5b662d1b01..1064c2697b 100644 --- a/synapse/config/database.py +++ b/synapse/config/database.py @@ -68,10 +68,6 @@ database: name: sqlite3 args: database: %(database_path)s - -# Number of events to cache in memory. -# -#event_cache_size: 10K """ @@ -116,8 +112,6 @@ class DatabaseConfig(Config): self.databases = [] def read_config(self, config, **kwargs): - self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K")) - # We *experimentally* support specifying multiple databases via the # `databases` key. This is a map from a label to database config in the # same format as the `database` config option, plus an extra diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 996d3e6bf7..2c7b3a699f 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig from .consent_config import ConsentConfig @@ -55,6 +56,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, TlsConfig, + CacheConfig, DatabaseConfig, LoggingConfig, RatelimitConfig, diff --git a/synapse/http/client.py b/synapse/http/client.py index 58eb47c69c..3cef747a4d 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -49,7 +49,6 @@ from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util.async_helpers import timeout_deferred -from synapse.util.caches import CACHE_SIZE_FACTOR logger = logging.getLogger(__name__) @@ -241,7 +240,10 @@ class SimpleHttpClient(object): # tends to do so in batches, so we need to allow the pool to keep # lots of idle connections around. pool = HTTPConnectionPool(self.reactor) - pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5)) + # XXX: The justification for using the cache factor here is that larger instances + # will need both more cache and more connections. + # Still, this should probably be a separate dial + pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5)) pool.cachedConnectionTimeout = 2 * 60 self.agent = ProxyAgent( diff --git a/synapse/metrics/_exposition.py b/synapse/metrics/_exposition.py index a248103191..ab7f948ed4 100644 --- a/synapse/metrics/_exposition.py +++ b/synapse/metrics/_exposition.py @@ -33,6 +33,8 @@ from prometheus_client import REGISTRY from twisted.web.resource import Resource +from synapse.util import caches + try: from prometheus_client.samples import Sample except ImportError: @@ -103,13 +105,15 @@ def nameify_sample(sample): def generate_latest(registry, emit_help=False): - output = [] - for metric in registry.collect(): + # Trigger the cache metrics to be rescraped, which updates the common + # metrics but do not produce metrics themselves + for collector in caches.collectors_by_name.values(): + collector.collect() - if metric.name.startswith("__unused"): - continue + output = [] + for metric in registry.collect(): if not metric.samples: # No samples, don't bother. continue diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 433ca2f416..e75d964ac8 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -51,6 +51,7 @@ push_rules_delta_state_cache_metric = register_cache( "cache", "push_rules_delta_state_cache_metric", cache=[], # Meaningless size, as this isn't a cache that stores values + resizable=False, ) @@ -67,7 +68,8 @@ class BulkPushRuleEvaluator(object): self.room_push_rule_cache_metrics = register_cache( "cache", "room_push_rule_cache", - cache=[], # Meaningless size, as this isn't a cache that stores values + cache=[], # Meaningless size, as this isn't a cache that stores values, + resizable=False, ) @defer.inlineCallbacks diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py index 4cd702b5fa..11032491af 100644 --- a/synapse/push/push_rule_evaluator.py +++ b/synapse/push/push_rule_evaluator.py @@ -22,7 +22,7 @@ from six import string_types from synapse.events import EventBase from synapse.types import UserID -from synapse.util.caches import CACHE_SIZE_FACTOR, register_cache +from synapse.util.caches import register_cache from synapse.util.caches.lrucache import LruCache logger = logging.getLogger(__name__) @@ -165,7 +165,7 @@ class PushRuleEvaluatorForEvent(object): # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches -regex_cache = LruCache(50000 * CACHE_SIZE_FACTOR) +regex_cache = LruCache(50000) register_cache("cache", "regex_push_cache", regex_cache) diff --git a/synapse/replication/slave/storage/client_ips.py b/synapse/replication/slave/storage/client_ips.py index fbf996e33a..1a38f53dfb 100644 --- a/synapse/replication/slave/storage/client_ips.py +++ b/synapse/replication/slave/storage/client_ips.py @@ -15,7 +15,6 @@ from synapse.storage.data_stores.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.database import Database -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache from ._base import BaseSlavedStore @@ -26,7 +25,7 @@ class SlavedClientIpStore(BaseSlavedStore): super(SlavedClientIpStore, self).__init__(database, db_conn, hs) self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4afefc6b1d..2fa529fcd0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -35,7 +35,6 @@ from synapse.state import v1, v2 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour from synapse.types import StateMap from synapse.util.async_helpers import Linearizer -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import Measure, measure_func @@ -53,7 +52,6 @@ state_groups_histogram = Histogram( KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) -SIZE_OF_CACHE = 100000 * get_cache_factor_for("state_cache") EVICTION_TIMEOUT_SECONDS = 60 * 60 @@ -447,7 +445,7 @@ class StateResolutionHandler(object): self._state_cache = ExpiringCache( cache_name="state_cache", clock=self.clock, - max_len=SIZE_OF_CACHE, + max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, iterable=True, reset_expiry_on_get=True, diff --git a/synapse/storage/data_stores/main/client_ips.py b/synapse/storage/data_stores/main/client_ips.py index 92bc06919b..71f8d43a76 100644 --- a/synapse/storage/data_stores/main/client_ips.py +++ b/synapse/storage/data_stores/main/client_ips.py @@ -22,7 +22,6 @@ from twisted.internet import defer from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore from synapse.storage.database import Database, make_tuple_comparison_clause -from synapse.util.caches import CACHE_SIZE_FACTOR from synapse.util.caches.descriptors import Cache logger = logging.getLogger(__name__) @@ -361,7 +360,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): def __init__(self, database: Database, db_conn, hs): self.client_ip_last_seen = Cache( - name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR + name="client_ip_last_seen", keylen=4, max_entries=50000 ) super(ClientIpStore, self).__init__(database, db_conn, hs) diff --git a/synapse/storage/data_stores/main/events_worker.py b/synapse/storage/data_stores/main/events_worker.py index 73df6b33ba..b8c1bbdf99 100644 --- a/synapse/storage/data_stores/main/events_worker.py +++ b/synapse/storage/data_stores/main/events_worker.py @@ -75,7 +75,10 @@ class EventsWorkerStore(SQLBaseStore): super(EventsWorkerStore, self).__init__(database, db_conn, hs) self._get_event_cache = Cache( - "*getEvent*", keylen=3, max_entries=hs.config.event_cache_size + "*getEvent*", + keylen=3, + max_entries=hs.config.caches.event_cache_size, + apply_cache_factor_from_config=False, ) self._event_fetch_lock = threading.Condition() diff --git a/synapse/storage/data_stores/state/store.py b/synapse/storage/data_stores/state/store.py index 57a5267663..f3ad1e4369 100644 --- a/synapse/storage/data_stores/state/store.py +++ b/synapse/storage/data_stores/state/store.py @@ -28,7 +28,6 @@ from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateSt from synapse.storage.database import Database from synapse.storage.state import StateFilter from synapse.types import StateMap -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.descriptors import cached from synapse.util.caches.dictionary_cache import DictionaryCache @@ -90,11 +89,10 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore): self._state_group_cache = DictionaryCache( "*stateGroupCache*", # TODO: this hasn't been tuned yet - 50000 * get_cache_factor_for("stateGroupCache"), + 50000, ) self._state_group_members_cache = DictionaryCache( - "*stateGroupMembersCache*", - 500000 * get_cache_factor_for("stateGroupMembersCache"), + "*stateGroupMembersCache*", 500000, ) @cached(max_entries=10000, iterable=True) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index da5077b471..4b8a0c7a8f 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd -# Copyright 2019 The Matrix.org Foundation C.I.C. +# Copyright 2019, 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -15,27 +15,17 @@ # limitations under the License. import logging -import os -from typing import Dict +from typing import Callable, Dict, Optional import six from six.moves import intern -from prometheus_client.core import REGISTRY, Gauge, GaugeMetricFamily - -logger = logging.getLogger(__name__) - -CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.5)) +import attr +from prometheus_client.core import Gauge +from synapse.config.cache import add_resizable_cache -def get_cache_factor_for(cache_name): - env_var = "SYNAPSE_CACHE_FACTOR_" + cache_name.upper() - factor = os.environ.get(env_var) - if factor: - return float(factor) - - return CACHE_SIZE_FACTOR - +logger = logging.getLogger(__name__) caches_by_name = {} collectors_by_name = {} # type: Dict @@ -44,6 +34,7 @@ cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_evicted = Gauge("synapse_util_caches_cache:evicted_size", "", ["name"]) cache_total = Gauge("synapse_util_caches_cache:total", "", ["name"]) +cache_max_size = Gauge("synapse_util_caches_cache_max_size", "", ["name"]) response_cache_size = Gauge("synapse_util_caches_response_cache:size", "", ["name"]) response_cache_hits = Gauge("synapse_util_caches_response_cache:hits", "", ["name"]) @@ -53,67 +44,82 @@ response_cache_evicted = Gauge( response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["name"]) -def register_cache(cache_type, cache_name, cache, collect_callback=None): - """Register a cache object for metric collection. +@attr.s +class CacheMetric(object): + + _cache = attr.ib() + _cache_type = attr.ib(type=str) + _cache_name = attr.ib(type=str) + _collect_callback = attr.ib(type=Optional[Callable]) + + hits = attr.ib(default=0) + misses = attr.ib(default=0) + evicted_size = attr.ib(default=0) + + def inc_hits(self): + self.hits += 1 + + def inc_misses(self): + self.misses += 1 + + def inc_evictions(self, size=1): + self.evicted_size += size + + def describe(self): + return [] + + def collect(self): + try: + if self._cache_type == "response_cache": + response_cache_size.labels(self._cache_name).set(len(self._cache)) + response_cache_hits.labels(self._cache_name).set(self.hits) + response_cache_evicted.labels(self._cache_name).set(self.evicted_size) + response_cache_total.labels(self._cache_name).set( + self.hits + self.misses + ) + else: + cache_size.labels(self._cache_name).set(len(self._cache)) + cache_hits.labels(self._cache_name).set(self.hits) + cache_evicted.labels(self._cache_name).set(self.evicted_size) + cache_total.labels(self._cache_name).set(self.hits + self.misses) + if getattr(self._cache, "max_size", None): + cache_max_size.labels(self._cache_name).set(self._cache.max_size) + if self._collect_callback: + self._collect_callback() + except Exception as e: + logger.warning("Error calculating metrics for %s: %s", self._cache_name, e) + raise + + +def register_cache( + cache_type: str, + cache_name: str, + cache, + collect_callback: Optional[Callable] = None, + resizable: bool = True, + resize_callback: Optional[Callable] = None, +) -> CacheMetric: + """Register a cache object for metric collection and resizing. Args: - cache_type (str): - cache_name (str): name of the cache - cache (object): cache itself - collect_callback (callable|None): if not None, a function which is called during - metric collection to update additional metrics. + cache_type + cache_name: name of the cache + cache: cache itself + collect_callback: If given, a function which is called during metric + collection to update additional metrics. + resizable: Whether this cache supports being resized. + resize_callback: A function which can be called to resize the cache. Returns: CacheMetric: an object which provides inc_{hits,misses,evictions} methods """ + if resizable: + if not resize_callback: + resize_callback = getattr(cache, "set_cache_factor") + add_resizable_cache(cache_name, resize_callback) - # Check if the metric is already registered. Unregister it, if so. - # This usually happens during tests, as at runtime these caches are - # effectively singletons. + metric = CacheMetric(cache, cache_type, cache_name, collect_callback) metric_name = "cache_%s_%s" % (cache_type, cache_name) - if metric_name in collectors_by_name.keys(): - REGISTRY.unregister(collectors_by_name[metric_name]) - - class CacheMetric(object): - - hits = 0 - misses = 0 - evicted_size = 0 - - def inc_hits(self): - self.hits += 1 - - def inc_misses(self): - self.misses += 1 - - def inc_evictions(self, size=1): - self.evicted_size += size - - def describe(self): - return [] - - def collect(self): - try: - if cache_type == "response_cache": - response_cache_size.labels(cache_name).set(len(cache)) - response_cache_hits.labels(cache_name).set(self.hits) - response_cache_evicted.labels(cache_name).set(self.evicted_size) - response_cache_total.labels(cache_name).set(self.hits + self.misses) - else: - cache_size.labels(cache_name).set(len(cache)) - cache_hits.labels(cache_name).set(self.hits) - cache_evicted.labels(cache_name).set(self.evicted_size) - cache_total.labels(cache_name).set(self.hits + self.misses) - if collect_callback: - collect_callback() - except Exception as e: - logger.warning("Error calculating metrics for %s: %s", cache_name, e) - raise - - yield GaugeMetricFamily("__unused", "") - - metric = CacheMetric() - REGISTRY.register(metric) caches_by_name[cache_name] = cache collectors_by_name[metric_name] = metric return metric diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 2e8f6543e5..cd48262420 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import functools import inspect import logging @@ -30,7 +31,6 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.util import unwrapFirstError from synapse.util.async_helpers import ObservableDeferred -from synapse.util.caches import get_cache_factor_for from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry @@ -81,7 +81,6 @@ class CacheEntry(object): class Cache(object): __slots__ = ( "cache", - "max_entries", "name", "keylen", "thread", @@ -89,7 +88,29 @@ class Cache(object): "_pending_deferred_cache", ) - def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False): + def __init__( + self, + name: str, + max_entries: int = 1000, + keylen: int = 1, + tree: bool = False, + iterable: bool = False, + apply_cache_factor_from_config: bool = True, + ): + """ + Args: + name: The name of the cache + max_entries: Maximum amount of entries that the cache will hold + keylen: The length of the tuple used as the cache key + tree: Use a TreeCache instead of a dict as the underlying cache type + iterable: If True, count each item in the cached object as an entry, + rather than each cached object + apply_cache_factor_from_config: Whether cache factors specified in the + config file affect `max_entries` + + Returns: + Cache + """ cache_type = TreeCache if tree else dict self._pending_deferred_cache = cache_type() @@ -99,6 +120,7 @@ class Cache(object): cache_type=cache_type, size_callback=(lambda d: len(d)) if iterable else None, evicted_callback=self._on_evicted, + apply_cache_factor_from_config=apply_cache_factor_from_config, ) self.name = name @@ -111,6 +133,10 @@ class Cache(object): collect_callback=self._metrics_collection_callback, ) + @property + def max_entries(self): + return self.cache.max_size + def _on_evicted(self, evicted_count): self.metrics.inc_evictions(evicted_count) @@ -370,13 +396,11 @@ class CacheDescriptor(_CacheDescriptorBase): cache_context=cache_context, ) - max_entries = int(max_entries * get_cache_factor_for(orig.__name__)) - self.max_entries = max_entries self.tree = tree self.iterable = iterable - def __get__(self, obj, objtype=None): + def __get__(self, obj, owner): cache = Cache( name=self.orig.__name__, max_entries=self.max_entries, diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index cddf1ed515..2726b67b6d 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -18,6 +18,7 @@ from collections import OrderedDict from six import iteritems, itervalues +from synapse.config import cache as cache_config from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import register_cache @@ -51,15 +52,16 @@ class ExpiringCache(object): an item on access. Defaults to False. iterable (bool): If true, the size is calculated by summing the sizes of all entries, rather than the number of entries. - """ self._cache_name = cache_name + self._original_max_size = max_len + + self._max_size = int(max_len * cache_config.properties.default_factor_size) + self._clock = clock - self._max_len = max_len self._expiry_ms = expiry_ms - self._reset_expiry_on_get = reset_expiry_on_get self._cache = OrderedDict() @@ -82,9 +84,11 @@ class ExpiringCache(object): def __setitem__(self, key, value): now = self._clock.time_msec() self._cache[key] = _CacheEntry(now, value) + self.evict() + def evict(self): # Evict if there are now too many items - while self._max_len and len(self) > self._max_len: + while self._max_size and len(self) > self._max_size: _key, value = self._cache.popitem(last=False) if self.iterable: self.metrics.inc_evictions(len(value.value)) @@ -170,6 +174,23 @@ class ExpiringCache(object): else: return len(self._cache) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self._max_size: + self._max_size = new_size + self.evict() + return True + return False + class _CacheEntry(object): __slots__ = ["time", "value"] diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 1536cb64f3..29fabac3cd 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. - import threading from functools import wraps +from typing import Callable, Optional, Type, Union +from synapse.config import cache as cache_config from synapse.util.caches.treecache import TreeCache @@ -52,17 +53,18 @@ class LruCache(object): def __init__( self, - max_size, - keylen=1, - cache_type=dict, - size_callback=None, - evicted_callback=None, + max_size: int, + keylen: int = 1, + cache_type: Type[Union[dict, TreeCache]] = dict, + size_callback: Optional[Callable] = None, + evicted_callback: Optional[Callable] = None, + apply_cache_factor_from_config: bool = True, ): """ Args: - max_size (int): + max_size: The maximum amount of entries the cache can hold - keylen (int): + keylen: The length of the tuple used as the cache key cache_type (type): type of underlying cache to be used. Typically one of dict @@ -73,9 +75,23 @@ class LruCache(object): evicted_callback (func(int)|None): if not None, called on eviction with the size of the evicted entry + + apply_cache_factor_from_config (bool): If true, `max_size` will be + multiplied by a cache factor derived from the homeserver config """ cache = cache_type() self.cache = cache # Used for introspection. + + # Save the original max size, and apply the default size factor. + self._original_max_size = max_size + # We previously didn't apply the cache factor here, and as such some caches were + # not affected by the global cache factor. Add an option here to disable applying + # the cache factor when a cache is created + if apply_cache_factor_from_config: + self.max_size = int(max_size * cache_config.properties.default_factor_size) + else: + self.max_size = int(max_size) + list_root = _Node(None, None, None, None) list_root.next_node = list_root list_root.prev_node = list_root @@ -83,7 +99,7 @@ class LruCache(object): lock = threading.Lock() def evict(): - while cache_len() > max_size: + while cache_len() > self.max_size: todelete = list_root.prev_node evicted_len = delete_node(todelete) cache.pop(todelete.key, None) @@ -236,6 +252,7 @@ class LruCache(object): return key in cache self.sentinel = object() + self._on_resize = evict self.get = cache_get self.set = cache_set self.setdefault = cache_set_default @@ -266,3 +283,20 @@ class LruCache(object): def __contains__(self, key): return self.contains(key) + + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = int(self._original_max_size * factor) + if new_size != self.max_size: + self.max_size = new_size + self._on_resize() + return True + return False diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index b68f9fe0d4..a6c60888e5 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -38,7 +38,7 @@ class ResponseCache(object): self.timeout_sec = timeout_ms / 1000.0 self._name = name - self._metrics = register_cache("response_cache", name, self) + self._metrics = register_cache("response_cache", name, self, resizable=False) def size(self): return len(self.pending_result_cache) diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index e54f80d76e..2a161bf244 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +import math from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union from six import integer_types @@ -46,7 +47,8 @@ class StreamChangeCache: max_size=10000, prefilled_cache: Optional[Mapping[EntityType, int]] = None, ): - self._max_size = int(max_size * caches.CACHE_SIZE_FACTOR) + self._original_max_size = max_size + self._max_size = math.floor(max_size) self._entity_to_key = {} # type: Dict[EntityType, int] # map from stream id to the a set of entities which changed at that stream id. @@ -58,12 +60,31 @@ class StreamChangeCache: # self._earliest_known_stream_pos = current_stream_pos self.name = name - self.metrics = caches.register_cache("cache", self.name, self._cache) + self.metrics = caches.register_cache( + "cache", self.name, self._cache, resize_callback=self.set_cache_factor + ) if prefilled_cache: for entity, stream_pos in prefilled_cache.items(): self.entity_has_changed(entity, stream_pos) + def set_cache_factor(self, factor: float) -> bool: + """ + Set the cache factor for this individual cache. + + This will trigger a resize if it changes, which may require evicting + items from the cache. + + Returns: + bool: Whether the cache changed size or not. + """ + new_size = math.floor(self._original_max_size * factor) + if new_size != self._max_size: + self.max_size = new_size + self._evict() + return True + return False + def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: """Returns True if the entity may have been updated since stream_pos """ @@ -171,6 +192,7 @@ class StreamChangeCache: e1 = self._cache[stream_pos] = set() e1.add(entity) self._entity_to_key[entity] = stream_pos + self._evict() # if the cache is too big, remove entries while len(self._cache) > self._max_size: @@ -179,6 +201,13 @@ class StreamChangeCache: for entity in r: del self._entity_to_key[entity] + def _evict(self): + while len(self._cache) > self._max_size: + k, r = self._cache.popitem(0) + self._earliest_known_stream_pos = max(k, self._earliest_known_stream_pos) + for entity in r: + self._entity_to_key.pop(entity, None) + def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an diff --git a/synapse/util/caches/ttlcache.py b/synapse/util/caches/ttlcache.py index 99646c7cf0..6437aa907e 100644 --- a/synapse/util/caches/ttlcache.py +++ b/synapse/util/caches/ttlcache.py @@ -38,7 +38,7 @@ class TTLCache(object): self._timer = timer - self._metrics = register_cache("ttl", cache_name, self) + self._metrics = register_cache("ttl", cache_name, self, resizable=False) def set(self, key, value, ttl): """Add/update an entry in the cache diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py new file mode 100644 index 0000000000..2920279125 --- /dev/null +++ b/tests/config/test_cache.py @@ -0,0 +1,127 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config, RootConfig +from synapse.config.cache import CacheConfig, add_resizable_cache +from synapse.util.caches.lrucache import LruCache + +from tests.unittest import TestCase + + +class FakeServer(Config): + section = "server" + + +class TestConfig(RootConfig): + config_classes = [FakeServer, CacheConfig] + + +class CacheConfigTests(TestCase): + def setUp(self): + # Reset caches before each test + TestConfig().caches.reset() + + def test_individual_caches_from_environ(self): + """ + Individual cache factors will be loaded from the environment. + """ + config = {} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_NOT_CACHE": "BLAH", + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(dict(t.caches.cache_factors), {"something_or_other": 2.0}) + + def test_config_overrides_environ(self): + """ + Individual cache factors defined in the environment will take precedence + over those in the config. + """ + config = {"caches": {"per_cache_factors": {"foo": 2, "bar": 3}}} + t = TestConfig() + t.caches._environ = { + "SYNAPSE_CACHE_FACTOR_SOMETHING_OR_OTHER": "2", + "SYNAPSE_CACHE_FACTOR_FOO": 1, + } + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual( + dict(t.caches.cache_factors), + {"foo": 1.0, "bar": 3.0, "something_or_other": 2.0}, + ) + + def test_individual_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized once the config + is loaded. + """ + cache = LruCache(100) + + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"per_cache_factors": {"foo": 3}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 300) + + def test_individual_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the per_cache_factor if + there is one. + """ + config = {"caches": {"per_cache_factors": {"foo": 2}}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 200) + + def test_global_instantiated_before_config_load(self): + """ + If a cache is instantiated before the config is read, it will be given + the default cache size in the interim, and then resized to the new + default cache size once the config is loaded. + """ + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 50) + + config = {"caches": {"global_factor": 4}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + self.assertEqual(cache.max_size, 400) + + def test_global_instantiated_after_config_load(self): + """ + If a cache is instantiated after the config is read, it will be + immediately resized to the correct size given the global factor if there + is no per-cache factor. + """ + config = {"caches": {"global_factor": 1.5}} + t = TestConfig() + t.read_config(config, config_dir_path="", data_dir_path="") + + cache = LruCache(100) + add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) + self.assertEqual(cache.max_size, 150) diff --git a/tests/storage/test__base.py b/tests/storage/test__base.py index e37260a820..5a50e4fdd4 100644 --- a/tests/storage/test__base.py +++ b/tests/storage/test__base.py @@ -25,8 +25,8 @@ from synapse.util.caches.descriptors import Cache, cached from tests import unittest -class CacheTestCase(unittest.TestCase): - def setUp(self): +class CacheTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, homeserver): self.cache = Cache("test") def test_empty(self): @@ -96,7 +96,7 @@ class CacheTestCase(unittest.TestCase): cache.get(3) -class CacheDecoratorTestCase(unittest.TestCase): +class CacheDecoratorTestCase(unittest.HomeserverTestCase): @defer.inlineCallbacks def test_passthrough(self): class A(object): @@ -239,7 +239,7 @@ class CacheDecoratorTestCase(unittest.TestCase): callcount2 = [0] class A(object): - @cached(max_entries=4) # HACK: This makes it 2 due to cache factor + @cached(max_entries=2) def func(self, key): callcount[0] += 1 return key diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py index 31710949a8..ef296e7dab 100644 --- a/tests/storage/test_appservice.py +++ b/tests/storage/test_appservice.py @@ -43,7 +43,7 @@ class ApplicationServiceStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_token = "token1" @@ -110,7 +110,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase): ) hs.config.app_service_config_files = self.as_yaml_files - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] self.as_list = [ @@ -422,7 +422,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] database = hs.get_datastores().databases[0] @@ -440,7 +440,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: @@ -464,7 +464,7 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase): ) hs.config.app_service_config_files = [f1, f2] - hs.config.event_cache_size = 1 + hs.config.caches.event_cache_size = 1 hs.config.password_providers = [] with self.assertRaises(ConfigError) as cm: diff --git a/tests/storage/test_base.py b/tests/storage/test_base.py index cdee0a9e60..278961c331 100644 --- a/tests/storage/test_base.py +++ b/tests/storage/test_base.py @@ -51,7 +51,8 @@ class SQLBaseStoreTestCase(unittest.TestCase): config = Mock() config._disable_native_upserts = True - config.event_cache_size = 1 + config.caches = Mock() + config.caches.event_cache_size = 1 hs = TestHomeServer("test", config=config) sqlite_config = {"name": "sqlite3"} diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 270f853d60..f5f63d8ed6 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -15,6 +15,7 @@ # limitations under the License. from synapse.metrics import REGISTRY, InFlightGauge, generate_latest +from synapse.util.caches.descriptors import Cache from tests import unittest @@ -129,3 +130,36 @@ class BuildInfoTests(unittest.TestCase): self.assertTrue(b"osversion=" in items[0]) self.assertTrue(b"pythonversion=" in items[0]) self.assertTrue(b"version=" in items[0]) + + +class CacheMetricsTests(unittest.HomeserverTestCase): + def test_cache_metric(self): + """ + Caches produce metrics reflecting their state when scraped. + """ + CACHE_NAME = "cache_metrics_test_fgjkbdfg" + cache = Cache(CACHE_NAME, max_entries=777) + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "0.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") + + cache.prefill("1", "hi") + + items = { + x.split(b"{")[0].decode("ascii"): x.split(b" ")[1].decode("ascii") + for x in filter( + lambda x: b"cache_metrics_test_fgjkbdfg" in x, + generate_latest(REGISTRY).split(b"\n"), + ) + } + + self.assertEqual(items["synapse_util_caches_cache_size"], "1.0") + self.assertEqual(items["synapse_util_caches_cache_max_size"], "777.0") diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 50bc7702d2..49ffeebd0e 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -21,7 +21,7 @@ from tests.utils import MockClock from .. import unittest -class ExpiringCacheTestCase(unittest.TestCase): +class ExpiringCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): clock = MockClock() cache = ExpiringCache("test", clock, max_len=1) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 786947375d..0adb2174af 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -22,7 +22,7 @@ from synapse.util.caches.treecache import TreeCache from .. import unittest -class LruCacheTestCase(unittest.TestCase): +class LruCacheTestCase(unittest.HomeserverTestCase): def test_get_set(self): cache = LruCache(1) cache["key"] = "value" @@ -84,7 +84,7 @@ class LruCacheTestCase(unittest.TestCase): self.assertEquals(len(cache), 0) -class LruCacheCallbacksTestCase(unittest.TestCase): +class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self): m = Mock() cache = LruCache(1) @@ -233,7 +233,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase): self.assertEquals(m3.call_count, 1) -class LruCacheSizedTestCase(unittest.TestCase): +class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self): cache = LruCache(5, size_callback=len) cache["key1"] = [0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 6857933540..13b753e367 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -1,11 +1,9 @@ -from mock import patch - from synapse.util.caches.stream_change_cache import StreamChangeCache from tests import unittest -class StreamChangeCacheTests(unittest.TestCase): +class StreamChangeCacheTests(unittest.HomeserverTestCase): """ Tests for StreamChangeCache. """ @@ -54,7 +52,6 @@ class StreamChangeCacheTests(unittest.TestCase): self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - @patch("synapse.util.caches.CACHE_SIZE_FACTOR", 1.0) def test_entity_has_changed_pops_off_start(self): """ StreamChangeCache.entity_has_changed will respect the max size and diff --git a/tests/utils.py b/tests/utils.py index f9be62b499..59c020a051 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -167,6 +167,7 @@ def default_config(name, parse=False): # disable user directory updates, because they get done in the # background, which upsets the test runner. "update_user_directory": False, + "caches": {"global_factor": 1}, } if parse: -- cgit 1.5.1