From 0b7830e457359ce651b293c8748bf636973404a9 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Wed, 19 Oct 2022 19:38:24 +0000 Subject: Bump flake8-bugbear from 21.3.2 to 22.9.23 (#14042) Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> Co-authored-by: Erik Johnston Co-authored-by: David Robertson --- tests/util/caches/test_descriptors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests/util') diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 90861fe522..78fd7b6961 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -1037,5 +1037,5 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj = Cls() # Make sure this raises an error about the arg mismatch - with self.assertRaises(Exception): + with self.assertRaises(TypeError): obj.list_fn([("foo", "bar")]) -- cgit 1.5.1 From c9dffd5b330553c5803784be5bc0e2479fab79b0 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Tue, 25 Oct 2022 11:39:25 +0100 Subject: Remove unused `@lru_cache` decorator (#13595) * Remove unused `@lru_cache` decorator Spotted this working on something else. Co-authored-by: David Robertson --- changelog.d/13595.misc | 1 + synapse/util/caches/descriptors.py | 104 ---------------------------------- tests/util/caches/test_descriptors.py | 40 ++----------- 3 files changed, 5 insertions(+), 140 deletions(-) create mode 100644 changelog.d/13595.misc (limited to 'tests/util') diff --git a/changelog.d/13595.misc b/changelog.d/13595.misc new file mode 100644 index 0000000000..71959a6ee7 --- /dev/null +++ b/changelog.d/13595.misc @@ -0,0 +1 @@ +Remove unused `@lru_cache` decorator. diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index b3c748ef44..75428d19ba 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import enum import functools import inspect import logging @@ -146,109 +145,6 @@ class _CacheDescriptorBase: ) -class _LruCachedFunction(Generic[F]): - cache: LruCache[CacheKey, Any] - __call__: F - - -def lru_cache( - *, max_entries: int = 1000, cache_context: bool = False -) -> Callable[[F], _LruCachedFunction[F]]: - """A method decorator that applies a memoizing cache around the function. - - This is more-or-less a drop-in equivalent to functools.lru_cache, although note - that the signature is slightly different. - - The main differences with functools.lru_cache are: - (a) the size of the cache can be controlled via the cache_factor mechanism - (b) the wrapped function can request a "cache_context" which provides a - callback mechanism to indicate that the result is no longer valid - (c) prometheus metrics are exposed automatically. - - The function should take zero or more arguments, which are used as the key for the - cache. Single-argument functions use that argument as the cache key; otherwise the - arguments are built into a tuple. - - Cached functions can be "chained" (i.e. a cached function can call other cached - functions and get appropriately invalidated when they called caches are - invalidated) by adding a special "cache_context" argument to the function - and passing that as a kwarg to all caches called. For example: - - @lru_cache(cache_context=True) - def foo(self, key, cache_context): - r1 = self.bar1(key, on_invalidate=cache_context.invalidate) - r2 = self.bar2(key, on_invalidate=cache_context.invalidate) - return r1 + r2 - - The wrapped function also has a 'cache' property which offers direct access to the - underlying LruCache. - """ - - def func(orig: F) -> _LruCachedFunction[F]: - desc = LruCacheDescriptor( - orig, - max_entries=max_entries, - cache_context=cache_context, - ) - return cast(_LruCachedFunction[F], desc) - - return func - - -class LruCacheDescriptor(_CacheDescriptorBase): - """Helper for @lru_cache""" - - class _Sentinel(enum.Enum): - sentinel = object() - - def __init__( - self, - orig: Callable[..., Any], - max_entries: int = 1000, - cache_context: bool = False, - ): - super().__init__( - orig, num_args=None, uncached_args=None, cache_context=cache_context - ) - self.max_entries = max_entries - - def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]: - cache: LruCache[CacheKey, Any] = LruCache( - cache_name=self.name, - max_size=self.max_entries, - ) - - get_cache_key = self.cache_key_builder - sentinel = LruCacheDescriptor._Sentinel.sentinel - - @functools.wraps(self.orig) - def _wrapped(*args: Any, **kwargs: Any) -> Any: - invalidate_callback = kwargs.pop("on_invalidate", None) - callbacks = (invalidate_callback,) if invalidate_callback else () - - cache_key = get_cache_key(args, kwargs) - - ret = cache.get(cache_key, default=sentinel, callbacks=callbacks) - if ret != sentinel: - return ret - - # Add our own `cache_context` to argument list if the wrapped function - # has asked for one - if self.add_cache_context: - kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key) - - ret2 = self.orig(obj, *args, **kwargs) - cache.set(cache_key, ret2, callbacks=callbacks) - - return ret2 - - wrapped = cast(CachedFunction, _wrapped) - wrapped.cache = cache - obj.__dict__[self.name] = wrapped - - return wrapped - - class DeferredCacheDescriptor(_CacheDescriptorBase): """A method decorator that applies a memoizing cache around the function. diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 78fd7b6961..43475a307f 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -28,7 +28,7 @@ from synapse.logging.context import ( make_deferred_yieldable, ) from synapse.util.caches import descriptors -from synapse.util.caches.descriptors import cached, cachedList, lru_cache +from synapse.util.caches.descriptors import cached, cachedList from tests import unittest from tests.test_utils import get_awaitable_result @@ -36,38 +36,6 @@ from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) -class LruCacheDecoratorTestCase(unittest.TestCase): - def test_base(self): - class Cls: - def __init__(self): - self.mock = mock.Mock() - - @lru_cache() - def fn(self, arg1, arg2): - return self.mock(arg1, arg2) - - obj = Cls() - obj.mock.return_value = "fish" - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - obj.mock.assert_called_once_with(1, 2) - obj.mock.reset_mock() - - # a call with different params should call the mock again - obj.mock.return_value = "chips" - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_called_once_with(1, 3) - obj.mock.reset_mock() - - # the two values should now be cached - r = obj.fn(1, 2) - self.assertEqual(r, "fish") - r = obj.fn(1, 3) - self.assertEqual(r, "chips") - obj.mock.assert_not_called() - - def run_on_reactor(): d = defer.Deferred() reactor.callLater(0, d.callback, 0) @@ -478,10 +446,10 @@ class DescriptorTestCase(unittest.TestCase): @cached(cache_context=True) async def func2(self, key, cache_context): - return self.func3(key, on_invalidate=cache_context.invalidate) + return await self.func3(key, on_invalidate=cache_context.invalidate) - @lru_cache(cache_context=True) - def func3(self, key, cache_context): + @cached(cache_context=True) + async def func3(self, key, cache_context): self.invalidate = cache_context.invalidate return 42 -- cgit 1.5.1 From 8756d5c87efc5637da55c9e21d2a4eb2369ba693 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 26 Oct 2022 12:45:41 +0200 Subject: Save login tokens in database (#13844) * Save login tokens in database Signed-off-by: Quentin Gliech * Add upgrade notes * Track login token reuse in a Prometheus metric Signed-off-by: Quentin Gliech --- changelog.d/13844.misc | 1 + docs/upgrade.md | 9 ++ synapse/handlers/auth.py | 64 +++++++-- synapse/module_api/__init__.py | 41 +----- synapse/rest/client/login.py | 3 +- synapse/rest/client/login_token_request.py | 5 +- synapse/storage/databases/main/registration.py | 156 ++++++++++++++++++++- .../schema/main/delta/73/10login_tokens.sql | 35 +++++ synapse/util/macaroons.py | 87 +----------- tests/handlers/test_auth.py | 135 ++++++++++-------- tests/util/test_macaroons.py | 28 ---- 11 files changed, 337 insertions(+), 227 deletions(-) create mode 100644 changelog.d/13844.misc create mode 100644 synapse/storage/schema/main/delta/73/10login_tokens.sql (limited to 'tests/util') diff --git a/changelog.d/13844.misc b/changelog.d/13844.misc new file mode 100644 index 0000000000..66f4414df7 --- /dev/null +++ b/changelog.d/13844.misc @@ -0,0 +1 @@ +Save login tokens in database and prevent login token reuse. diff --git a/docs/upgrade.md b/docs/upgrade.md index b81385b191..78c34d0c15 100644 --- a/docs/upgrade.md +++ b/docs/upgrade.md @@ -88,6 +88,15 @@ process, for example: dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb ``` +# Upgrading to v1.71.0 + +## Removal of the `generate_short_term_login_token` module API method + +As announced with the release of [Synapse 1.69.0](#deprecation-of-the-generate_short_term_login_token-module-api-method), the deprecated `generate_short_term_login_token` module method has been removed. + +Modules relying on it can instead use the `create_login_token` method. + + # Upgrading to v1.69.0 ## Changes to the receipts replication streams diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index f5f0e0e7a7..8b9ef25d29 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -38,6 +38,7 @@ from typing import ( import attr import bcrypt import unpaddedbase64 +from prometheus_client import Counter from twisted.internet.defer import CancelledError from twisted.web.server import Request @@ -48,6 +49,7 @@ from synapse.api.errors import ( Codes, InteractiveAuthIncompleteError, LoginError, + NotFoundError, StoreError, SynapseError, UserDeactivatedError, @@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.storage.databases.main.registration import ( + LoginTokenExpired, + LoginTokenLookupResult, + LoginTokenReused, +) from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import delay_cancellation, maybe_awaitable -from synapse.util.macaroons import LoginTokenAttributes from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import base62_encode from synapse.util.threepids import canonicalise_email @@ -80,6 +86,12 @@ logger = logging.getLogger(__name__) INVALID_USERNAME_OR_PASSWORD = "Invalid username or password" +invalid_login_token_counter = Counter( + "synapse_user_login_invalid_login_tokens", + "Counts the number of rejected m.login.token on /login", + ["reason"], +) + def convert_client_dict_legacy_fields_to_identifier( submission: JsonDict, @@ -883,6 +895,25 @@ class AuthHandler: return True + async def create_login_token_for_user_id( + self, + user_id: str, + duration_ms: int = (2 * 60 * 1000), + auth_provider_id: Optional[str] = None, + auth_provider_session_id: Optional[str] = None, + ) -> str: + login_token = self.generate_login_token() + now = self._clock.time_msec() + expiry_ts = now + duration_ms + await self.store.add_login_token_to_user( + user_id=user_id, + token=login_token, + expiry_ts=expiry_ts, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + return login_token + async def create_refresh_token_for_user_id( self, user_id: str, @@ -1401,6 +1432,18 @@ class AuthHandler: return None return user_id + def generate_login_token(self) -> str: + """Generates an opaque string, for use as an short-term login token""" + + # we use the following format for access tokens: + # syl__ + + random_string = stringutils.random_string(20) + base = f"syl_{random_string}" + + crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) + return f"{base}_{crc}" + def generate_access_token(self, for_user: UserID) -> str: """Generates an opaque string, for use as an access token""" @@ -1427,16 +1470,17 @@ class AuthHandler: crc = base62_encode(crc32(base.encode("ascii")), minwidth=6) return f"{base}_{crc}" - async def validate_short_term_login_token( - self, login_token: str - ) -> LoginTokenAttributes: + async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult: try: - res = self.macaroon_gen.verify_short_term_login_token(login_token) - except Exception: - raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) + return await self.store.consume_login_token(login_token) + except LoginTokenExpired: + invalid_login_token_counter.labels("expired").inc() + except LoginTokenReused: + invalid_login_token_counter.labels("reused").inc() + except NotFoundError: + invalid_login_token_counter.labels("not found").inc() - await self.auth_blocking.check_auth_blocking(res.user_id) - return res + raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN) async def delete_access_token(self, access_token: str) -> None: """Invalidate a single access token @@ -1711,7 +1755,7 @@ class AuthHandler: ) # Create a login token - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.create_login_token_for_user_id( registered_user_id, auth_provider_id=auth_provider_id, auth_provider_session_id=auth_provider_session_id, diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 6a6ae208d1..30e689d00d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -771,50 +771,11 @@ class ModuleApi: auth_provider_session_id: The session ID got during login from the SSO IdP, if any. """ - # The deprecated `generate_short_term_login_token` method defaulted to an empty - # string for the `auth_provider_id` because of how the underlying macaroon was - # generated. This will change to a proper NULL-able field when the tokens get - # moved to the database. - return self._hs.get_macaroon_generator().generate_short_term_login_token( + return await self._hs.get_auth_handler().create_login_token_for_user_id( user_id, - auth_provider_id or "", - auth_provider_session_id, duration_in_ms, - ) - - def generate_short_term_login_token( - self, - user_id: str, - duration_in_ms: int = (2 * 60 * 1000), - auth_provider_id: str = "", - auth_provider_session_id: Optional[str] = None, - ) -> str: - """Generate a login token suitable for m.login.token authentication - - Added in Synapse v1.9.0. - - This was deprecated in Synapse v1.69.0 in favor of create_login_token, and will - be removed in Synapse 1.71.0. - - Args: - user_id: gives the ID of the user that the token is for - - duration_in_ms: the time that the token will be valid for - - auth_provider_id: the ID of the SSO IdP that the user used to authenticate - to get this token, if any. This is encoded in the token so that - /login can report stats on number of successful logins by IdP. - """ - logger.warn( - "A module configured on this server uses ModuleApi.generate_short_term_login_token(), " - "which is deprecated in favor of ModuleApi.create_login_token(), and will be removed in " - "Synapse 1.71.0", - ) - return self._hs.get_macaroon_generator().generate_short_term_login_token( - user_id, auth_provider_id, auth_provider_session_id, - duration_in_ms, ) @defer.inlineCallbacks diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index f554586ac3..7774f1967d 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -436,8 +436,7 @@ class LoginRestServlet(RestServlet): The body of the JSON response. """ token = login_submission["token"] - auth_handler = self.auth_handler - res = await auth_handler.validate_short_term_login_token(token) + res = await self.auth_handler.consume_login_token(token) return await self._complete_login( res.user_id, diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py index 277b20fb63..43ea21d5e6 100644 --- a/synapse/rest/client/login_token_request.py +++ b/synapse/rest/client/login_token_request.py @@ -57,7 +57,6 @@ class LoginTokenRequestServlet(RestServlet): self.store = hs.get_datastores().main self.clock = hs.get_clock() self.server_name = hs.config.server.server_name - self.macaroon_gen = hs.get_macaroon_generator() self.auth_handler = hs.get_auth_handler() self.token_timeout = hs.config.experimental.msc3882_token_timeout self.ui_auth = hs.config.experimental.msc3882_ui_auth @@ -76,10 +75,10 @@ class LoginTokenRequestServlet(RestServlet): can_skip_ui_auth=False, # Don't allow skipping of UI auth ) - login_token = self.macaroon_gen.generate_short_term_login_token( + login_token = await self.auth_handler.create_login_token_for_user_id( user_id=requester.user.to_string(), auth_provider_id="org.matrix.msc3882.login_token_request", - duration_in_ms=self.token_timeout, + duration_ms=self.token_timeout, ) return ( diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 2996d6bb4d..0255295317 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -21,7 +21,13 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast import attr from synapse.api.constants import UserTypes -from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError +from synapse.api.errors import ( + Codes, + NotFoundError, + StoreError, + SynapseError, + ThreepidValidationError, +) from synapse.config.homeserver import HomeServerConfig from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.database import ( @@ -50,6 +56,14 @@ class ExternalIDReuseException(Exception): because this external id is given to an other user.""" +class LoginTokenExpired(Exception): + """Exception if the login token sent expired""" + + +class LoginTokenReused(Exception): + """Exception if the login token sent was already used""" + + @attr.s(frozen=True, slots=True, auto_attribs=True) class TokenLookupResult: """Result of looking up an access token. @@ -115,6 +129,20 @@ class RefreshTokenLookupResult: If None, the session can be refreshed indefinitely.""" +@attr.s(auto_attribs=True, frozen=True, slots=True) +class LoginTokenLookupResult: + """Result of looking up a login token.""" + + user_id: str + """The user this token belongs to.""" + + auth_provider_id: Optional[str] + """The SSO Identity Provider that the user authenticated with, to get this token.""" + + auth_provider_session_id: Optional[str] + """The session ID advertised by the SSO Identity Provider.""" + + class RegistrationWorkerStore(CacheInvalidationWorkerStore): def __init__( self, @@ -1789,6 +1817,109 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): "replace_refresh_token", _replace_refresh_token_txn ) + async def add_login_token_to_user( + self, + user_id: str, + token: str, + expiry_ts: int, + auth_provider_id: Optional[str], + auth_provider_session_id: Optional[str], + ) -> None: + """Adds a short-term login token for the given user. + + Args: + user_id: The user ID. + token: The new login token to add. + expiry_ts (milliseconds since the epoch): Time after which the login token + cannot be used. + auth_provider_id: The SSO Identity Provider that the user authenticated with + to get this token, if any + auth_provider_session_id: The session ID advertised by the SSO Identity + Provider, if any. + """ + await self.db_pool.simple_insert( + "login_tokens", + { + "token": token, + "user_id": user_id, + "expiry_ts": expiry_ts, + "auth_provider_id": auth_provider_id, + "auth_provider_session_id": auth_provider_session_id, + }, + desc="add_login_token_to_user", + ) + + def _consume_login_token( + self, + txn: LoggingTransaction, + token: str, + ts: int, + ) -> LoginTokenLookupResult: + values = self.db_pool.simple_select_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + retcols=( + "user_id", + "expiry_ts", + "used_ts", + "auth_provider_id", + "auth_provider_session_id", + ), + allow_none=True, + ) + + if values is None: + raise NotFoundError() + + self.db_pool.simple_update_one_txn( + txn, + "login_tokens", + keyvalues={"token": token}, + updatevalues={"used_ts": ts}, + ) + user_id = values["user_id"] + expiry_ts = values["expiry_ts"] + used_ts = values["used_ts"] + auth_provider_id = values["auth_provider_id"] + auth_provider_session_id = values["auth_provider_session_id"] + + # Token was already used + if used_ts is not None: + raise LoginTokenReused() + + # Token expired + if ts > int(expiry_ts): + raise LoginTokenExpired() + + return LoginTokenLookupResult( + user_id=user_id, + auth_provider_id=auth_provider_id, + auth_provider_session_id=auth_provider_session_id, + ) + + async def consume_login_token(self, token: str) -> LoginTokenLookupResult: + """Lookup a login token and consume it. + + Args: + token: The login token. + + Returns: + The data stored with that token, including the `user_id`. Returns `None` if + the token does not exist or if it expired. + + Raises: + NotFound if the login token was not found in database + LoginTokenExpired if the login token expired + LoginTokenReused if the login token was already used + """ + return await self.db_pool.runInteraction( + "consume_login_token", + self._consume_login_token, + token, + self._clock.time_msec(), + ) + @cached() async def is_guest(self, user_id: str) -> bool: res = await self.db_pool.simple_select_one_onecol( @@ -2019,6 +2150,12 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): and hs.config.experimental.msc3866.require_approval_for_new_accounts ) + # Create a background job for removing expired login tokens + if hs.config.worker.run_background_tasks: + self._clock.looping_call( + self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS + ) + async def add_access_token_to_user( self, user_id: str, @@ -2617,6 +2754,23 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): approved, ) + @wrap_as_background_process("delete_expired_login_tokens") + async def _delete_expired_login_tokens(self) -> None: + """Remove login tokens with expiry dates that have passed.""" + + def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None: + sql = "DELETE FROM login_tokens WHERE expiry_ts <= ?" + txn.execute(sql, (ts,)) + + # We keep the expired tokens for an extra 5 minutes so we can measure how many + # times a token is being used after its expiry + now = self._clock.time_msec() + await self.db_pool.runInteraction( + "delete_expired_login_tokens", + _delete_expired_login_tokens_txn, + now - (5 * 60 * 1000), + ) + def find_max_generated_user_id_localpart(cur: Cursor) -> int: """ diff --git a/synapse/storage/schema/main/delta/73/10login_tokens.sql b/synapse/storage/schema/main/delta/73/10login_tokens.sql new file mode 100644 index 0000000000..a39b7bcece --- /dev/null +++ b/synapse/storage/schema/main/delta/73/10login_tokens.sql @@ -0,0 +1,35 @@ +/* + * Copyright 2022 The Matrix.org Foundation C.I.C. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Login tokens are short-lived tokens that are used for the m.login.token +-- login method, mainly during SSO logins +CREATE TABLE login_tokens ( + token TEXT PRIMARY KEY, + user_id TEXT NOT NULL, + expiry_ts BIGINT NOT NULL, + used_ts BIGINT, + auth_provider_id TEXT, + auth_provider_session_id TEXT +); + +-- We're sometimes querying them by their session ID we got from their IDP +CREATE INDEX login_tokens_auth_provider_idx + ON login_tokens (auth_provider_id, auth_provider_session_id); + +-- We're deleting them by their expiration time +CREATE INDEX login_tokens_expiry_time_idx + ON login_tokens (expiry_ts); + diff --git a/synapse/util/macaroons.py b/synapse/util/macaroons.py index df77edcce2..5df03d3ddc 100644 --- a/synapse/util/macaroons.py +++ b/synapse/util/macaroons.py @@ -24,7 +24,7 @@ from typing_extensions import Literal from synapse.util import Clock, stringutils -MacaroonType = Literal["access", "delete_pusher", "session", "login"] +MacaroonType = Literal["access", "delete_pusher", "session"] def get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: @@ -111,19 +111,6 @@ class OidcSessionData: """The session ID of the ongoing UI Auth ("" if this is a login)""" -@attr.s(slots=True, frozen=True, auto_attribs=True) -class LoginTokenAttributes: - """Data we store in a short-term login token""" - - user_id: str - - auth_provider_id: str - """The SSO Identity Provider that the user authenticated with, to get this token.""" - - auth_provider_session_id: Optional[str] - """The session ID advertised by the SSO Identity Provider.""" - - class MacaroonGenerator: def __init__(self, clock: Clock, location: str, secret_key: bytes): self._clock = clock @@ -165,35 +152,6 @@ class MacaroonGenerator: macaroon.add_first_party_caveat(f"pushkey = {pushkey}") return macaroon.serialize() - def generate_short_term_login_token( - self, - user_id: str, - auth_provider_id: str, - auth_provider_session_id: Optional[str] = None, - duration_in_ms: int = (2 * 60 * 1000), - ) -> str: - """Generate a short-term login token used during SSO logins - - Args: - user_id: The user for which the token is valid. - auth_provider_id: The SSO IdP the user used. - auth_provider_session_id: The session ID got during login from the SSO IdP. - - Returns: - A signed token valid for using as a ``m.login.token`` token. - """ - now = self._clock.time_msec() - expiry = now + duration_in_ms - macaroon = self._generate_base_macaroon("login") - macaroon.add_first_party_caveat(f"user_id = {user_id}") - macaroon.add_first_party_caveat(f"time < {expiry}") - macaroon.add_first_party_caveat(f"auth_provider_id = {auth_provider_id}") - if auth_provider_session_id is not None: - macaroon.add_first_party_caveat( - f"auth_provider_session_id = {auth_provider_session_id}" - ) - return macaroon.serialize() - def generate_oidc_session_token( self, state: str, @@ -233,49 +191,6 @@ class MacaroonGenerator: return macaroon.serialize() - def verify_short_term_login_token(self, token: str) -> LoginTokenAttributes: - """Verify a short-term-login macaroon - - Checks that the given token is a valid, unexpired short-term-login token - minted by this server. - - Args: - token: The login token to verify. - - Returns: - A set of attributes carried by this token, including the - ``user_id`` and informations about the SSO IDP used during that - login. - - Raises: - MacaroonVerificationFailedException if the verification failed - """ - macaroon = pymacaroons.Macaroon.deserialize(token) - - v = self._base_verifier("login") - v.satisfy_general(lambda c: c.startswith("user_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_id = ")) - v.satisfy_general(lambda c: c.startswith("auth_provider_session_id = ")) - satisfy_expiry(v, self._clock.time_msec) - v.verify(macaroon, self._secret_key) - - user_id = get_value_from_macaroon(macaroon, "user_id") - auth_provider_id = get_value_from_macaroon(macaroon, "auth_provider_id") - - auth_provider_session_id: Optional[str] = None - try: - auth_provider_session_id = get_value_from_macaroon( - macaroon, "auth_provider_session_id" - ) - except MacaroonVerificationFailedException: - pass - - return LoginTokenAttributes( - user_id=user_id, - auth_provider_id=auth_provider_id, - auth_provider_session_id=auth_provider_session_id, - ) - def verify_guest_token(self, token: str) -> str: """Verify a guest access token macaroon diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py index 7106799d44..036dbbc45b 100644 --- a/tests/handlers/test_auth.py +++ b/tests/handlers/test_auth.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional from unittest.mock import Mock import pymacaroons @@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor from synapse.api.errors import AuthError, ResourceLimitError from synapse.rest import admin +from synapse.rest.client import login from synapse.server import HomeServer from synapse.util import Clock @@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable class AuthTestCase(unittest.HomeserverTestCase): servlets = [ admin.register_servlets, + login.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1 = self.register_user("a_user", "pass") + def token_login(self, token: str) -> Optional[str]: + body = { + "type": "m.login.token", + "token": token, + } + + channel = self.make_request( + "POST", + "/_matrix/client/v3/login", + body, + ) + + if channel.code == 200: + return channel.json_body["user_id"] + + return None + def test_macaroon_caveats(self) -> None: token = self.macaroon_generator.generate_guest_access_token("a_user") macaroon = pymacaroons.Macaroon.deserialize(token) @@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase): v.satisfy_general(verify_guest) v.verify(macaroon, self.hs.config.key.macaroon_secret_key) - def test_short_term_login_token_gives_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + def test_login_token_gives_user_id(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) + + res = self.get_success(self.auth_handler.consume_login_token(token)) self.assertEqual(self.user1, res.user_id) - self.assertEqual("", res.auth_provider_id) + self.assertEqual(None, res.auth_provider_id) - # when we advance the clock, the token should be rejected - self.reactor.advance(6) - self.get_failure( - self.auth_handler.validate_short_term_login_token(token), - AuthError, + def test_login_token_reuse_fails(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - def test_short_term_login_token_gives_auth_provider(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, auth_provider_id="my_idp" - ) - res = self.get_success(self.auth_handler.validate_short_term_login_token(token)) - self.assertEqual(self.user1, res.user_id) - self.assertEqual("my_idp", res.auth_provider_id) + self.get_success(self.auth_handler.consume_login_token(token)) - def test_short_term_login_token_cannot_replace_user_id(self) -> None: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + self.get_failure( + self.auth_handler.consume_login_token(token), + AuthError, ) - macaroon = pymacaroons.Macaroon.deserialize(token) - res = self.get_success( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()) + def test_login_token_expires(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + duration_ms=(5 * 1000), + ) ) - self.assertEqual(self.user1, res.user_id) - - # add another "user_id" caveat, which might allow us to override the - # user_id. - macaroon.add_first_party_caveat("user_id = b_user") + # when we advance the clock, the token should be rejected + self.reactor.advance(6) self.get_failure( - self.auth_handler.validate_short_term_login_token(macaroon.serialize()), + self.auth_handler.consume_login_token(token), AuthError, ) + def test_login_token_gives_auth_provider(self) -> None: + token = self.get_success( + self.auth_handler.create_login_token_for_user_id( + self.user1, + auth_provider_id="my_idp", + auth_provider_session_id="11-22-33-44", + duration_ms=(5 * 1000), + ) + ) + res = self.get_success(self.auth_handler.consume_login_token(token)) + self.assertEqual(self.user1, res.user_id) + self.assertEqual("my_idp", res.auth_provider_id) + self.assertEqual("11-22-33-44", res.auth_provider_session_id) + def test_mau_limits_disabled(self) -> None: self.auth_blocking._limit_usage_by_mau = False # Ensure does not throw exception @@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase): ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) + def test_mau_limits_exceeded_large(self) -> None: self.auth_blocking._limit_usage_by_mau = True self.hs.get_datastores().main.get_monthly_active_count = Mock( @@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.large_number_of_users) ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) def test_mau_limits_parity(self) -> None: # Ensure we're not at the unix epoch. @@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase): ), ResourceLimitError, ) - self.get_failure( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ), - ResourceLimitError, + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNone(self.token_login(token)) # If in monthly active cohort self.hs.get_datastores().main.user_last_seen_monthly_active = Mock( @@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase): self.user1, device_id=None, valid_until_ms=None ) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) + self.assertIsNotNone(self.token_login(token)) def test_mau_limits_not_exceeded(self) -> None: self.auth_blocking._limit_usage_by_mau = True @@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase): self.hs.get_datastores().main.get_monthly_active_count = Mock( return_value=make_awaitable(self.small_number_of_users) ) - self.get_success( - self.auth_handler.validate_short_term_login_token( - self._get_macaroon().serialize() - ) - ) - - def _get_macaroon(self) -> pymacaroons.Macaroon: - token = self.macaroon_generator.generate_short_term_login_token( - self.user1, "", duration_in_ms=5000 + token = self.get_success( + self.auth_handler.create_login_token_for_user_id(self.user1) ) - return pymacaroons.Macaroon.deserialize(token) + self.assertIsNotNone(self.token_login(token)) diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index 32125f7bb7..40754a4711 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase): ) self.assertEqual(user_id, "@user:tesths") - def test_short_term_login_token(self): - """Test the generation and verification of short-term login tokens""" - token = self.macaroon_generator.generate_short_term_login_token( - user_id="@user:tesths", - auth_provider_id="oidc", - auth_provider_session_id="sid", - duration_in_ms=2 * 60 * 1000, - ) - - info = self.macaroon_generator.verify_short_term_login_token(token) - self.assertEqual(info.user_id, "@user:tesths") - self.assertEqual(info.auth_provider_id, "oidc") - self.assertEqual(info.auth_provider_session_id, "sid") - - # Raises with another secret key - with self.assertRaises(MacaroonVerificationFailedException): - self.other_macaroon_generator.verify_short_term_login_token(token) - - # Wait a minute - self.reactor.pump([60]) - # Shouldn't raise - self.macaroon_generator.verify_short_term_login_token(token) - # Wait another minute - self.reactor.pump([60]) - # Should raise since it expired - with self.assertRaises(MacaroonVerificationFailedException): - self.macaroon_generator.verify_short_term_login_token(token) - def test_oidc_session_token(self): """Test the generation and verification of OIDC session cookies""" state = "arandomstate" -- cgit 1.5.1 From 4ae967cf6308e80b03da749f0cbaed36988e235e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 17:35:54 -0500 Subject: Add missing type hints to test.util.caches (#14529) --- changelog.d/14529.misc | 1 + mypy.ini | 11 +++--- tests/util/caches/test_cached_call.py | 23 ++++++------ tests/util/caches/test_deferred_cache.py | 61 ++++++++++++++++---------------- tests/util/caches/test_descriptors.py | 22 +++++++----- tests/util/caches/test_response_cache.py | 16 ++++----- tests/util/caches/test_ttlcache.py | 8 ++--- 7 files changed, 76 insertions(+), 66 deletions(-) create mode 100644 changelog.d/14529.misc (limited to 'tests/util') diff --git a/changelog.d/14529.misc b/changelog.d/14529.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14529.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 4cd61e0484..25b3c93748 100644 --- a/mypy.ini +++ b/mypy.ini @@ -59,11 +59,6 @@ exclude = (?x) |tests/server_notices/test_resource_limits_server_notices.py |tests/test_state.py |tests/test_terms_auth.py - |tests/util/caches/test_cached_call.py - |tests/util/caches/test_deferred_cache.py - |tests/util/caches/test_descriptors.py - |tests/util/caches/test_response_cache.py - |tests/util/caches/test_ttlcache.py |tests/util/test_async_helpers.py |tests/util/test_batching_queue.py |tests/util/test_dict_cache.py @@ -133,6 +128,12 @@ disallow_untyped_defs = True [mypy-tests.federation.transport.test_client] disallow_untyped_defs = True +[mypy-tests.util.caches.*] +disallow_untyped_defs = True + +[mypy-tests.util.caches.test_descriptors] +disallow_untyped_defs = False + [mypy-tests.utils] disallow_untyped_defs = True diff --git a/tests/util/caches/test_cached_call.py b/tests/util/caches/test_cached_call.py index 80b97167ba..9266f12590 100644 --- a/tests/util/caches/test_cached_call.py +++ b/tests/util/caches/test_cached_call.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import NoReturn from unittest.mock import Mock from twisted.internet import defer @@ -23,14 +24,14 @@ from tests.unittest import TestCase class CachedCallTestCase(TestCase): - def test_get(self): + def test_get(self) -> None: """ Happy-path test case: makes a couple of calls and makes sure they behave correctly """ - d = Deferred() + d: "Deferred[int]" = Deferred() - async def f(): + async def f() -> int: return await d slow_call = Mock(side_effect=f) @@ -43,7 +44,7 @@ class CachedCallTestCase(TestCase): # now fire off a couple of calls completed_results = [] - async def r(): + async def r() -> None: res = await cached_call.get() completed_results.append(res) @@ -69,12 +70,12 @@ class CachedCallTestCase(TestCase): self.assertEqual(r3, 123) slow_call.assert_not_called() - def test_fast_call(self): + def test_fast_call(self) -> None: """ Test the behaviour when the underlying function completes immediately """ - async def f(): + async def f() -> int: return 12 fast_call = Mock(side_effect=f) @@ -92,12 +93,12 @@ class CachedCallTestCase(TestCase): class RetryOnExceptionCachedCallTestCase(TestCase): - def test_get(self): + def test_get(self) -> None: # set up the RetryOnExceptionCachedCall around a function which will fail # (after a while) - d = Deferred() + d: "Deferred[int]" = Deferred() - async def f1(): + async def f1() -> NoReturn: await d raise ValueError("moo") @@ -110,7 +111,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase): # now fire off a couple of calls completed_results = [] - async def r(): + async def r() -> None: try: await cached_call.get() except Exception as e1: @@ -137,7 +138,7 @@ class RetryOnExceptionCachedCallTestCase(TestCase): # to the getter d = Deferred() - async def f2(): + async def f2() -> int: return await d slow_call.reset_mock() diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 02b99b466a..f74d82b1dc 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import partial +from typing import List, Tuple from twisted.internet import defer @@ -22,20 +23,20 @@ from tests.unittest import TestCase class DeferredCacheTestCase(TestCase): - def test_empty(self): - cache = DeferredCache("test") + def test_empty(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") with self.assertRaises(KeyError): cache.get("foo") - def test_hit(self): - cache = DeferredCache("test") + def test_hit(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") cache.prefill("foo", 123) self.assertEqual(self.successResultOf(cache.get("foo")), 123) - def test_hit_deferred(self): - cache = DeferredCache("test") - origin_d = defer.Deferred() + def test_hit_deferred(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") + origin_d: "defer.Deferred[int]" = defer.Deferred() set_d = cache.set("k1", origin_d) # get should return an incomplete deferred @@ -43,7 +44,7 @@ class DeferredCacheTestCase(TestCase): self.assertFalse(get_d.called) # add a callback that will make sure that the set_d gets called before the get_d - def check1(r): + def check1(r: str) -> str: self.assertTrue(set_d.called) return r @@ -55,16 +56,16 @@ class DeferredCacheTestCase(TestCase): self.assertEqual(self.successResultOf(set_d), 99) self.assertEqual(self.successResultOf(get_d), 99) - def test_callbacks(self): + def test_callbacks(self) -> None: """Invalidation callbacks are called at the right time""" - cache = DeferredCache("test") + cache: DeferredCache[str, int] = DeferredCache("test") callbacks = set() # start with an entry, with a callback cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) # now replace that entry with a pending result - origin_d = defer.Deferred() + origin_d: "defer.Deferred[int]" = defer.Deferred() set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) # ... and also make a get request @@ -89,15 +90,15 @@ class DeferredCacheTestCase(TestCase): cache.prefill("k1", 30) self.assertEqual(callbacks, {"set", "get"}) - def test_set_fail(self): - cache = DeferredCache("test") + def test_set_fail(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") callbacks = set() # start with an entry, with a callback cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill")) # now replace that entry with a pending result - origin_d = defer.Deferred() + origin_d: defer.Deferred = defer.Deferred() set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set")) # ... and also make a get request @@ -126,9 +127,9 @@ class DeferredCacheTestCase(TestCase): cache.prefill("k1", 30) self.assertEqual(callbacks, {"prefill", "get2"}) - def test_get_immediate(self): - cache = DeferredCache("test") - d1 = defer.Deferred() + def test_get_immediate(self) -> None: + cache: DeferredCache[str, int] = DeferredCache("test") + d1: "defer.Deferred[int]" = defer.Deferred() cache.set("key1", d1) # get_immediate should return default @@ -142,27 +143,27 @@ class DeferredCacheTestCase(TestCase): v = cache.get_immediate("key1", 1) self.assertEqual(v, 2) - def test_invalidate(self): - cache = DeferredCache("test") + def test_invalidate(self) -> None: + cache: DeferredCache[Tuple[str], int] = DeferredCache("test") cache.prefill(("foo",), 123) cache.invalidate(("foo",)) with self.assertRaises(KeyError): cache.get(("foo",)) - def test_invalidate_all(self): - cache = DeferredCache("testcache") + def test_invalidate_all(self) -> None: + cache: DeferredCache[str, str] = DeferredCache("testcache") callback_record = [False, False] - def record_callback(idx): + def record_callback(idx: int) -> None: callback_record[idx] = True # add a couple of pending entries - d1 = defer.Deferred() + d1: "defer.Deferred[str]" = defer.Deferred() cache.set("key1", d1, partial(record_callback, 0)) - d2 = defer.Deferred() + d2: "defer.Deferred[str]" = defer.Deferred() cache.set("key2", d2, partial(record_callback, 1)) # lookup should return pending deferreds @@ -193,8 +194,8 @@ class DeferredCacheTestCase(TestCase): with self.assertRaises(KeyError): cache.get("key1", None) - def test_eviction(self): - cache = DeferredCache( + def test_eviction(self) -> None: + cache: DeferredCache[int, str] = DeferredCache( "test", max_entries=2, apply_cache_factor_from_config=False ) @@ -208,8 +209,8 @@ class DeferredCacheTestCase(TestCase): cache.get(2) cache.get(3) - def test_eviction_lru(self): - cache = DeferredCache( + def test_eviction_lru(self) -> None: + cache: DeferredCache[int, str] = DeferredCache( "test", max_entries=2, apply_cache_factor_from_config=False ) @@ -227,8 +228,8 @@ class DeferredCacheTestCase(TestCase): cache.get(1) cache.get(3) - def test_eviction_iterable(self): - cache = DeferredCache( + def test_eviction_iterable(self) -> None: + cache: DeferredCache[int, List[str]] = DeferredCache( "test", max_entries=3, apply_cache_factor_from_config=False, diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 43475a307f..13f1edd533 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,11 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Iterable, Set, Tuple +from typing import Iterable, Set, Tuple, cast from unittest import mock from twisted.internet import defer, reactor from twisted.internet.defer import CancelledError, Deferred +from twisted.internet.interfaces import IReactorTime from synapse.api.errors import SynapseError from synapse.logging.context import ( @@ -37,8 +38,8 @@ logger = logging.getLogger(__name__) def run_on_reactor(): - d = defer.Deferred() - reactor.callLater(0, d.callback, 0) + d: "Deferred[int]" = defer.Deferred() + cast(IReactorTime, reactor).callLater(0, d.callback, 0) return make_deferred_yieldable(d) @@ -224,7 +225,8 @@ class DescriptorTestCase(unittest.TestCase): callbacks: Set[str] = set() # set off an asynchronous request - obj.result = origin_d = defer.Deferred() + origin_d: Deferred = defer.Deferred() + obj.result = origin_d d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1")) self.assertFalse(d1.called) @@ -262,7 +264,7 @@ class DescriptorTestCase(unittest.TestCase): """Check that logcontexts are set and restored correctly when using the cache.""" - complete_lookup = defer.Deferred() + complete_lookup: Deferred = defer.Deferred() class Cls: @descriptors.cached() @@ -772,10 +774,14 @@ class CachedListDescriptorTestCase(unittest.TestCase): @descriptors.cachedList(cached_method_name="fn", list_name="args1") async def list_fn(self, args1, arg2): - assert current_context().name == "c1" + context = current_context() + assert isinstance(context, LoggingContext) + assert context.name == "c1" # we want this to behave like an asynchronous function await run_on_reactor() - assert current_context().name == "c1" + context = current_context() + assert isinstance(context, LoggingContext) + assert context.name == "c1" return self.mock(args1, arg2) with LoggingContext("c1") as c1: @@ -834,7 +840,7 @@ class CachedListDescriptorTestCase(unittest.TestCase): return self.mock(args1) obj = Cls() - deferred_result = Deferred() + deferred_result: "Deferred[dict]" = Deferred() obj.mock.return_value = deferred_result # start off several concurrent lookups of the same key diff --git a/tests/util/caches/test_response_cache.py b/tests/util/caches/test_response_cache.py index 025b73e32f..f09eeecada 100644 --- a/tests/util/caches/test_response_cache.py +++ b/tests/util/caches/test_response_cache.py @@ -35,7 +35,7 @@ class ResponseCacheTestCase(TestCase): (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock) """ - def setUp(self): + def setUp(self) -> None: self.reactor, self.clock = get_clock() def with_cache(self, name: str, ms: int = 0) -> ResponseCache: @@ -49,7 +49,7 @@ class ResponseCacheTestCase(TestCase): await self.clock.sleep(1) return o - def test_cache_hit(self): + def test_cache_hit(self) -> None: cache = self.with_cache("keeping_cache", ms=9001) expected_result = "howdy" @@ -74,7 +74,7 @@ class ResponseCacheTestCase(TestCase): "cache should still have the result", ) - def test_cache_miss(self): + def test_cache_miss(self) -> None: cache = self.with_cache("trashing_cache", ms=0) expected_result = "howdy" @@ -90,7 +90,7 @@ class ResponseCacheTestCase(TestCase): ) self.assertCountEqual([], cache.keys(), "cache should not have the result now") - def test_cache_expire(self): + def test_cache_expire(self) -> None: cache = self.with_cache("short_cache", ms=1000) expected_result = "howdy" @@ -115,7 +115,7 @@ class ResponseCacheTestCase(TestCase): self.reactor.pump((2,)) self.assertCountEqual([], cache.keys(), "cache should not have the result now") - def test_cache_wait_hit(self): + def test_cache_wait_hit(self) -> None: cache = self.with_cache("neutral_cache") expected_result = "howdy" @@ -131,7 +131,7 @@ class ResponseCacheTestCase(TestCase): self.assertEqual(expected_result, self.successResultOf(wrap_d)) - def test_cache_wait_expire(self): + def test_cache_wait_expire(self) -> None: cache = self.with_cache("medium_cache", ms=3000) expected_result = "howdy" @@ -162,7 +162,7 @@ class ResponseCacheTestCase(TestCase): self.assertCountEqual([], cache.keys(), "cache should not have the result now") @parameterized.expand([(True,), (False,)]) - def test_cache_context_nocache(self, should_cache: bool): + def test_cache_context_nocache(self, should_cache: bool) -> None: """If the callback clears the should_cache bit, the result should not be cached""" cache = self.with_cache("medium_cache", ms=3000) @@ -170,7 +170,7 @@ class ResponseCacheTestCase(TestCase): call_count = 0 - async def non_caching(o: str, cache_context: ResponseCacheContext[int]): + async def non_caching(o: str, cache_context: ResponseCacheContext[int]) -> str: nonlocal call_count call_count += 1 await self.clock.sleep(1) diff --git a/tests/util/caches/test_ttlcache.py b/tests/util/caches/test_ttlcache.py index fe8314057d..679d1eb36b 100644 --- a/tests/util/caches/test_ttlcache.py +++ b/tests/util/caches/test_ttlcache.py @@ -20,11 +20,11 @@ from tests import unittest class CacheTestCase(unittest.TestCase): - def setUp(self): + def setUp(self) -> None: self.mock_timer = Mock(side_effect=lambda: 100.0) - self.cache = TTLCache("test_cache", self.mock_timer) + self.cache: TTLCache[str, str] = TTLCache("test_cache", self.mock_timer) - def test_get(self): + def test_get(self) -> None: """simple set/get tests""" self.cache.set("one", "1", 10) self.cache.set("two", "2", 20) @@ -59,7 +59,7 @@ class CacheTestCase(unittest.TestCase): self.assertEqual(self.cache._metrics.hits, 4) self.assertEqual(self.cache._metrics.misses, 5) - def test_expiry(self): + def test_expiry(self) -> None: self.cache.set("one", "1", 10) self.cache.set("two", "2", 20) self.cache.set("three", "3", 30) -- cgit 1.5.1 From acea4d7a2ff61b5beda420b54a8451088060a8cd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 2 Dec 2022 12:58:56 -0500 Subject: Add missing types to tests.util. (#14597) Removes files under tests.util from the ignored by list, then fully types all tests/util/*.py files. --- changelog.d/14597.misc | 1 + mypy.ini | 13 +--- tests/util/test_async_helpers.py | 118 ++++++++++++++++++--------------- tests/util/test_batching_queue.py | 30 +++++---- tests/util/test_check_dependencies.py | 29 ++++++-- tests/util/test_dict_cache.py | 20 +++--- tests/util/test_expiring_cache.py | 26 +++++--- tests/util/test_file_consumer.py | 103 ++++++++++++++++------------ tests/util/test_itertools.py | 24 +++---- tests/util/test_logcontext.py | 86 +++++++++++++++--------- tests/util/test_logformatter.py | 2 +- tests/util/test_lrucache.py | 80 +++++++++++----------- tests/util/test_macaroons.py | 8 +-- tests/util/test_ratelimitutils.py | 15 +++-- tests/util/test_retryutils.py | 4 +- tests/util/test_rwlock.py | 14 ++-- tests/util/test_stream_change_cache.py | 14 ++-- tests/util/test_stringutils.py | 4 +- tests/util/test_threepids.py | 16 ++--- tests/util/test_treecache.py | 14 ++-- tests/util/test_wheel_timer.py | 16 ++--- 21 files changed, 361 insertions(+), 276 deletions(-) create mode 100644 changelog.d/14597.misc (limited to 'tests/util') diff --git a/changelog.d/14597.misc b/changelog.d/14597.misc new file mode 100644 index 0000000000..d44571b731 --- /dev/null +++ b/changelog.d/14597.misc @@ -0,0 +1 @@ +Add missing type hints. diff --git a/mypy.ini b/mypy.ini index 0b6e7df267..c3fbd1a955 100644 --- a/mypy.ini +++ b/mypy.ini @@ -59,16 +59,6 @@ exclude = (?x) |tests/server_notices/test_resource_limits_server_notices.py |tests/test_state.py |tests/test_terms_auth.py - |tests/util/test_async_helpers.py - |tests/util/test_batching_queue.py - |tests/util/test_dict_cache.py - |tests/util/test_expiring_cache.py - |tests/util/test_file_consumer.py - |tests/util/test_linearizer.py - |tests/util/test_logcontext.py - |tests/util/test_lrucache.py - |tests/util/test_rwlock.py - |tests/util/test_wheel_timer.py )$ [mypy-synapse.federation.transport.client] @@ -137,6 +127,9 @@ disallow_untyped_defs = True [mypy-tests.util.caches.test_descriptors] disallow_untyped_defs = False +[mypy-tests.util.*] +disallow_untyped_defs = True + [mypy-tests.utils] disallow_untyped_defs = True diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 9d5010bf92..91cac9822a 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import traceback +from typing import Generator, List, NoReturn, Optional from parameterized import parameterized_class @@ -41,8 +42,8 @@ from tests.unittest import TestCase class ObservableDeferredTest(TestCase): - def test_succeed(self): - origin_d = Deferred() + def test_succeed(self) -> None: + origin_d: "Deferred[int]" = Deferred() observable = ObservableDeferred(origin_d) observer1 = observable.observe() @@ -52,16 +53,18 @@ class ObservableDeferredTest(TestCase): self.assertFalse(observer2.called) # check the first observer is called first - def check_called_first(res): + def check_called_first(res: int) -> int: self.assertFalse(observer2.called) return res observer1.addBoth(check_called_first) # store the results - results = [None, None] + results: List[Optional[ObservableDeferred[int]]] = [None, None] - def check_val(res, idx): + def check_val( + res: ObservableDeferred[int], idx: int + ) -> ObservableDeferred[int]: results[idx] = res return res @@ -72,8 +75,8 @@ class ObservableDeferredTest(TestCase): self.assertEqual(results[0], 123, "observer 1 callback result") self.assertEqual(results[1], 123, "observer 2 callback result") - def test_failure(self): - origin_d = Deferred() + def test_failure(self) -> None: + origin_d: Deferred = Deferred() observable = ObservableDeferred(origin_d, consumeErrors=True) observer1 = observable.observe() @@ -83,16 +86,16 @@ class ObservableDeferredTest(TestCase): self.assertFalse(observer2.called) # check the first observer is called first - def check_called_first(res): + def check_called_first(res: int) -> int: self.assertFalse(observer2.called) return res observer1.addBoth(check_called_first) # store the results - results = [None, None] + results: List[Optional[ObservableDeferred[str]]] = [None, None] - def check_val(res, idx): + def check_val(res: ObservableDeferred[str], idx: int) -> None: results[idx] = res return None @@ -103,10 +106,12 @@ class ObservableDeferredTest(TestCase): raise Exception("gah!") except Exception as e: origin_d.errback(e) + assert results[0] is not None self.assertEqual(str(results[0].value), "gah!", "observer 1 errback result") + assert results[1] is not None self.assertEqual(str(results[1].value), "gah!", "observer 2 errback result") - def test_cancellation(self): + def test_cancellation(self) -> None: """Test that cancelling an observer does not affect other observers.""" origin_d: "Deferred[int]" = Deferred() observable = ObservableDeferred(origin_d, consumeErrors=True) @@ -136,37 +141,38 @@ class ObservableDeferredTest(TestCase): class TimeoutDeferredTest(TestCase): - def setUp(self): + def setUp(self) -> None: self.clock = Clock() - def test_times_out(self): + def test_times_out(self) -> None: """Basic test case that checks that the original deferred is cancelled and that the timing-out deferred is errbacked """ - cancelled = [False] + cancelled = False - def canceller(_d): - cancelled[0] = True + def canceller(_d: Deferred) -> None: + nonlocal cancelled + cancelled = True - non_completing_d = Deferred(canceller) + non_completing_d: Deferred = Deferred(canceller) timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) self.assertNoResult(timing_out_d) - self.assertFalse(cancelled[0], "deferred was cancelled prematurely") + self.assertFalse(cancelled, "deferred was cancelled prematurely") self.clock.pump((1.0,)) - self.assertTrue(cancelled[0], "deferred was not cancelled by timeout") + self.assertTrue(cancelled, "deferred was not cancelled by timeout") self.failureResultOf(timing_out_d, defer.TimeoutError) - def test_times_out_when_canceller_throws(self): + def test_times_out_when_canceller_throws(self) -> None: """Test that we have successfully worked around https://twistedmatrix.com/trac/ticket/9534""" - def canceller(_d): + def canceller(_d: Deferred) -> None: raise Exception("can't cancel this deferred") - non_completing_d = Deferred(canceller) + non_completing_d: Deferred = Deferred(canceller) timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) self.assertNoResult(timing_out_d) @@ -175,22 +181,24 @@ class TimeoutDeferredTest(TestCase): self.failureResultOf(timing_out_d, defer.TimeoutError) - def test_logcontext_is_preserved_on_cancellation(self): - blocking_was_cancelled = [False] + def test_logcontext_is_preserved_on_cancellation(self) -> None: + blocking_was_cancelled = False @defer.inlineCallbacks - def blocking(): - non_completing_d = Deferred() + def blocking() -> Generator["Deferred[object]", object, None]: + nonlocal blocking_was_cancelled + + non_completing_d: Deferred = Deferred() with PreserveLoggingContext(): try: yield non_completing_d except CancelledError: - blocking_was_cancelled[0] = True + blocking_was_cancelled = True raise with LoggingContext("one") as context_one: # the errbacks should be run in the test logcontext - def errback(res, deferred_name): + def errback(res: Failure, deferred_name: str) -> Failure: self.assertIs( current_context(), context_one, @@ -209,7 +217,7 @@ class TimeoutDeferredTest(TestCase): self.clock.pump((1.0,)) self.assertTrue( - blocking_was_cancelled[0], "non-completing deferred was not cancelled" + blocking_was_cancelled, "non-completing deferred was not cancelled" ) self.failureResultOf(timing_out_d, defer.TimeoutError) self.assertIs(current_context(), context_one) @@ -220,13 +228,13 @@ class _TestException(Exception): class ConcurrentlyExecuteTest(TestCase): - def test_limits_runners(self): + def test_limits_runners(self) -> None: """If we have more tasks than runners, we should get the limit of runners""" started = 0 waiters = [] processed = [] - async def callback(v): + async def callback(v: int) -> None: # when we first enter, bump the start count nonlocal started started += 1 @@ -235,7 +243,7 @@ class ConcurrentlyExecuteTest(TestCase): processed.append(v) # wait for the goahead before returning - d2 = Deferred() + d2: "Deferred[int]" = Deferred() waiters.append(d2) await d2 @@ -265,16 +273,16 @@ class ConcurrentlyExecuteTest(TestCase): self.assertCountEqual(processed, [1, 2, 3, 4, 5]) self.successResultOf(d2) - def test_preserves_stacktraces(self): + def test_preserves_stacktraces(self) -> None: """Test that the stacktrace from an exception thrown in the callback is preserved""" - d1 = Deferred() + d1: "Deferred[int]" = Deferred() - async def callback(v): + async def callback(v: int) -> None: # alas, this doesn't work at all without an await here await d1 raise _TestException("bah") - async def caller(): + async def caller() -> None: try: await concurrently_execute(callback, [1], 2) except _TestException as e: @@ -290,17 +298,17 @@ class ConcurrentlyExecuteTest(TestCase): d1.callback(0) self.successResultOf(d2) - def test_preserves_stacktraces_on_preformed_failure(self): + def test_preserves_stacktraces_on_preformed_failure(self) -> None: """Test that the stacktrace on a Failure returned by the callback is preserved""" - d1 = Deferred() + d1: "Deferred[int]" = Deferred() f = Failure(_TestException("bah")) - async def callback(v): + async def callback(v: int) -> None: # alas, this doesn't work at all without an await here await d1 await defer.fail(f) - async def caller(): + async def caller() -> None: try: await concurrently_execute(callback, [1], 2) except _TestException as e: @@ -336,7 +344,7 @@ class CancellationWrapperTests(TestCase): else: raise ValueError(f"Unsupported wrapper type: {self.wrapper}") - def test_succeed(self): + def test_succeed(self) -> None: """Test that the new `Deferred` receives the result.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = self.wrap_deferred(deferred) @@ -346,7 +354,7 @@ class CancellationWrapperTests(TestCase): self.assertTrue(wrapper_deferred.called) self.assertEqual("success", self.successResultOf(wrapper_deferred)) - def test_failure(self): + def test_failure(self) -> None: """Test that the new `Deferred` receives the `Failure`.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = self.wrap_deferred(deferred) @@ -361,7 +369,7 @@ class CancellationWrapperTests(TestCase): class StopCancellationTests(TestCase): """Tests for the `stop_cancellation` function.""" - def test_cancellation(self): + def test_cancellation(self) -> None: """Test that cancellation of the new `Deferred` leaves the original running.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = stop_cancellation(deferred) @@ -384,7 +392,7 @@ class StopCancellationTests(TestCase): class DelayCancellationTests(TestCase): """Tests for the `delay_cancellation` function.""" - def test_deferred_cancellation(self): + def test_deferred_cancellation(self) -> None: """Test that cancellation of the new `Deferred` waits for the original.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = delay_cancellation(deferred) @@ -405,12 +413,12 @@ class DelayCancellationTests(TestCase): # Now that the original `Deferred` has failed, we should get a `CancelledError`. self.failureResultOf(wrapper_deferred, CancelledError) - def test_coroutine_cancellation(self): + def test_coroutine_cancellation(self) -> None: """Test that cancellation of the new `Deferred` waits for the original.""" blocking_deferred: "Deferred[None]" = Deferred() completion_deferred: "Deferred[None]" = Deferred() - async def task(): + async def task() -> NoReturn: await blocking_deferred completion_deferred.callback(None) # Raise an exception. Twisted should consume it, otherwise unwanted @@ -434,7 +442,7 @@ class DelayCancellationTests(TestCase): # Now that the original coroutine has failed, we should get a `CancelledError`. self.failureResultOf(wrapper_deferred, CancelledError) - def test_suppresses_second_cancellation(self): + def test_suppresses_second_cancellation(self) -> None: """Test that a second cancellation is suppressed. Identical to `test_cancellation` except the new `Deferred` is cancelled twice. @@ -459,7 +467,7 @@ class DelayCancellationTests(TestCase): # Now that the original `Deferred` has failed, we should get a `CancelledError`. self.failureResultOf(wrapper_deferred, CancelledError) - def test_propagates_cancelled_error(self): + def test_propagates_cancelled_error(self) -> None: """Test that a `CancelledError` from the original `Deferred` gets propagated.""" deferred: "Deferred[str]" = Deferred() wrapper_deferred = delay_cancellation(deferred) @@ -472,14 +480,14 @@ class DelayCancellationTests(TestCase): self.assertTrue(wrapper_deferred.called) self.assertIs(cancelled_error, self.failureResultOf(wrapper_deferred).value) - def test_preserves_logcontext(self): + def test_preserves_logcontext(self) -> None: """Test that logging contexts are preserved.""" blocking_d: "Deferred[None]" = Deferred() - async def inner(): + async def inner() -> None: await make_deferred_yieldable(blocking_d) - async def outer(): + async def outer() -> None: with LoggingContext("c") as c: try: await delay_cancellation(inner()) @@ -503,7 +511,7 @@ class DelayCancellationTests(TestCase): class AwakenableSleeperTests(TestCase): "Tests AwakenableSleeper" - def test_sleep(self): + def test_sleep(self) -> None: reactor, _ = get_clock() sleeper = AwakenableSleeper(reactor) @@ -518,7 +526,7 @@ class AwakenableSleeperTests(TestCase): reactor.advance(0.6) self.assertTrue(d.called) - def test_explicit_wake(self): + def test_explicit_wake(self) -> None: reactor, _ = get_clock() sleeper = AwakenableSleeper(reactor) @@ -535,7 +543,7 @@ class AwakenableSleeperTests(TestCase): reactor.advance(0.6) - def test_multiple_sleepers_timeout(self): + def test_multiple_sleepers_timeout(self) -> None: reactor, _ = get_clock() sleeper = AwakenableSleeper(reactor) @@ -555,7 +563,7 @@ class AwakenableSleeperTests(TestCase): reactor.advance(0.6) self.assertTrue(d2.called) - def test_multiple_sleepers_wake(self): + def test_multiple_sleepers_wake(self) -> None: reactor, _ = get_clock() sleeper = AwakenableSleeper(reactor) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index 07be57d72c..94ef91f645 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Tuple + +from prometheus_client import Gauge + from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable @@ -26,7 +30,7 @@ from tests.unittest import TestCase class BatchingQueueTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.clock, hs_clock = get_clock() # We ensure that we remove any existing metrics for "test_queue". @@ -37,25 +41,27 @@ class BatchingQueueTestCase(TestCase): except KeyError: pass - self._pending_calls = [] - self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) + self._pending_calls: List[Tuple[List[str], defer.Deferred]] = [] + self.queue: BatchingQueue[str, str] = BatchingQueue( + "test_queue", hs_clock, self._process_queue + ) - async def _process_queue(self, values): - d = defer.Deferred() + async def _process_queue(self, values: List[str]) -> str: + d: "defer.Deferred[str]" = defer.Deferred() self._pending_calls.append((values, d)) return await make_deferred_yieldable(d) - def _get_sample_with_name(self, metric, name) -> int: + def _get_sample_with_name(self, metric: Gauge, name: str) -> float: """For a prometheus metric get the value of the sample that has a matching "name" label. """ - for sample in metric.collect()[0].samples: + for sample in next(iter(metric.collect())).samples: if sample.labels.get("name") == name: return sample.value self.fail("Found no matching sample") - def _assert_metrics(self, queued, keys, in_flight): + def _assert_metrics(self, queued: int, keys: int, in_flight: int) -> None: """Assert that the metrics are correct""" sample = self._get_sample_with_name(number_queued, self.queue._name) @@ -75,7 +81,7 @@ class BatchingQueueTestCase(TestCase): "number_in_flight", ) - def test_simple(self): + def test_simple(self) -> None: """Tests the basic case of calling `add_to_queue` once and having `_process_queue` return. """ @@ -106,7 +112,7 @@ class BatchingQueueTestCase(TestCase): self._assert_metrics(queued=0, keys=0, in_flight=0) - def test_batching(self): + def test_batching(self) -> None: """Test that multiple calls at the same time get batched up into one call to `_process_queue`. """ @@ -134,7 +140,7 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self.successResultOf(queue_d2), "bar") self._assert_metrics(queued=0, keys=0, in_flight=0) - def test_queuing(self): + def test_queuing(self) -> None: """Test that we queue up requests while a `_process_queue` is being called. """ @@ -184,7 +190,7 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self.successResultOf(queue_d3), "bar2") self._assert_metrics(queued=0, keys=0, in_flight=0) - def test_different_keys(self): + def test_different_keys(self) -> None: """Test that calls to different keys get processed in parallel.""" self.assertFalse(self._pending_calls) diff --git a/tests/util/test_check_dependencies.py b/tests/util/test_check_dependencies.py index 6913de24b9..aa20fe6780 100644 --- a/tests/util/test_check_dependencies.py +++ b/tests/util/test_check_dependencies.py @@ -1,5 +1,20 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from contextlib import contextmanager -from typing import Generator, Optional +from os import PathLike +from typing import Generator, Optional, Union from unittest.mock import patch from synapse.util.check_dependencies import ( @@ -12,17 +27,17 @@ from tests.unittest import TestCase class DummyDistribution(metadata.Distribution): - def __init__(self, version: object): + def __init__(self, version: str): self._version = version @property - def version(self): + def version(self) -> str: return self._version - def locate_file(self, path): + def locate_file(self, path: Union[str, PathLike]) -> PathLike: raise NotImplementedError() - def read_text(self, filename): + def read_text(self, filename: str) -> None: raise NotImplementedError() @@ -30,7 +45,7 @@ old = DummyDistribution("0.1.2") old_release_candidate = DummyDistribution("0.1.2rc3") new = DummyDistribution("1.2.3") new_release_candidate = DummyDistribution("1.2.3rc4") -distribution_with_no_version = DummyDistribution(None) +distribution_with_no_version = DummyDistribution(None) # type: ignore[arg-type] # could probably use stdlib TestCase --- no need for twisted here @@ -45,7 +60,7 @@ class TestDependencyChecker(TestCase): If `distribution = None`, we pretend that the package is not installed. """ - def mock_distribution(name: str): + def mock_distribution(name: str) -> DummyDistribution: if distribution is None: raise metadata.PackageNotFoundError else: diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index e8b6246ab5..acb251bfea 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -19,10 +19,12 @@ from tests import unittest class DictCacheTestCase(unittest.TestCase): - def setUp(self): - self.cache = DictionaryCache("foobar", max_entries=10) + def setUp(self) -> None: + self.cache: DictionaryCache[str, str, str] = DictionaryCache( + "foobar", max_entries=10 + ) - def test_simple_cache_hit_full(self): + def test_simple_cache_hit_full(self) -> None: key = "test_simple_cache_hit_full" v = self.cache.get(key) @@ -37,7 +39,7 @@ class DictCacheTestCase(unittest.TestCase): c = self.cache.get(key) self.assertEqual(test_value, c.value) - def test_simple_cache_hit_partial(self): + def test_simple_cache_hit_partial(self) -> None: key = "test_simple_cache_hit_partial" seq = self.cache.sequence @@ -47,7 +49,7 @@ class DictCacheTestCase(unittest.TestCase): c = self.cache.get(key, ["test"]) self.assertEqual(test_value, c.value) - def test_simple_cache_miss_partial(self): + def test_simple_cache_miss_partial(self) -> None: key = "test_simple_cache_miss_partial" seq = self.cache.sequence @@ -57,7 +59,7 @@ class DictCacheTestCase(unittest.TestCase): c = self.cache.get(key, ["test2"]) self.assertEqual({}, c.value) - def test_simple_cache_hit_miss_partial(self): + def test_simple_cache_hit_miss_partial(self) -> None: key = "test_simple_cache_hit_miss_partial" seq = self.cache.sequence @@ -71,7 +73,7 @@ class DictCacheTestCase(unittest.TestCase): c = self.cache.get(key, ["test2"]) self.assertEqual({"test2": "test_simple_cache_hit_miss_partial2"}, c.value) - def test_multi_insert(self): + def test_multi_insert(self) -> None: key = "test_simple_cache_hit_miss_partial" seq = self.cache.sequence @@ -92,7 +94,7 @@ class DictCacheTestCase(unittest.TestCase): ) self.assertEqual(c.full, False) - def test_invalidation(self): + def test_invalidation(self) -> None: """Test that the partial dict and full dicts get invalidated separately. """ @@ -106,7 +108,7 @@ class DictCacheTestCase(unittest.TestCase): # entry for "a" warm. for i in range(20): self.cache.get(key, ["a"]) - self.cache.update(seq, f"key{i}", {1: 2}) + self.cache.update(seq, f"key{i}", {"1": "2"}) # We should have evicted the full dict... r = self.cache.get(key) diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index 7f60aae5ba..9cf920daf8 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, cast +from synapse.util import Clock from synapse.util.caches.expiringcache import ExpiringCache from tests.utils import MockClock @@ -21,17 +23,21 @@ from .. import unittest class ExpiringCacheTestCase(unittest.HomeserverTestCase): - def test_get_set(self): + def test_get_set(self) -> None: clock = MockClock() - cache = ExpiringCache("test", clock, max_len=1) + cache: ExpiringCache[str, str] = ExpiringCache( + "test", cast(Clock, clock), max_len=1 + ) cache["key"] = "value" self.assertEqual(cache.get("key"), "value") self.assertEqual(cache["key"], "value") - def test_eviction(self): + def test_eviction(self) -> None: clock = MockClock() - cache = ExpiringCache("test", clock, max_len=2) + cache: ExpiringCache[str, str] = ExpiringCache( + "test", cast(Clock, clock), max_len=2 + ) cache["key"] = "value" cache["key2"] = "value2" @@ -43,9 +49,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key2"), "value2") self.assertEqual(cache.get("key3"), "value3") - def test_iterable_eviction(self): + def test_iterable_eviction(self) -> None: clock = MockClock() - cache = ExpiringCache("test", clock, max_len=5, iterable=True) + cache: ExpiringCache[str, List[int]] = ExpiringCache( + "test", cast(Clock, clock), max_len=5, iterable=True + ) cache["key"] = [1] cache["key2"] = [2, 3] @@ -61,9 +69,11 @@ class ExpiringCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get("key3"), [4, 5]) self.assertEqual(cache.get("key4"), [6, 7]) - def test_time_eviction(self): + def test_time_eviction(self) -> None: clock = MockClock() - cache = ExpiringCache("test", clock, expiry_ms=1000) + cache: ExpiringCache[str, int] = ExpiringCache( + "test", cast(Clock, clock), expiry_ms=1000 + ) cache["key"] = 1 clock.advance_time(0.5) diff --git a/tests/util/test_file_consumer.py b/tests/util/test_file_consumer.py index 3bb4695405..4f3c983c15 100644 --- a/tests/util/test_file_consumer.py +++ b/tests/util/test_file_consumer.py @@ -12,22 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. - import threading -from io import StringIO +from io import BytesIO +from typing import BinaryIO, Generator, Optional, cast from unittest.mock import NonCallableMock -from twisted.internet import defer, reactor +from zope.interface import implementer + +from twisted.internet import defer, reactor as _reactor +from twisted.internet.interfaces import IPullProducer +from synapse.types import ISynapseReactor from synapse.util.file_consumer import BackgroundFileConsumer from tests import unittest +reactor = cast(ISynapseReactor, _reactor) + class FileConsumerTests(unittest.TestCase): @defer.inlineCallbacks - def test_pull_consumer(self): - string_file = StringIO() + def test_pull_consumer(self) -> Generator["defer.Deferred[object]", object, None]: + string_file = BytesIO() consumer = BackgroundFileConsumer(string_file, reactor=reactor) try: @@ -35,55 +41,57 @@ class FileConsumerTests(unittest.TestCase): yield producer.register_with_consumer(consumer) - yield producer.write_and_wait("Foo") + yield producer.write_and_wait(b"Foo") - self.assertEqual(string_file.getvalue(), "Foo") + self.assertEqual(string_file.getvalue(), b"Foo") - yield producer.write_and_wait("Bar") + yield producer.write_and_wait(b"Bar") - self.assertEqual(string_file.getvalue(), "FooBar") + self.assertEqual(string_file.getvalue(), b"FooBar") finally: consumer.unregisterProducer() - yield consumer.wait() + yield consumer.wait() # type: ignore[misc] self.assertTrue(string_file.closed) @defer.inlineCallbacks - def test_push_consumer(self): - string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file, reactor=reactor) + def test_push_consumer(self) -> Generator["defer.Deferred[object]", object, None]: + string_file = BlockingBytesWrite() + consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor) try: producer = NonCallableMock(spec_set=[]) consumer.registerProducer(producer, True) - consumer.write("Foo") - yield string_file.wait_for_n_writes(1) + consumer.write(b"Foo") + yield string_file.wait_for_n_writes(1) # type: ignore[misc] - self.assertEqual(string_file.buffer, "Foo") + self.assertEqual(string_file.buffer, b"Foo") - consumer.write("Bar") - yield string_file.wait_for_n_writes(2) + consumer.write(b"Bar") + yield string_file.wait_for_n_writes(2) # type: ignore[misc] - self.assertEqual(string_file.buffer, "FooBar") + self.assertEqual(string_file.buffer, b"FooBar") finally: consumer.unregisterProducer() - yield consumer.wait() + yield consumer.wait() # type: ignore[misc] self.assertTrue(string_file.closed) @defer.inlineCallbacks - def test_push_producer_feedback(self): - string_file = BlockingStringWrite() - consumer = BackgroundFileConsumer(string_file, reactor=reactor) + def test_push_producer_feedback( + self, + ) -> Generator["defer.Deferred[object]", object, None]: + string_file = BlockingBytesWrite() + consumer = BackgroundFileConsumer(cast(BinaryIO, string_file), reactor=reactor) try: producer = NonCallableMock(spec_set=["pauseProducing", "resumeProducing"]) - resume_deferred = defer.Deferred() + resume_deferred: defer.Deferred = defer.Deferred() producer.resumeProducing.side_effect = lambda: resume_deferred.callback( None ) @@ -93,65 +101,72 @@ class FileConsumerTests(unittest.TestCase): number_writes = 0 with string_file.write_lock: for _ in range(consumer._PAUSE_ON_QUEUE_SIZE): - consumer.write("Foo") + consumer.write(b"Foo") number_writes += 1 producer.pauseProducing.assert_called_once() - yield string_file.wait_for_n_writes(number_writes) + yield string_file.wait_for_n_writes(number_writes) # type: ignore[misc] yield resume_deferred producer.resumeProducing.assert_called_once() finally: consumer.unregisterProducer() - yield consumer.wait() + yield consumer.wait() # type: ignore[misc] self.assertTrue(string_file.closed) +@implementer(IPullProducer) class DummyPullProducer: - def __init__(self): - self.consumer = None - self.deferred = defer.Deferred() + def __init__(self) -> None: + self.consumer: Optional[BackgroundFileConsumer] = None + self.deferred: "defer.Deferred[object]" = defer.Deferred() - def resumeProducing(self): + def resumeProducing(self) -> None: d = self.deferred self.deferred = defer.Deferred() d.callback(None) - def write_and_wait(self, bytes): + def stopProducing(self) -> None: + raise RuntimeError("Unexpected call") + + def write_and_wait(self, write_bytes: bytes) -> "defer.Deferred[object]": + assert self.consumer is not None d = self.deferred - self.consumer.write(bytes) + self.consumer.write(write_bytes) return d - def register_with_consumer(self, consumer): + def register_with_consumer( + self, consumer: BackgroundFileConsumer + ) -> "defer.Deferred[object]": d = self.deferred self.consumer = consumer self.consumer.registerProducer(self, False) return d -class BlockingStringWrite: - def __init__(self): - self.buffer = "" +class BlockingBytesWrite: + def __init__(self) -> None: + self.buffer = b"" self.closed = False self.write_lock = threading.Lock() - self._notify_write_deferred = None + self._notify_write_deferred: Optional[defer.Deferred] = None self._number_of_writes = 0 - def write(self, bytes): + def write(self, write_bytes: bytes) -> None: with self.write_lock: - self.buffer += bytes + self.buffer += write_bytes self._number_of_writes += 1 reactor.callFromThread(self._notify_write) - def close(self): + def close(self) -> None: self.closed = True - def _notify_write(self): + def _notify_write(self) -> None: "Called by write to indicate a write happened" with self.write_lock: if not self._notify_write_deferred: @@ -161,7 +176,9 @@ class BlockingStringWrite: d.callback(None) @defer.inlineCallbacks - def wait_for_n_writes(self, n): + def wait_for_n_writes( + self, n: int + ) -> Generator["defer.Deferred[object]", object, None]: "Wait for n writes to have happened" while True: with self.write_lock: diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py index 3c0ddd4f18..406c16cdcf 100644 --- a/tests/util/test_itertools.py +++ b/tests/util/test_itertools.py @@ -19,7 +19,7 @@ from tests.unittest import TestCase class ChunkSeqTests(TestCase): - def test_short_seq(self): + def test_short_seq(self) -> None: parts = chunk_seq("123", 8) self.assertEqual( @@ -27,7 +27,7 @@ class ChunkSeqTests(TestCase): ["123"], ) - def test_long_seq(self): + def test_long_seq(self) -> None: parts = chunk_seq("abcdefghijklmnop", 8) self.assertEqual( @@ -35,7 +35,7 @@ class ChunkSeqTests(TestCase): ["abcdefgh", "ijklmnop"], ) - def test_uneven_parts(self): + def test_uneven_parts(self) -> None: parts = chunk_seq("abcdefghijklmnop", 5) self.assertEqual( @@ -43,7 +43,7 @@ class ChunkSeqTests(TestCase): ["abcde", "fghij", "klmno", "p"], ) - def test_empty_input(self): + def test_empty_input(self) -> None: parts: Iterable[Sequence] = chunk_seq([], 5) self.assertEqual( @@ -53,13 +53,13 @@ class ChunkSeqTests(TestCase): class SortTopologically(TestCase): - def test_empty(self): + def test_empty(self) -> None: "Test that an empty graph works correctly" graph: Dict[int, List[int]] = {} self.assertEqual(list(sorted_topologically([], graph)), []) - def test_handle_empty_graph(self): + def test_handle_empty_graph(self) -> None: "Test that a graph where a node doesn't have an entry is treated as empty" graph: Dict[int, List[int]] = {} @@ -67,7 +67,7 @@ class SortTopologically(TestCase): # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) - def test_disconnected(self): + def test_disconnected(self) -> None: "Test that a graph with no edges work" graph: Dict[int, List[int]] = {1: [], 2: []} @@ -75,20 +75,20 @@ class SortTopologically(TestCase): # For disconnected nodes the output is simply sorted. self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2]) - def test_linear(self): + def test_linear(self) -> None: "Test that a simple `4 -> 3 -> 2 -> 1` graph works" graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) - def test_subset(self): + def test_subset(self) -> None: "Test that only sorting a subset of the graph works" graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3], graph)), [3, 4]) - def test_fork(self): + def test_fork(self) -> None: "Test that a forked graph works" graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [1], 4: [2, 3]} @@ -96,13 +96,13 @@ class SortTopologically(TestCase): # always get the same one. self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) - def test_duplicates(self): + def test_duplicates(self) -> None: "Test that a graph with duplicate edges work" graph: Dict[int, List[int]] = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) - def test_multiple_paths(self): + def test_multiple_paths(self) -> None: "Test that a graph with multiple paths between two nodes work" graph: Dict[int, List[int]] = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 2ad321e184..d64c162e1d 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -1,5 +1,21 @@ +# Copyright 2014-2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Generator, cast + import twisted.python.failure -from twisted.internet import defer, reactor +from twisted.internet import defer, reactor as _reactor from synapse.logging.context import ( SENTINEL_CONTEXT, @@ -10,25 +26,30 @@ from synapse.logging.context import ( nested_logging_context, run_in_background, ) +from synapse.types import ISynapseReactor from synapse.util import Clock from .. import unittest +reactor = cast(ISynapseReactor, _reactor) + class LoggingContextTestCase(unittest.TestCase): - def _check_test_key(self, value): - self.assertEqual(current_context().name, value) + def _check_test_key(self, value: str) -> None: + context = current_context() + assert isinstance(context, LoggingContext) + self.assertEqual(context.name, value) - def test_with_context(self): + def test_with_context(self) -> None: with LoggingContext("test"): self._check_test_key("test") @defer.inlineCallbacks - def test_sleep(self): + def test_sleep(self) -> Generator["defer.Deferred[object]", object, None]: clock = Clock(reactor) @defer.inlineCallbacks - def competing_callback(): + def competing_callback() -> Generator["defer.Deferred[object]", object, None]: with LoggingContext("competing"): yield clock.sleep(0) self._check_test_key("competing") @@ -39,17 +60,18 @@ class LoggingContextTestCase(unittest.TestCase): yield clock.sleep(0) self._check_test_key("one") - def _test_run_in_background(self, function): + def _test_run_in_background(self, function: Callable[[], object]) -> defer.Deferred: sentinel_context = current_context() - callback_completed = [False] + callback_completed = False with LoggingContext("one"): # fire off function, but don't wait on it. d2 = run_in_background(function) - def cb(res): - callback_completed[0] = True + def cb(res: object) -> object: + nonlocal callback_completed + callback_completed = True return res d2.addCallback(cb) @@ -60,8 +82,8 @@ class LoggingContextTestCase(unittest.TestCase): # the logcontext is left in a sane state. d2 = defer.Deferred() - def check_logcontext(): - if not callback_completed[0]: + def check_logcontext() -> None: + if not callback_completed: reactor.callLater(0.01, check_logcontext) return @@ -78,31 +100,31 @@ class LoggingContextTestCase(unittest.TestCase): # test is done once d2 finishes return d2 - def test_run_in_background_with_blocking_fn(self): + def test_run_in_background_with_blocking_fn(self) -> defer.Deferred: @defer.inlineCallbacks - def blocking_function(): + def blocking_function() -> Generator["defer.Deferred[object]", object, None]: yield Clock(reactor).sleep(0) return self._test_run_in_background(blocking_function) - def test_run_in_background_with_non_blocking_fn(self): + def test_run_in_background_with_non_blocking_fn(self) -> defer.Deferred: @defer.inlineCallbacks - def nonblocking_function(): + def nonblocking_function() -> Generator["defer.Deferred[object]", object, None]: with PreserveLoggingContext(): yield defer.succeed(None) return self._test_run_in_background(nonblocking_function) - def test_run_in_background_with_chained_deferred(self): + def test_run_in_background_with_chained_deferred(self) -> defer.Deferred: # a function which returns a deferred which looks like it has been # called, but is actually paused - def testfunc(): + def testfunc() -> defer.Deferred: return make_deferred_yieldable(_chained_deferred_function()) return self._test_run_in_background(testfunc) - def test_run_in_background_with_coroutine(self): - async def testfunc(): + def test_run_in_background_with_coroutine(self) -> defer.Deferred: + async def testfunc() -> None: self._check_test_key("one") d = Clock(reactor).sleep(0) self.assertIs(current_context(), SENTINEL_CONTEXT) @@ -111,18 +133,20 @@ class LoggingContextTestCase(unittest.TestCase): return self._test_run_in_background(testfunc) - def test_run_in_background_with_nonblocking_coroutine(self): - async def testfunc(): + def test_run_in_background_with_nonblocking_coroutine(self) -> defer.Deferred: + async def testfunc() -> None: self._check_test_key("one") return self._test_run_in_background(testfunc) @defer.inlineCallbacks - def test_make_deferred_yieldable(self): + def test_make_deferred_yieldable( + self, + ) -> Generator["defer.Deferred[object]", object, None]: # a function which returns an incomplete deferred, but doesn't follow # the synapse rules. - def blocking_function(): - d = defer.Deferred() + def blocking_function() -> defer.Deferred: + d: defer.Deferred = defer.Deferred() reactor.callLater(0, d.callback, None) return d @@ -139,7 +163,9 @@ class LoggingContextTestCase(unittest.TestCase): self._check_test_key("one") @defer.inlineCallbacks - def test_make_deferred_yieldable_with_chained_deferreds(self): + def test_make_deferred_yieldable_with_chained_deferreds( + self, + ) -> Generator["defer.Deferred[object]", object, None]: sentinel_context = current_context() with LoggingContext("one"): @@ -152,7 +178,7 @@ class LoggingContextTestCase(unittest.TestCase): # now it should be restored self._check_test_key("one") - def test_nested_logging_context(self): + def test_nested_logging_context(self) -> None: with LoggingContext("foo"): nested_context = nested_logging_context(suffix="bar") self.assertEqual(nested_context.name, "foo-bar") @@ -161,11 +187,11 @@ class LoggingContextTestCase(unittest.TestCase): # a function which returns a deferred which has been "called", but # which had a function which returned another incomplete deferred on # its callback list, so won't yet call any other new callbacks. -def _chained_deferred_function(): +def _chained_deferred_function() -> defer.Deferred: d = defer.succeed(None) - def cb(res): - d2 = defer.Deferred() + def cb(res: object) -> defer.Deferred: + d2: defer.Deferred = defer.Deferred() reactor.callLater(0, d2.callback, res) return d2 diff --git a/tests/util/test_logformatter.py b/tests/util/test_logformatter.py index a2e08281e6..0dee69a6fe 100644 --- a/tests/util/test_logformatter.py +++ b/tests/util/test_logformatter.py @@ -23,7 +23,7 @@ class TestException(Exception): class LogFormatterTestCase(unittest.TestCase): - def test_formatter(self): + def test_formatter(self) -> None: formatter = LogFormatter() try: diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 67173a4f5b..1fc5a473f0 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -13,10 +13,11 @@ # limitations under the License. -from typing import List +from typing import List, Tuple from unittest.mock import Mock, patch from synapse.metrics.jemalloc import JemallocStats +from synapse.types import JsonDict from synapse.util.caches.lrucache import LruCache, setup_expire_lru_cache_entries from synapse.util.caches.treecache import TreeCache @@ -25,14 +26,14 @@ from tests.unittest import override_config class LruCacheTestCase(unittest.HomeserverTestCase): - def test_get_set(self): - cache = LruCache(1) + def test_get_set(self) -> None: + cache: LruCache[str, str] = LruCache(1) cache["key"] = "value" self.assertEqual(cache.get("key"), "value") self.assertEqual(cache["key"], "value") - def test_eviction(self): - cache = LruCache(2) + def test_eviction(self) -> None: + cache: LruCache[int, int] = LruCache(2) cache[1] = 1 cache[2] = 2 @@ -45,8 +46,8 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get(2), 2) self.assertEqual(cache.get(3), 3) - def test_setdefault(self): - cache = LruCache(1) + def test_setdefault(self) -> None: + cache: LruCache[str, int] = LruCache(1) self.assertEqual(cache.setdefault("key", 1), 1) self.assertEqual(cache.get("key"), 1) self.assertEqual(cache.setdefault("key", 2), 1) @@ -54,14 +55,15 @@ class LruCacheTestCase(unittest.HomeserverTestCase): cache["key"] = 2 # Make sure overriding works. self.assertEqual(cache.get("key"), 2) - def test_pop(self): - cache = LruCache(1) + def test_pop(self) -> None: + cache: LruCache[str, int] = LruCache(1) cache["key"] = 1 self.assertEqual(cache.pop("key"), 1) self.assertEqual(cache.pop("key"), None) - def test_del_multi(self): - cache = LruCache(4, cache_type=TreeCache) + def test_del_multi(self) -> None: + # The type here isn't quite correct as they don't handle TreeCache well. + cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" cache[("vehicles", "car")] = "vroom" @@ -71,7 +73,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get(("animal", "cat")), "mew") self.assertEqual(cache.get(("vehicles", "car")), "vroom") - cache.del_multi(("animal",)) + cache.del_multi(("animal",)) # type: ignore[arg-type] self.assertEqual(len(cache), 2) self.assertEqual(cache.get(("animal", "cat")), None) self.assertEqual(cache.get(("animal", "dog")), None) @@ -79,22 +81,22 @@ class LruCacheTestCase(unittest.HomeserverTestCase): self.assertEqual(cache.get(("vehicles", "train")), "chuff") # Man from del_multi say "Yes". - def test_clear(self): - cache = LruCache(1) + def test_clear(self) -> None: + cache: LruCache[str, int] = LruCache(1) cache["key"] = 1 cache.clear() self.assertEqual(len(cache), 0) @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) - def test_special_size(self): - cache = LruCache(10, "mycache") + def test_special_size(self) -> None: + cache: LruCache = LruCache(10, "mycache") self.assertEqual(cache.max_size, 100) class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): - def test_get(self): + def test_get(self) -> None: m = Mock() - cache = LruCache(1) + cache: LruCache[str, str] = LruCache(1) cache.set("key", "value") self.assertFalse(m.called) @@ -111,9 +113,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key", "value") self.assertEqual(m.call_count, 1) - def test_multi_get(self): + def test_multi_get(self) -> None: m = Mock() - cache = LruCache(1) + cache: LruCache[str, str] = LruCache(1) cache.set("key", "value") self.assertFalse(m.called) @@ -130,9 +132,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key", "value") self.assertEqual(m.call_count, 1) - def test_set(self): + def test_set(self) -> None: m = Mock() - cache = LruCache(1) + cache: LruCache[str, str] = LruCache(1) cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) @@ -146,9 +148,9 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.set("key", "value") self.assertEqual(m.call_count, 1) - def test_pop(self): + def test_pop(self) -> None: m = Mock() - cache = LruCache(1) + cache: LruCache[str, str] = LruCache(1) cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) @@ -162,12 +164,13 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): cache.pop("key") self.assertEqual(m.call_count, 1) - def test_del_multi(self): + def test_del_multi(self) -> None: m1 = Mock() m2 = Mock() m3 = Mock() m4 = Mock() - cache = LruCache(4, cache_type=TreeCache) + # The type here isn't quite correct as they don't handle TreeCache well. + cache: LruCache[Tuple[str, str], str] = LruCache(4, cache_type=TreeCache) cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "2"), "value", callbacks=[m2]) @@ -179,17 +182,17 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertEqual(m3.call_count, 0) self.assertEqual(m4.call_count, 0) - cache.del_multi(("a",)) + cache.del_multi(("a",)) # type: ignore[arg-type] self.assertEqual(m1.call_count, 1) self.assertEqual(m2.call_count, 1) self.assertEqual(m3.call_count, 0) self.assertEqual(m4.call_count, 0) - def test_clear(self): + def test_clear(self) -> None: m1 = Mock() m2 = Mock() - cache = LruCache(5) + cache: LruCache[str, str] = LruCache(5) cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) @@ -202,11 +205,11 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): self.assertEqual(m1.call_count, 1) self.assertEqual(m2.call_count, 1) - def test_eviction(self): + def test_eviction(self) -> None: m1 = Mock(name="m1") m2 = Mock(name="m2") m3 = Mock(name="m3") - cache = LruCache(2) + cache: LruCache[str, str] = LruCache(2) cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) @@ -241,8 +244,8 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): class LruCacheSizedTestCase(unittest.HomeserverTestCase): - def test_evict(self): - cache = LruCache(5, size_callback=len) + def test_evict(self) -> None: + cache: LruCache[str, List[int]] = LruCache(5, size_callback=len) cache["key1"] = [0] cache["key2"] = [1, 2] cache["key3"] = [3] @@ -269,6 +272,7 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): cache["key1"] = [] self.assertEqual(len(cache), 0) + assert isinstance(cache.cache, dict) cache.cache["key1"].drop_from_cache() self.assertIsNone( cache.pop("key1"), "Cache entry should have been evicted but wasn't" @@ -278,17 +282,17 @@ class LruCacheSizedTestCase(unittest.HomeserverTestCase): class TimeEvictionTestCase(unittest.HomeserverTestCase): """Test that time based eviction works correctly.""" - def default_config(self): + def default_config(self) -> JsonDict: config = super().default_config() config.setdefault("caches", {})["expiry_time"] = "30m" return config - def test_evict(self): + def test_evict(self) -> None: setup_expire_lru_cache_entries(self.hs) - cache = LruCache(5, clock=self.hs.get_clock()) + cache: LruCache[str, int] = LruCache(5, clock=self.hs.get_clock()) # Check that we evict entries we haven't accessed for 30 minutes. cache["key1"] = 1 @@ -332,7 +336,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase): } ) @patch("synapse.util.caches.lrucache.get_jemalloc_stats") - def test_evict_memory(self, jemalloc_interface) -> None: + def test_evict_memory(self, jemalloc_interface: Mock) -> None: mock_jemalloc_class = Mock(spec=JemallocStats) jemalloc_interface.return_value = mock_jemalloc_class @@ -340,7 +344,7 @@ class MemoryEvictionTestCase(unittest.HomeserverTestCase): mock_jemalloc_class.get_stat.return_value = 924288000 setup_expire_lru_cache_entries(self.hs) - cache = LruCache(4, clock=self.hs.get_clock()) + cache: LruCache[str, int] = LruCache(4, clock=self.hs.get_clock()) cache["key1"] = 1 cache["key2"] = 2 diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py index 40754a4711..f68377a05a 100644 --- a/tests/util/test_macaroons.py +++ b/tests/util/test_macaroons.py @@ -21,14 +21,14 @@ from tests.unittest import TestCase class MacaroonGeneratorTestCase(TestCase): - def setUp(self): + def setUp(self) -> None: self.reactor, hs_clock = get_clock() self.macaroon_generator = MacaroonGenerator(hs_clock, "tesths", b"verysecret") self.other_macaroon_generator = MacaroonGenerator( hs_clock, "tesths", b"anothersecretkey" ) - def test_guest_access_token(self): + def test_guest_access_token(self) -> None: """Test the generation and verification of guest access tokens""" token = self.macaroon_generator.generate_guest_access_token("@user:tesths") user_id = self.macaroon_generator.verify_guest_token(token) @@ -47,7 +47,7 @@ class MacaroonGeneratorTestCase(TestCase): with self.assertRaises(MacaroonVerificationFailedException): self.macaroon_generator.verify_guest_token(token) - def test_delete_pusher_token(self): + def test_delete_pusher_token(self) -> None: """Test the generation and verification of delete_pusher tokens""" token = self.macaroon_generator.generate_delete_pusher_token( "@user:tesths", "m.mail", "john@example.com" @@ -84,7 +84,7 @@ class MacaroonGeneratorTestCase(TestCase): ) self.assertEqual(user_id, "@user:tesths") - def test_oidc_session_token(self): + def test_oidc_session_token(self) -> None: """Test the generation and verification of OIDC session cookies""" state = "arandomstate" session_data = OidcSessionData( diff --git a/tests/util/test_ratelimitutils.py b/tests/util/test_ratelimitutils.py index 89d8656634..5b327b390e 100644 --- a/tests/util/test_ratelimitutils.py +++ b/tests/util/test_ratelimitutils.py @@ -13,16 +13,19 @@ # limitations under the License. from typing import Optional +from twisted.internet.defer import Deferred + from synapse.config.homeserver import HomeServerConfig +from synapse.config.ratelimiting import FederationRatelimitSettings from synapse.util.ratelimitutils import FederationRateLimiter -from tests.server import get_clock +from tests.server import ThreadedMemoryReactorClock, get_clock from tests.unittest import TestCase from tests.utils import default_config class FederationRateLimiterTestCase(TestCase): - def test_ratelimit(self): + def test_ratelimit(self) -> None: """A simple test with the default values""" reactor, clock = get_clock() rc_config = build_rc_config() @@ -32,7 +35,7 @@ class FederationRateLimiterTestCase(TestCase): # shouldn't block self.successResultOf(d1) - def test_concurrent_limit(self): + def test_concurrent_limit(self) -> None: """Test what happens when we hit the concurrent limit""" reactor, clock = get_clock() rc_config = build_rc_config({"rc_federation": {"concurrent": 2}}) @@ -56,7 +59,7 @@ class FederationRateLimiterTestCase(TestCase): cm2.__exit__(None, None, None) self.successResultOf(d3) - def test_sleep_limit(self): + def test_sleep_limit(self) -> None: """Test what happens when we hit the sleep limit""" reactor, clock = get_clock() rc_config = build_rc_config( @@ -79,7 +82,7 @@ class FederationRateLimiterTestCase(TestCase): self.assertAlmostEqual(sleep_time, 500, places=3) -def _await_resolution(reactor, d): +def _await_resolution(reactor: ThreadedMemoryReactorClock, d: Deferred) -> float: """advance the clock until the deferred completes. Returns the number of milliseconds it took to complete. @@ -90,7 +93,7 @@ def _await_resolution(reactor, d): return (reactor.seconds() - start_time) * 1000 -def build_rc_config(settings: Optional[dict] = None): +def build_rc_config(settings: Optional[dict] = None) -> FederationRatelimitSettings: config_dict = default_config("test") config_dict.update(settings or {}) config = HomeServerConfig() diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 26cb71c640..9529ee53c8 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -22,7 +22,7 @@ from tests.unittest import HomeserverTestCase class RetryLimiterTestCase(HomeserverTestCase): - def test_new_destination(self): + def test_new_destination(self) -> None: """A happy-path case with a new destination and a successful operation""" store = self.hs.get_datastores().main limiter = self.get_success(get_retry_limiter("test_dest", self.clock, store)) @@ -36,7 +36,7 @@ class RetryLimiterTestCase(HomeserverTestCase): new_timings = self.get_success(store.get_destination_retry_timings("test_dest")) self.assertIsNone(new_timings) - def test_limiter(self): + def test_limiter(self) -> None: """General test case which walks through the process of a failing request""" store = self.hs.get_datastores().main diff --git a/tests/util/test_rwlock.py b/tests/util/test_rwlock.py index 5da04362a9..bc93de62eb 100644 --- a/tests/util/test_rwlock.py +++ b/tests/util/test_rwlock.py @@ -49,7 +49,7 @@ class ReadWriteLockTestCase(unittest.TestCase): acquired_d: "Deferred[None]" = Deferred() unblock_d: "Deferred[None]" = Deferred() - async def reader_or_writer(): + async def reader_or_writer() -> str: async with read_or_write(key): acquired_d.callback(None) await unblock_d @@ -134,7 +134,7 @@ class ReadWriteLockTestCase(unittest.TestCase): d.called, msg="deferred %d was unexpectedly resolved" % (i + n) ) - def test_rwlock(self): + def test_rwlock(self) -> None: rwlock = ReadWriteLock() key = "key" @@ -197,7 +197,7 @@ class ReadWriteLockTestCase(unittest.TestCase): _, acquired_d = self._start_nonblocking_reader(rwlock, key, "last reader") self.assertTrue(acquired_d.called) - def test_lock_handoff_to_nonblocking_writer(self): + def test_lock_handoff_to_nonblocking_writer(self) -> None: """Test a writer handing the lock to another writer that completes instantly.""" rwlock = ReadWriteLock() key = "key" @@ -216,7 +216,7 @@ class ReadWriteLockTestCase(unittest.TestCase): d3, _ = self._start_nonblocking_writer(rwlock, key, "write 3 completed") self.assertTrue(d3.called) - def test_cancellation_while_holding_read_lock(self): + def test_cancellation_while_holding_read_lock(self) -> None: """Test cancellation while holding a read lock. A waiting writer should be given the lock when the reader holding the lock is @@ -242,7 +242,7 @@ class ReadWriteLockTestCase(unittest.TestCase): ) self.assertEqual("write completed", self.successResultOf(writer_d)) - def test_cancellation_while_holding_write_lock(self): + def test_cancellation_while_holding_write_lock(self) -> None: """Test cancellation while holding a write lock. A waiting reader should be given the lock when the writer holding the lock is @@ -268,7 +268,7 @@ class ReadWriteLockTestCase(unittest.TestCase): ) self.assertEqual("read completed", self.successResultOf(reader_d)) - def test_cancellation_while_waiting_for_read_lock(self): + def test_cancellation_while_waiting_for_read_lock(self) -> None: """Test cancellation while waiting for a read lock. Tests that cancelling a waiting reader: @@ -319,7 +319,7 @@ class ReadWriteLockTestCase(unittest.TestCase): ) self.assertEqual("write 2 completed", self.successResultOf(writer2_d)) - def test_cancellation_while_waiting_for_write_lock(self): + def test_cancellation_while_waiting_for_write_lock(self) -> None: """Test cancellation while waiting for a write lock. Tests that cancelling a waiting writer: diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 9ed01f7e0c..1b0fa52ad1 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -8,7 +8,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): Tests for StreamChangeCache. """ - def test_prefilled_cache(self): + def test_prefilled_cache(self) -> None: """ Providing a prefilled cache to StreamChangeCache will result in a cache with the prefilled-cache entered in. @@ -16,7 +16,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): cache = StreamChangeCache("#test", 1, prefilled_cache={"user@foo.com": 2}) self.assertTrue(cache.has_entity_changed("user@foo.com", 1)) - def test_has_entity_changed(self): + def test_has_entity_changed(self) -> None: """ StreamChangeCache.entity_has_changed will mark entities as changed, and has_entity_changed will observe the changed entities. @@ -52,7 +52,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) - def test_entity_has_changed_pops_off_start(self): + def test_entity_has_changed_pops_off_start(self) -> None: """ StreamChangeCache.entity_has_changed will respect the max size and purge the oldest items upon reaching that max size. @@ -86,7 +86,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): ) self.assertIsNone(cache.get_all_entities_changed(1)) - def test_get_all_entities_changed(self): + def test_get_all_entities_changed(self) -> None: """ StreamChangeCache.get_all_entities_changed will return all changed entities since the given position. If the position is before the start @@ -142,7 +142,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): r = cache.get_all_entities_changed(3) self.assertTrue(r == ok1 or r == ok2) - def test_has_any_entity_changed(self): + def test_has_any_entity_changed(self) -> None: """ StreamChangeCache.has_any_entity_changed will return True if any entities have been changed since the provided stream position, and @@ -168,7 +168,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(3)) - def test_get_entities_changed(self): + def test_get_entities_changed(self) -> None: """ StreamChangeCache.get_entities_changed will return the entities in the given list that have changed since the provided stream ID. If the @@ -228,7 +228,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): {"bar@baz.net"}, ) - def test_max_pos(self): + def test_max_pos(self) -> None: """ StreamChangeCache.get_max_pos_of_last_change will return the most recent point where the entity could have changed. If the entity is not diff --git a/tests/util/test_stringutils.py b/tests/util/test_stringutils.py index ad4dd7f007..f137e05191 100644 --- a/tests/util/test_stringutils.py +++ b/tests/util/test_stringutils.py @@ -19,7 +19,7 @@ from .. import unittest class StringUtilsTestCase(unittest.TestCase): - def test_client_secret_regex(self): + def test_client_secret_regex(self) -> None: """Ensure that client_secret does not contain illegal characters""" good = [ "abcde12345", @@ -46,7 +46,7 @@ class StringUtilsTestCase(unittest.TestCase): with self.assertRaises(SynapseError): assert_valid_client_secret(client_secret) - def test_base62_encode(self): + def test_base62_encode(self) -> None: self.assertEqual("0", base62_encode(0)) self.assertEqual("10", base62_encode(62)) self.assertEqual("1c", base62_encode(100)) diff --git a/tests/util/test_threepids.py b/tests/util/test_threepids.py index d957b953bb..3b35b8e4ec 100644 --- a/tests/util/test_threepids.py +++ b/tests/util/test_threepids.py @@ -18,31 +18,31 @@ from tests.unittest import HomeserverTestCase class CanonicaliseEmailTests(HomeserverTestCase): - def test_no_at(self): + def test_no_at(self) -> None: with self.assertRaises(ValueError): canonicalise_email("address-without-at.bar") - def test_two_at(self): + def test_two_at(self) -> None: with self.assertRaises(ValueError): canonicalise_email("foo@foo@test.bar") - def test_bad_format(self): + def test_bad_format(self) -> None: with self.assertRaises(ValueError): canonicalise_email("user@bad.example.net@good.example.com") - def test_valid_format(self): + def test_valid_format(self) -> None: self.assertEqual(canonicalise_email("foo@test.bar"), "foo@test.bar") - def test_domain_to_lower(self): + def test_domain_to_lower(self) -> None: self.assertEqual(canonicalise_email("foo@TEST.BAR"), "foo@test.bar") - def test_domain_with_umlaut(self): + def test_domain_with_umlaut(self) -> None: self.assertEqual(canonicalise_email("foo@Öumlaut.com"), "foo@öumlaut.com") - def test_address_casefold(self): + def test_address_casefold(self) -> None: self.assertEqual( canonicalise_email("Strauß@Example.com"), "strauss@example.com" ) - def test_address_trim(self): + def test_address_trim(self) -> None: self.assertEqual(canonicalise_email(" foo@test.bar "), "foo@test.bar") diff --git a/tests/util/test_treecache.py b/tests/util/test_treecache.py index 567cb18468..fe3b4dc6a4 100644 --- a/tests/util/test_treecache.py +++ b/tests/util/test_treecache.py @@ -19,7 +19,7 @@ from .. import unittest class TreeCacheTestCase(unittest.TestCase): - def test_get_set_onelevel(self): + def test_get_set_onelevel(self) -> None: cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" @@ -27,7 +27,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual(cache.get(("b",)), "B") self.assertEqual(len(cache), 2) - def test_pop_onelevel(self): + def test_pop_onelevel(self) -> None: cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" @@ -36,7 +36,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual(cache.get(("b",)), "B") self.assertEqual(len(cache), 1) - def test_get_set_twolevel(self): + def test_get_set_twolevel(self) -> None: cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" @@ -46,7 +46,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual(cache.get(("b", "a")), "BA") self.assertEqual(len(cache), 3) - def test_pop_twolevel(self): + def test_pop_twolevel(self) -> None: cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" @@ -58,7 +58,7 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual(cache.pop(("b", "a")), None) self.assertEqual(len(cache), 1) - def test_pop_mixedlevel(self): + def test_pop_mixedlevel(self) -> None: cache = TreeCache() cache[("a", "a")] = "AA" cache[("a", "b")] = "AB" @@ -72,14 +72,14 @@ class TreeCacheTestCase(unittest.TestCase): self.assertEqual({"AA", "AB"}, set(iterate_tree_cache_entry(popped))) - def test_clear(self): + def test_clear(self) -> None: cache = TreeCache() cache[("a",)] = "A" cache[("b",)] = "B" cache.clear() self.assertEqual(len(cache), 0) - def test_contains(self): + def test_contains(self) -> None: cache = TreeCache() cache[("a",)] = "A" self.assertTrue(("a",) in cache) diff --git a/tests/util/test_wheel_timer.py b/tests/util/test_wheel_timer.py index 0d5039de04..c9d22b6d8c 100644 --- a/tests/util/test_wheel_timer.py +++ b/tests/util/test_wheel_timer.py @@ -18,8 +18,8 @@ from .. import unittest class WheelTimerTestCase(unittest.TestCase): - def test_single_insert_fetch(self): - wheel = WheelTimer(bucket_size=5) + def test_single_insert_fetch(self) -> None: + wheel: WheelTimer[object] = WheelTimer(bucket_size=5) obj = object() wheel.insert(100, obj, 150) @@ -32,8 +32,8 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(156), [obj]) self.assertListEqual(wheel.fetch(170), []) - def test_multi_insert(self): - wheel = WheelTimer(bucket_size=5) + def test_multi_insert(self) -> None: + wheel: WheelTimer[object] = WheelTimer(bucket_size=5) obj1 = object() obj2 = object() @@ -50,15 +50,15 @@ class WheelTimerTestCase(unittest.TestCase): self.assertListEqual(wheel.fetch(200), [obj3]) self.assertListEqual(wheel.fetch(210), []) - def test_insert_past(self): - wheel = WheelTimer(bucket_size=5) + def test_insert_past(self) -> None: + wheel: WheelTimer[object] = WheelTimer(bucket_size=5) obj = object() wheel.insert(100, obj, 50) self.assertListEqual(wheel.fetch(120), [obj]) - def test_insert_past_multi(self): - wheel = WheelTimer(bucket_size=5) + def test_insert_past_multi(self) -> None: + wheel: WheelTimer[object] = WheelTimer(bucket_size=5) obj1 = object() obj2 = object() -- cgit 1.5.1 From 6a8310f3dfe77acf59df2fe3e88a71b85b9b3ecc Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 5 Dec 2022 09:00:59 -0500 Subject: Compare to the earliest known stream pos in the stream change cache. (#14435) The internal methods of the StreamChangeCache were inconsistently treating the earliest known stream position as valid. It is now treated as invalid, meaning the cache cannot determine if an entity at the earliest known stream position has changed or not. --- changelog.d/14435.bugfix | 1 + poetry.lock | 2 +- pyproject.toml | 3 +- synapse/util/caches/stream_change_cache.py | 142 +++++++++++++++++++++++------ tests/util/test_stream_change_cache.py | 38 +++----- 5 files changed, 133 insertions(+), 53 deletions(-) create mode 100644 changelog.d/14435.bugfix (limited to 'tests/util') diff --git a/changelog.d/14435.bugfix b/changelog.d/14435.bugfix new file mode 100644 index 0000000000..149ee99dd7 --- /dev/null +++ b/changelog.d/14435.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances. diff --git a/poetry.lock b/poetry.lock index 8c63134578..90b363a548 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1639,7 +1639,7 @@ url-preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "27811bd21d56ceeb0f68ded5a00375efcd1a004928f0736f5b02927ce8594cb0" +content-hash = "8c44ceeb9df5c3ab43040400e0a6b895de49417e61293a1ba027640b34f03263" [metadata.files] attrs = [ diff --git a/pyproject.toml b/pyproject.toml index af5ce2aa03..1368e4e688 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -141,7 +141,8 @@ pyasn1 = ">=0.1.9" pyasn1-modules = ">=0.0.7" bcrypt = ">=3.1.7" Pillow = ">=5.4.0" -sortedcontainers = ">=1.4.4" +# We use SortedDict.peekitem(), which was added in sortedcontainers 1.5.2. +sortedcontainers = ">=1.5.2" pymacaroons = ">=0.13.0" msgpack = ">=0.5.2" phonenumbers = ">=8.2.0" diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 666f4b6895..042de8d7c8 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -27,13 +27,17 @@ EntityType = str class StreamChangeCache: - """Keeps track of the stream positions of the latest change in a set of entities. + """ + Keeps track of the stream positions of the latest change in a set of entities. + + The entity will is typically a room ID or user ID, but can be any string. - Typically the entity will be a room or user id. + Can be queried for whether a specific entity has changed after a stream position + or for a list of changed entities after a stream position. See the individual + methods for more information. - Given a list of entities and a stream position, it will give a subset of - entities that may have changed since that position. If position key is too - old then the cache will simply return all given entities. + Only tracks to a maximum cache size, any position earlier than the earliest + known stream position must be treated as unknown. """ def __init__( @@ -45,16 +49,20 @@ class StreamChangeCache: ) -> None: self._original_max_size: int = max_size self._max_size = math.floor(max_size) - self._entity_to_key: Dict[EntityType, int] = {} - # map from stream id to the a set of entities which changed at that stream id. + # map from stream id to the set of entities which changed at that stream id. self._cache: SortedDict[int, Set[EntityType]] = SortedDict() + # map from entity to the stream ID of the latest change for that entity. + # + # Must be kept in sync with _cache. + self._entity_to_key: Dict[EntityType, int] = {} # the earliest stream_pos for which we can reliably answer # get_all_entities_changed. In other words, one less than the earliest # stream_pos for which we know _cache is valid. # self._earliest_known_stream_pos = current_stream_pos + self.name = name self.metrics = caches.register_cache( "cache", self.name, self._cache, resize_callback=self.set_cache_factor @@ -82,22 +90,46 @@ class StreamChangeCache: return False def has_entity_changed(self, entity: EntityType, stream_pos: int) -> bool: - """Returns True if the entity may have been updated since stream_pos""" + """ + Returns True if the entity may have been updated after stream_pos. + + Args: + entity: The entity to check for changes. + stream_pos: The stream position to check for changes after. + + Return: + True if the entity may have been updated, this happens if: + * The given stream position is at or earlier than the earliest + known stream position. + * The given stream position is earlier than the latest change for + the entity. + + False otherwise: + * The entity is unknown. + * The given stream position is at or later than the latest change + for the entity. + """ assert isinstance(stream_pos, int) - if stream_pos < self._earliest_known_stream_pos: + # _cache is not valid at or before the earliest known stream position, so + # return that the entity has changed. + if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + # If the entity is unknown, it hasn't changed. latest_entity_change_pos = self._entity_to_key.get(entity, None) if latest_entity_change_pos is None: self.metrics.inc_hits() return False + # This is a known entity, return true if the stream position is earlier + # than the last change. if stream_pos < latest_entity_change_pos: self.metrics.inc_misses() return True + # Otherwise, the stream position is after the latest change: return false. self.metrics.inc_hits() return False @@ -105,15 +137,27 @@ class StreamChangeCache: self, entities: Collection[EntityType], stream_pos: int ) -> Union[Set[EntityType], FrozenSet[EntityType]]: """ - Returns subset of entities that have had new things since the given - position. Entities unknown to the cache will be returned. If the - position is too old it will just return the given list. + Returns the subset of the given entities that have had changes after the given position. + + Entities unknown to the cache will be returned. + + If the position is too old it will just return the given list. + + Args: + entities: Entities to check for changes. + stream_pos: The stream position to check for changes after. + + Return: + A subset of entities which have changed after the given stream position. + + This will be all entities if the given stream position is at or earlier + than the earliest known stream position. """ changed_entities = self.get_all_entities_changed(stream_pos) if changed_entities is not None: # We now do an intersection, trying to do so in the most efficient # way possible (some of these sets are *large*). First check in the - # given iterable is already set that we can reuse, otherwise we + # given iterable is already a set that we can reuse, otherwise we # create a set of the *smallest* of the two iterables and call # `intersection(..)` on it (this can be twice as fast as the reverse). if isinstance(entities, (set, frozenset)): @@ -130,29 +174,57 @@ class StreamChangeCache: return result def has_any_entity_changed(self, stream_pos: int) -> bool: - """Returns if any entity has changed""" - assert type(stream_pos) is int + """ + Returns true if any entity has changed after the given stream position. + + Args: + stream_pos: The stream position to check for changes after. + + Return: + True if any entity has changed after the given stream position or + if the given stream position is at or earlier than the earliest + known stream position. + + False otherwise. + """ + assert isinstance(stream_pos, int) if not self._cache: # If the cache is empty, nothing can have changed. return False - if stream_pos >= self._earliest_known_stream_pos: - self.metrics.inc_hits() - return self._cache.bisect_right(stream_pos) < len(self._cache) - else: + # _cache is not valid at or before the earliest known stream position, so + # return that an entity has changed. + if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + self.metrics.inc_hits() + return stream_pos < self._cache.peekitem()[0] + def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: - """Returns all entities that have had new things since the given - position. If the position is too old it will return None. + """ + Returns all entities that have had changes after the given position. + + If the stream change cache does not go far enough back, i.e. the position + is too old, it will return None. Returns the entities in the order that they were changed. + + Args: + stream_pos: The stream position to check for changes after. + + Return: + Entities which have changed after the given stream position. + + None if the given stream position is at or earlier than the earliest + known stream position. """ - assert type(stream_pos) is int + assert isinstance(stream_pos, int) - if stream_pos < self._earliest_known_stream_pos: + # _cache is not valid at or before the earliest known stream position, so + # return None to mark that it is unknown if an entity has changed. + if stream_pos <= self._earliest_known_stream_pos: return None changed_entities: List[EntityType] = [] @@ -162,11 +234,17 @@ class StreamChangeCache: return changed_entities def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: - """Informs the cache that the entity has been changed at the given - position. """ - assert type(stream_pos) is int + Informs the cache that the entity has been changed at the given position. + + Args: + entity: The entity to mark as changed. + stream_pos: The stream position to update the entity to. + """ + assert isinstance(stream_pos, int) + # For a change before _cache is valid (e.g. at or before the earliest known + # stream position) there's nothing to do. if stream_pos <= self._earliest_known_stream_pos: return @@ -189,6 +267,11 @@ class StreamChangeCache: self._evict() def _evict(self) -> None: + """ + Ensure the cache has not exceeded the maximum size. + + Evicts entries until it is at the maximum size. + """ # if the cache is too big, remove entries while len(self._cache) > self._max_size: k, r = self._cache.popitem(0) @@ -199,5 +282,12 @@ class StreamChangeCache: def get_max_pos_of_last_change(self, entity: EntityType) -> int: """Returns an upper bound of the stream id of the last change to an entity. + + Args: + entity: The entity to check. + + Return: + The stream position of the latest change for the given entity or + the earliest known stream position if the entitiy is unknown. """ return self._entity_to_key.get(entity, self._earliest_known_stream_pos) diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 1b0fa52ad1..a29cc872f9 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -51,6 +51,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # return True, whether it's a known entity or not. self.assertTrue(cache.has_entity_changed("user@foo.com", 0)) self.assertTrue(cache.has_entity_changed("not@here.website", 0)) + self.assertTrue(cache.has_entity_changed("user@foo.com", 3)) + self.assertTrue(cache.has_entity_changed("not@here.website", 3)) def test_entity_has_changed_pops_off_start(self) -> None: """ @@ -65,15 +67,14 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # The cache is at the max size, 2 self.assertEqual(len(cache._cache), 2) + # The cache's earliest known position is 2. + self.assertEqual(cache._earliest_known_stream_pos, 2) # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) - self.assertEqual( - cache.get_all_entities_changed(2), - ["bar@baz.net", "user@elsewhere.org"], - ) - self.assertIsNone(cache.get_all_entities_changed(1)) + self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) + self.assertIsNone(cache.get_all_entities_changed(2)) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) @@ -81,10 +82,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(2), + cache.get_all_entities_changed(3), ["user@elsewhere.org", "bar@baz.net"], ) - self.assertIsNone(cache.get_all_entities_changed(1)) + self.assertIsNone(cache.get_all_entities_changed(2)) def test_get_all_entities_changed(self) -> None: """ @@ -99,28 +100,15 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): cache.entity_has_changed("anotheruser@foo.com", 3) cache.entity_has_changed("user@elsewhere.org", 4) - r = cache.get_all_entities_changed(1) + r = cache.get_all_entities_changed(2) - # either of these are valid - ok1 = [ - "user@foo.com", - "bar@baz.net", - "anotheruser@foo.com", - "user@elsewhere.org", - ] - ok2 = [ - "user@foo.com", - "anotheruser@foo.com", - "bar@baz.net", - "user@elsewhere.org", - ] + # Results are ordered so either of these are valid. + ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"] + ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"] self.assertTrue(r == ok1 or r == ok2) - r = cache.get_all_entities_changed(2) - self.assertTrue(r == ok1[1:] or r == ok2[1:]) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertEqual(cache.get_all_entities_changed(0), None) + self.assertEqual(cache.get_all_entities_changed(1), None) # ... later, things gest more updates cache.entity_has_changed("user@foo.com", 5) -- cgit 1.5.1 From cee9445884eb62c070fb0b03a112a862e8dea7c4 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 5 Dec 2022 20:19:14 +0000 Subject: Better return type for `get_all_entities_changed` (#14604) Help callers from using the return value incorrectly by ensuring that callers explicitly check if there was a cache hit or not. --- changelog.d/14604.bugfix | 1 + synapse/handlers/appservice.py | 4 +- synapse/handlers/presence.py | 12 ++-- synapse/handlers/sync.py | 6 +- synapse/handlers/typing.py | 8 +-- synapse/storage/databases/main/devices.py | 111 ++++++++++++++++++----------- synapse/util/caches/stream_change_cache.py | 52 ++++++++++---- tests/util/test_stream_change_cache.py | 20 +++--- 8 files changed, 138 insertions(+), 76 deletions(-) create mode 100644 changelog.d/14604.bugfix (limited to 'tests/util') diff --git a/changelog.d/14604.bugfix b/changelog.d/14604.bugfix new file mode 100644 index 0000000000..149ee99dd7 --- /dev/null +++ b/changelog.d/14604.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a device list update might not be sent to clients in certain circumstances. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 66f5b8d108..f68027aaed 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -615,8 +615,8 @@ class ApplicationServicesHandler: ) # Fetch the users who have modified their device list since then. - users_with_changed_device_lists = ( - await self.store.get_users_whose_devices_changed(from_key, to_key=new_key) + users_with_changed_device_lists = await self.store.get_all_devices_changed( + from_key, to_key=new_key ) # Filter out any users the application service is not interested in diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 1799174c2f..2af90b25a3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -1692,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): if from_key is not None: # First get all users that have had a presence update - updated_users = stream_change_cache.get_all_entities_changed(from_key) + result = stream_change_cache.get_all_entities_changed(from_key) # Cross-reference users we're interested in with those that have had updates. - if updated_users is not None: + if result.hit: + updated_users = result.entities + # If we have the full list of changes for presence we can # simply check which ones share a room with the user. get_updates_counter.labels("stream").inc() @@ -1767,9 +1769,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]): updated_users = None if from_key: # Only return updates since the last sync - updated_users = self.store.presence_stream_cache.get_all_entities_changed( - from_key - ) + result = self.store.presence_stream_cache.get_all_entities_changed(from_key) + if result.hit: + updated_users = result.entities if updated_users is not None: # Get the actual presence update for each change diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c8858b22dd..0b395a104d 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1528,10 +1528,12 @@ class SyncHandler: # # If we don't have that info cached then we get all the users that # share a room with our user and check if those users have changed. - changed_users = self.store.get_cached_device_list_changes( + cache_result = self.store.get_cached_device_list_changes( since_token.device_list_key ) - if changed_users is not None: + if cache_result.hit: + changed_users = cache_result.entities + result = await self.store.get_rooms_for_users(changed_users) for changed_user_id, entries in result.items(): diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a0ea719430..3f656ea4f5 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler): if last_id == current_id: return [], current_id, False - changed_rooms: Optional[ - Iterable[str] - ] = self._typing_stream_change_cache.get_all_entities_changed(last_id) + result = self._typing_stream_change_cache.get_all_entities_changed(last_id) - if changed_rooms is None: + if result.hit: + changed_rooms: Iterable[str] = result.entities + else: changed_rooms = self._room_serials rows = [] diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index 8ba995df3b..a5bb4d404e 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -58,7 +58,10 @@ from synapse.types import JsonDict, get_verify_key_from_cross_signing_key from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache -from synapse.util.caches.stream_change_cache import StreamChangeCache +from synapse.util.caches.stream_change_cache import ( + AllEntitiesChangedResult, + StreamChangeCache, +) from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.stringutils import shortstr @@ -799,18 +802,66 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): def get_cached_device_list_changes( self, from_key: int, - ) -> Optional[List[str]]: + ) -> AllEntitiesChangedResult: """Get set of users whose devices have changed since `from_key`, or None if that information is not in our cache. """ return self._device_list_stream_cache.get_all_entities_changed(from_key) + @cancellable + async def get_all_devices_changed( + self, + from_key: int, + to_key: int, + ) -> Set[str]: + """Get all users whose devices have changed in the given range. + + Args: + from_key: The minimum device lists stream token to query device list + changes for, exclusive. + to_key: The maximum device lists stream token to query device list + changes for, inclusive. + + Returns: + The set of user_ids whose devices have changed since `from_key` + (exclusive) until `to_key` (inclusive). + """ + + result = self._device_list_stream_cache.get_all_entities_changed(from_key) + + if result.hit: + # We know which users might have changed devices. + if not result.entities: + # If no users then we can return early. + return set() + + # Otherwise we need to filter down the list + return await self.get_users_whose_devices_changed( + from_key, result.entities, to_key + ) + + # If the cache didn't tell us anything, we just need to query the full + # range. + sql = """ + SELECT DISTINCT user_id FROM device_lists_stream + WHERE ? < stream_id AND stream_id <= ? + """ + + rows = await self.db_pool.execute( + "get_all_devices_changed", + None, + sql, + from_key, + to_key, + ) + return {u for u, in rows} + @cancellable async def get_users_whose_devices_changed( self, from_key: int, - user_ids: Optional[Collection[str]] = None, + user_ids: Collection[str], to_key: Optional[int] = None, ) -> Set[str]: """Get set of users whose devices have changed since `from_key` that @@ -830,52 +881,32 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): """ # Get set of users who *may* have changed. Users not in the returned # list have definitely not changed. - user_ids_to_check: Optional[Collection[str]] - if user_ids is None: - # Get set of all users that have had device list changes since 'from_key' - user_ids_to_check = self._device_list_stream_cache.get_all_entities_changed( - from_key - ) - else: - # The same as above, but filter results to only those users in 'user_ids' - user_ids_to_check = self._device_list_stream_cache.get_entities_changed( - user_ids, from_key - ) + user_ids_to_check = self._device_list_stream_cache.get_entities_changed( + user_ids, from_key + ) # If an empty set was returned, there's nothing to do. - if user_ids_to_check is not None and not user_ids_to_check: + if not user_ids_to_check: return set() - def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: - stream_id_where_clause = "stream_id > ?" - sql_args = [from_key] - - if to_key: - stream_id_where_clause += " AND stream_id <= ?" - sql_args.append(to_key) + if to_key is None: + to_key = self._device_list_id_gen.get_current_token() - sql = f""" + def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ SELECT DISTINCT user_id FROM device_lists_stream - WHERE {stream_id_where_clause} + WHERE ? < stream_id AND stream_id <= ? AND %s """ - # If the stream change cache gave us no information, fetch *all* - # users between the stream IDs. - if user_ids_to_check is None: - txn.execute(sql, sql_args) - return {user_id for user_id, in txn} + changes: Set[str] = set() - # Otherwise, fetch changes for the given users. - else: - changes: Set[str] = set() - - # Query device changes with a batch of users at a time - for chunk in batch_iter(user_ids_to_check, 100): - clause, args = make_in_list_sql_clause( - txn.database_engine, "user_id", chunk - ) - txn.execute(sql + " AND " + clause, sql_args + args) - changes.update(user_id for user_id, in txn) + # Query device changes with a batch of users at a time + for chunk in batch_iter(user_ids_to_check, 100): + clause, args = make_in_list_sql_clause( + txn.database_engine, "user_id", chunk + ) + txn.execute(sql % (clause,), [from_key, to_key] + args) + changes.update(user_id for user_id, in txn) return changes diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 042de8d7c8..c8b17acb59 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -16,6 +16,7 @@ import logging import math from typing import Collection, Dict, FrozenSet, List, Mapping, Optional, Set, Union +import attr from sortedcontainers import SortedDict from synapse.util import caches @@ -26,6 +27,29 @@ logger = logging.getLogger(__name__) EntityType = str +@attr.s(auto_attribs=True, frozen=True, slots=True) +class AllEntitiesChangedResult: + """Return type of `get_all_entities_changed`. + + Callers must check that there was a cache hit, via `result.hit`, before + using the entities in `result.entities`. + + This specifically does *not* implement helpers such as `__bool__` to ensure + that callers do the correct checks. + """ + + _entities: Optional[List[EntityType]] + + @property + def hit(self) -> bool: + return self._entities is not None + + @property + def entities(self) -> List[EntityType]: + assert self._entities is not None + return self._entities + + class StreamChangeCache: """ Keeps track of the stream positions of the latest change in a set of entities. @@ -153,19 +177,19 @@ class StreamChangeCache: This will be all entities if the given stream position is at or earlier than the earliest known stream position. """ - changed_entities = self.get_all_entities_changed(stream_pos) - if changed_entities is not None: + cache_result = self.get_all_entities_changed(stream_pos) + if cache_result.hit: # We now do an intersection, trying to do so in the most efficient # way possible (some of these sets are *large*). First check in the # given iterable is already a set that we can reuse, otherwise we # create a set of the *smallest* of the two iterables and call # `intersection(..)` on it (this can be twice as fast as the reverse). if isinstance(entities, (set, frozenset)): - result = entities.intersection(changed_entities) - elif len(changed_entities) < len(entities): - result = set(changed_entities).intersection(entities) + result = entities.intersection(cache_result.entities) + elif len(cache_result.entities) < len(entities): + result = set(cache_result.entities).intersection(entities) else: - result = set(entities).intersection(changed_entities) + result = set(entities).intersection(cache_result.entities) self.metrics.inc_hits() else: result = set(entities) @@ -202,12 +226,12 @@ class StreamChangeCache: self.metrics.inc_hits() return stream_pos < self._cache.peekitem()[0] - def get_all_entities_changed(self, stream_pos: int) -> Optional[List[EntityType]]: + def get_all_entities_changed(self, stream_pos: int) -> AllEntitiesChangedResult: """ Returns all entities that have had changes after the given position. - If the stream change cache does not go far enough back, i.e. the position - is too old, it will return None. + If the stream change cache does not go far enough back, i.e. the + position is too old, it will return None. Returns the entities in the order that they were changed. @@ -215,23 +239,21 @@ class StreamChangeCache: stream_pos: The stream position to check for changes after. Return: - Entities which have changed after the given stream position. - - None if the given stream position is at or earlier than the earliest - known stream position. + A class indicating if we have the requested data cached, and if so + includes the entities in the order they were changed. """ assert isinstance(stream_pos, int) # _cache is not valid at or before the earliest known stream position, so # return None to mark that it is unknown if an entity has changed. if stream_pos <= self._earliest_known_stream_pos: - return None + return AllEntitiesChangedResult(None) changed_entities: List[EntityType] = [] for k in self._cache.islice(start=self._cache.bisect_right(stream_pos)): changed_entities.extend(self._cache[k]) - return changed_entities + return AllEntitiesChangedResult(changed_entities) def entity_has_changed(self, entity: EntityType, stream_pos: int) -> None: """ diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index a29cc872f9..0305741c99 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -73,8 +73,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # The oldest item has been popped off self.assertTrue("user@foo.com" not in cache._entity_to_key) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertIsNone(cache.get_all_entities_changed(2)) + self.assertEqual( + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"] + ) + self.assertFalse(cache.get_all_entities_changed(2).hit) # If we update an existing entity, it keeps the two existing entities cache.entity_has_changed("bar@baz.net", 5) @@ -82,10 +84,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): {"bar@baz.net", "user@elsewhere.org"}, set(cache._entity_to_key) ) self.assertEqual( - cache.get_all_entities_changed(3), + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org", "bar@baz.net"], ) - self.assertIsNone(cache.get_all_entities_changed(2)) + self.assertFalse(cache.get_all_entities_changed(2).hit) def test_get_all_entities_changed(self) -> None: """ @@ -105,10 +107,12 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): # Results are ordered so either of these are valid. ok1 = ["bar@baz.net", "anotheruser@foo.com", "user@elsewhere.org"] ok2 = ["anotheruser@foo.com", "bar@baz.net", "user@elsewhere.org"] - self.assertTrue(r == ok1 or r == ok2) + self.assertTrue(r.entities == ok1 or r.entities == ok2) - self.assertEqual(cache.get_all_entities_changed(3), ["user@elsewhere.org"]) - self.assertEqual(cache.get_all_entities_changed(1), None) + self.assertEqual( + cache.get_all_entities_changed(3).entities, ["user@elsewhere.org"] + ) + self.assertFalse(cache.get_all_entities_changed(1).hit) # ... later, things gest more updates cache.entity_has_changed("user@foo.com", 5) @@ -128,7 +132,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): "anotheruser@foo.com", ] r = cache.get_all_entities_changed(3) - self.assertTrue(r == ok1 or r == ok2) + self.assertTrue(r.entities == ok1 or r.entities == ok2) def test_has_any_entity_changed(self) -> None: """ -- cgit 1.5.1 From da777207528513c858395758bf4c023da2c2c1a3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 8 Dec 2022 11:35:49 -0500 Subject: Check the stream position before checking if the cache is empty. (#14639) An empty cache does not mean the entity has no changed, if it is earlier than the earliest known stream position return that the entity *has* changed since the cache cannot accurately answer that query. --- changelog.d/14639.bugfix | 1 + synapse/util/caches/stream_change_cache.py | 9 +++++---- tests/util/test_stream_change_cache.py | 7 ++++--- 3 files changed, 10 insertions(+), 7 deletions(-) create mode 100644 changelog.d/14639.bugfix (limited to 'tests/util') diff --git a/changelog.d/14639.bugfix b/changelog.d/14639.bugfix new file mode 100644 index 0000000000..8730b10afe --- /dev/null +++ b/changelog.d/14639.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where the user directory and room/user stats might be out of sync. diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index c8b17acb59..1657459549 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -213,16 +213,17 @@ class StreamChangeCache: """ assert isinstance(stream_pos, int) - if not self._cache: - # If the cache is empty, nothing can have changed. - return False - # _cache is not valid at or before the earliest known stream position, so # return that an entity has changed. if stream_pos <= self._earliest_known_stream_pos: self.metrics.inc_misses() return True + # If the cache is empty, nothing can have changed. + if not self._cache: + self.metrics.inc_misses() + return False + self.metrics.inc_hits() return stream_pos < self._cache.peekitem()[0] diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 0305741c99..3df053493b 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -144,9 +144,10 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): """ cache = StreamChangeCache("#test", 1) - # With no entities, it returns False for the past, present, and future. - self.assertFalse(cache.has_any_entity_changed(0)) - self.assertFalse(cache.has_any_entity_changed(1)) + # With no entities, it returns True for the past, present, and False for + # the future. + self.assertTrue(cache.has_any_entity_changed(0)) + self.assertTrue(cache.has_any_entity_changed(1)) self.assertFalse(cache.has_any_entity_changed(2)) # We add an entity -- cgit 1.5.1