From e0804ef898374c71850a30ce38c23cd1274a1c97 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Wed, 21 Sep 2022 14:40:34 +0200 Subject: Improve the `synapse.api.auth.Auth` mock used in unit tests. (#13809) To return the proper type (`Requester`) instead of a `dict`. --- tests/unittest.py | 36 ++++++++++-------------------------- 1 file changed, 10 insertions(+), 26 deletions(-) (limited to 'tests') diff --git a/tests/unittest.py b/tests/unittest.py index 975b0a23a7..00cb023198 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -300,47 +300,31 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: assert self.helper.auth_user_id is not None + token = "some_fake_token" # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, - "some_fake_token", + token, None, None, ) ) - async def get_user_by_access_token( - token: Optional[str] = None, allow_guest: bool = False - ) -> JsonDict: - assert self.helper.auth_user_id is not None - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": token_id, - "is_guest": False, - } - - async def get_user_by_req( - request: SynapseRequest, - allow_guest: bool = False, - allow_expired: bool = False, - ) -> Requester: + # This has to be a function and not just a Mock, because + # `self.helper.auth_user_id` is temporarily reassigned in some tests + async def get_requester(*args, **kwargs) -> Requester: assert self.helper.auth_user_id is not None return create_requester( - UserID.from_string(self.helper.auth_user_id), - token_id, - False, - False, - None, + user_id=UserID.from_string(self.helper.auth_user_id), + access_token_id=token_id, ) # Type ignore: mypy doesn't like us assigning to methods. - self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment] - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] - self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment] - return_value="1234" - ) + self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment] + self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment] + self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment] if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] -- cgit 1.5.1 From 6bd8763804dc0987c7ecd37bcb5ebff465fffa29 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Wed, 21 Sep 2022 15:32:01 +0200 Subject: Add cache invalidation across workers to module API (#13667) Signed-off-by: Mathieu Velten --- changelog.d/13667.feature | 1 + scripts-dev/mypy_synapse_plugin.py | 4 +- synapse/module_api/__init__.py | 33 ++++++++- synapse/storage/_base.py | 23 +++++-- synapse/storage/databases/main/cache.py | 20 ++++-- synapse/util/caches/descriptors.py | 14 ++-- .../replication/test_module_cache_invalidation.py | 79 ++++++++++++++++++++++ 7 files changed, 153 insertions(+), 21 deletions(-) create mode 100644 changelog.d/13667.feature create mode 100644 tests/replication/test_module_cache_invalidation.py (limited to 'tests') diff --git a/changelog.d/13667.feature b/changelog.d/13667.feature new file mode 100644 index 0000000000..a0b3cfe18c --- /dev/null +++ b/changelog.d/13667.feature @@ -0,0 +1 @@ +Add cache invalidation across workers to module API. diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index d08517a953..2c377533c0 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -29,7 +29,7 @@ class SynapsePlugin(Plugin): self, fullname: str ) -> Optional[Callable[[MethodSigContext], CallableType]]: if fullname.startswith( - "synapse.util.caches.descriptors._CachedFunction.__call__" + "synapse.util.caches.descriptors.CachedFunction.__call__" ) or fullname.startswith( "synapse.util.caches.descriptors._LruCachedFunction.__call__" ): @@ -38,7 +38,7 @@ class SynapsePlugin(Plugin): def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: - """Fixes the `_CachedFunction.__call__` signature to be correct. + """Fixes the `CachedFunction.__call__` signature to be correct. It already has *almost* the correct signature, except: diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 9287c0fb8d..59755bff6d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -125,7 +125,7 @@ from synapse.types import ( ) from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import CachedFunction, cached from synapse.util.frozenutils import freeze if TYPE_CHECKING: @@ -836,6 +836,37 @@ class ModuleApi: self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] ) + def register_cached_function(self, cached_func: CachedFunction) -> None: + """Register a cached function that should be invalidated across workers. + Invalidation local to a worker can be done directly using `cached_func.invalidate`, + however invalidation that needs to go to other workers needs to call `invalidate_cache` + on the module API instead. + + Args: + cached_function: The cached function that will be registered to receive invalidation + locally and from other workers. + """ + self._store.register_external_cached_function( + f"{cached_func.__module__}.{cached_func.__name__}", cached_func + ) + + async def invalidate_cache( + self, cached_func: CachedFunction, keys: Tuple[Any, ...] + ) -> None: + """Invalidate a cache entry of a cached function across workers. The cached function + needs to be registered on all workers first with `register_cached_function`. + + Args: + cached_function: The cached function that needs an invalidation + keys: keys of the entry to invalidate, usually matching the arguments of the + cached function. + """ + cached_func.invalidate(keys) + await self._store.send_invalidation_to_replication( + f"{cached_func.__module__}.{cached_func.__name__}", + keys, + ) + async def complete_sso_login_async( self, registered_user_id: str, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e30f9c76d4..303a5d5298 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,12 +15,13 @@ # limitations under the License. import logging from abc import ABCMeta -from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import get_domain_from_id from synapse.util import json_decoder +from synapse.util.caches.descriptors import CachedFunction if TYPE_CHECKING: from synapse.server import HomeServer @@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta): self.database_engine = database.engine self.db_pool = database + self.external_cached_functions: Dict[str, CachedFunction] = {} + def process_replication_rows( self, stream_name: str, @@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta): def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] - ) -> None: + ) -> bool: """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. @@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta): try: cache = getattr(self, cache_name) except AttributeError: - # We probably haven't pulled in the cache in this worker, - # which is fine. - return + # Check if an externally defined module cache has been registered + cache = self.external_cached_functions.get(cache_name) + if not cache: + # We probably haven't pulled in the cache in this worker, + # which is fine. + return False if key is None: cache.invalidate_all() @@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta): invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) invalidate_method(tuple(key)) + return True + + def register_external_cached_function( + self, cache_name: str, func: CachedFunction + ) -> None: + self.external_cached_functions[cache_name] = func + def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: """ diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 12e9a42382..2c421151c1 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -33,7 +33,7 @@ from synapse.storage.database import ( ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.util.caches.descriptors import _CachedFunction +from synapse.util.caches.descriptors import CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): return cache_func.invalidate(keys) - await self.db_pool.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, + await self.send_invalidation_to_replication( cache_func.__name__, keys, ) @@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_cache_and_stream( self, txn: LoggingTransaction, - cache_func: _CachedFunction, + cache_func: CachedFunction, keys: Tuple[Any, ...], ) -> None: """Invalidates the cache and adds it to the cache stream so slaves @@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._send_invalidation_to_replication(txn, cache_func.__name__, keys) def _invalidate_all_cache_and_stream( - self, txn: LoggingTransaction, cache_func: _CachedFunction + self, txn: LoggingTransaction, cache_func: CachedFunction ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn, CURRENT_STATE_CACHE_NAME, [room_id] ) + async def send_invalidation_to_replication( + self, cache_name: str, keys: Optional[Collection[Any]] + ) -> None: + await self.db_pool.runInteraction( + "send_invalidation_to_replication", + self._send_invalidation_to_replication, + cache_name, + keys, + ) + def _send_invalidation_to_replication( self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] ) -> None: diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 10aff4d04a..3909f1caea 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) -class _CachedFunction(Generic[F]): +class CachedFunction(Generic[F]): invalidate: Any = None invalidate_all: Any = None prefill: Any = None @@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): return ret2 - wrapped = cast(_CachedFunction, _wrapped) + wrapped = cast(CachedFunction, _wrapped) wrapped.cache = cache obj.__dict__[self.name] = wrapped @@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): return make_deferred_yieldable(ret) - wrapped = cast(_CachedFunction, _wrapped) + wrapped = cast(CachedFunction, _wrapped) if self.num_args == 1: assert not self.tree @@ -572,7 +572,7 @@ def cached( iterable: bool = False, prune_unread_entries: bool = True, name: Optional[str] = None, -) -> Callable[[F], _CachedFunction[F]]: +) -> Callable[[F], CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, @@ -585,7 +585,7 @@ def cached( name=name, ) - return cast(Callable[[F], _CachedFunction[F]], func) + return cast(Callable[[F], CachedFunction[F]], func) def cachedList( @@ -594,7 +594,7 @@ def cachedList( list_name: str, num_args: Optional[int] = None, name: Optional[str] = None, -) -> Callable[[F], _CachedFunction[F]]: +) -> Callable[[F], CachedFunction[F]]: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. Used to do batch lookups for an already created cache. One of the arguments @@ -631,7 +631,7 @@ def cachedList( name=name, ) - return cast(Callable[[F], _CachedFunction[F]], func) + return cast(Callable[[F], CachedFunction[F]], func) def _get_cache_key_builder( diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py new file mode 100644 index 0000000000..b93cae67d3 --- /dev/null +++ b/tests/replication/test_module_cache_invalidation.py @@ -0,0 +1,79 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import synapse +from synapse.module_api import cached + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + +FIRST_VALUE = "one" +SECOND_VALUE = "two" + +KEY = "mykey" + + +class TestCache: + current_value = FIRST_VALUE + + @cached() + async def cached_function(self, user_id: str) -> str: + return self.current_value + + +class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + ] + + def test_module_cache_full_invalidation(self): + main_cache = TestCache() + self.hs.get_module_api().register_cached_function(main_cache.cached_function) + + worker_hs = self.make_worker_hs("synapse.app.generic_worker") + + worker_cache = TestCache() + worker_hs.get_module_api().register_cached_function( + worker_cache.cached_function + ) + + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual( + FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) + + main_cache.current_value = SECOND_VALUE + worker_cache.current_value = SECOND_VALUE + # No invalidation yet, should return the cached value on both the main process and the worker + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual( + FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) + + # Full invalidation on the main process, should be replicated on the worker that + # should returned the updated value too + self.get_success( + self.hs.get_module_api().invalidate_cache( + main_cache.cached_function, (KEY,) + ) + ) + + self.assertEqual( + SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)) + ) + self.assertEqual( + SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) -- cgit 1.5.1 From 8ae42ab8fa3c6b52d74c24daa7ca75a478fa4fbb Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 15:39:01 +0100 Subject: Support enabling/disabling pushers (from MSC3881) (#13799) Partial implementation of MSC3881 --- changelog.d/13799.feature | 1 + synapse/_scripts/synapse_port_db.py | 1 + synapse/config/experimental.py | 3 + synapse/handlers/register.py | 4 +- synapse/push/__init__.py | 2 + synapse/push/pusherpool.py | 81 ++++++++--- synapse/replication/tcp/client.py | 10 +- synapse/rest/admin/users.py | 4 +- synapse/rest/client/pusher.py | 18 ++- synapse/storage/databases/main/pusher.py | 69 ++++++---- .../schema/main/delta/73/02add_pusher_enabled.sql | 16 +++ tests/push/test_email.py | 4 +- tests/push/test_http.py | 148 +++++++++++++++++++-- tests/replication/test_pusher_shard.py | 2 +- tests/rest/admin/test_user.py | 2 +- 15 files changed, 294 insertions(+), 71 deletions(-) create mode 100644 changelog.d/13799.feature create mode 100644 synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql (limited to 'tests') diff --git a/changelog.d/13799.feature b/changelog.d/13799.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13799.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 30983c47fb..450ba462ba 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = { "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], "device_lists_changes_in_room": ["converted_to_destinations"], + "pushers": ["enabled"], } diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702b81e636..f4541a8db0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -93,3 +93,6 @@ class ExperimentalConfig(Config): # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) + + # MSC3881: Remotely toggle push notifications for another client + self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 20ec22105a..cfcadb34db 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -997,7 +997,7 @@ class RegistrationHandler: assert user_tuple token_id = user_tuple.token_id - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=token_id, kind="email", @@ -1005,7 +1005,7 @@ class RegistrationHandler: app_display_name="Email Notifications", device_display_name=threepid["address"], pushkey=threepid["address"], - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 57c4d70466..ac99d35a7e 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -116,6 +116,7 @@ class PusherConfig: last_stream_ordering: int last_success: Optional[int] failing_since: Optional[int] + enabled: bool def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -128,6 +129,7 @@ class PusherConfig: "lang": self.lang, "profile_tag": self.profile_tag, "pushkey": self.pushkey, + "enabled": self.enabled, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 1e0ef44fc7..2597898cf4 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -94,7 +94,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - async def add_pusher( + async def add_or_update_pusher( self, user_id: str, access_token: Optional[int], @@ -106,6 +106,7 @@ class PusherPool: lang: Optional[str], data: JsonDict, profile_tag: str = "", + enabled: bool = True, ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -147,9 +148,20 @@ class PusherPool: last_stream_ordering=last_stream_ordering, last_success=None, failing_since=None, + enabled=enabled, ) ) + # Before we actually persist the pusher, we check if the user already has one + # for this app ID and pushkey. If so, we want to keep the access token in place, + # since this could be one device modifying (e.g. enabling/disabling) another + # device's pusher. + existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) + if existing_config: + access_token = existing_config.access_token + await self.store.add_pusher( user_id=user_id, access_token=access_token, @@ -163,8 +175,9 @@ class PusherPool: data=data, last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, + enabled=enabled, ) - pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) return pusher @@ -276,10 +289,25 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - async def start_pusher_by_id( + async def _get_pusher_config_for_user_by_app_id_and_pushkey( + self, user_id: str, app_id: str, pushkey: str + ) -> Optional[PusherConfig]: + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + + pusher_config = None + for r in resultlist: + if r.user_name == user_id: + pusher_config = r + + return pusher_config + + async def process_pusher_change_by_id( self, app_id: str, pushkey: str, user_id: str ) -> Optional[Pusher]: - """Look up the details for the given pusher, and start it + """Look up the details for the given pusher, and either start it if its + "enabled" flag is True, or try to stop it otherwise. + + If the pusher is new and its "enabled" flag is False, the stop is a noop. Returns: The pusher started, if any @@ -290,12 +318,13 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return None - resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) - pusher_config = None - for r in resultlist: - if r.user_name == user_id: - pusher_config = r + if pusher_config and not pusher_config.enabled: + self.maybe_stop_pusher(app_id, pushkey, user_id) + return None pusher = None if pusher_config: @@ -305,7 +334,7 @@ class PusherPool: async def _start_pushers(self) -> None: """Start all the pushers""" - pushers = await self.store.get_all_pushers() + pushers = await self.store.get_enabled_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -363,6 +392,8 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc() + logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey) + # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. @@ -382,16 +413,7 @@ class PusherPool: return pusher async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: - appid_pushkey = "%s:%s" % (app_id, pushkey) - - byuser = self.pushers.get(user_id, {}) - - if appid_pushkey in byuser: - logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - pusher = byuser.pop(appid_pushkey) - pusher.on_stop() - - synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() + self.maybe_stop_pusher(app_id, pushkey, user_id) # We can only delete pushers on master. if self._remove_pusher_client: @@ -402,3 +424,22 @@ class PusherPool: await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) + + def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: + """Stops a pusher with the given app ID and push key if one is running. + + Args: + app_id: the pusher's app ID. + pushkey: the pusher's push key. + user_id: the user the pusher belongs to. Only used for logging. + """ + appid_pushkey = "%s:%s" % (app_id, pushkey) + + byuser = self.pushers.get(user_id, {}) + + if appid_pushkey in byuser: + logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e4f2201c92..cf9cd6833b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,7 +189,9 @@ class ReplicationDataHandler: if row.deleted: self.stop_pusher(row.user_id, row.app_id, row.pushkey) else: - await self.start_pusher(row.user_id, row.app_id, row.pushkey) + await self.process_pusher_change( + row.user_id, row.app_id, row.pushkey + ) elif stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -334,13 +336,15 @@ class ReplicationDataHandler: logger.info("Stopping pusher %r / %r", user_id, key) pusher.on_stop() - async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: + async def process_pusher_change( + self, user_id: str, app_id: str, pushkey: str + ) -> None: if not self._notify_pushers: return key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id) class FederationSenderHandler: diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2ca6b2d08a..1274773d7e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet): and self.hs.config.email.email_notif_for_new_users and medium == "email" ): - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=None, kind="email", @@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet): app_display_name="Email Notifications", device_display_name=address, pushkey=address, - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 9a1f10f4be..c9f76125dc 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -51,9 +52,14 @@ class PushersRestServlet(RestServlet): user.to_string() ) - filtered_pushers = [p.as_dict() for p in pushers] + pusher_dicts = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers} + for pusher in pusher_dicts: + if self._msc3881_enabled: + pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] + del pusher["enabled"] + + return 200, {"pushers": pusher_dicts} class PushersSetRestServlet(RestServlet): @@ -65,6 +71,7 @@ class PushersSetRestServlet(RestServlet): self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -103,6 +110,10 @@ class PushersSetRestServlet(RestServlet): if "append" in content: append = content["append"] + enabled = True + if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content: + enabled = content["org.matrix.msc3881.enabled"] + if not append: await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], @@ -111,7 +122,7 @@ class PushersSetRestServlet(RestServlet): ) try: - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -122,6 +133,7 @@ class PushersSetRestServlet(RestServlet): lang=content["lang"], data=content["data"], profile_tag=content.get("profile_tag", ""), + enabled=enabled, ) except PusherConfigException as pce: raise SynapseError( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index bd0cfa7f32..ee55b8c4a9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore): ) continue + # If we're using SQLite, then boolean values are integers. This is + # troublesome since some code using the return value of this method might + # expect it to be a boolean, or will expose it to clients (in responses). + r["enabled"] = bool(r["enabled"]) + yield PusherConfig(**r) async def get_pushers_by_app_id_and_pushkey( @@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore): return await self.get_pushers_by({"user_name": user_id}) async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: - ret = await self.db_pool.simple_select_list( - "pushers", - keyvalues, - [ - "id", - "user_name", - "access_token", - "profile_tag", - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "ts", - "lang", - "data", - "last_stream_ordering", - "last_success", - "failing_since", - ], + """Retrieve pushers that match the given criteria. + + Args: + keyvalues: A {column: value} dictionary. + + Returns: + The pushers for which the given columns have the given values. + """ + + def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]: + # We could technically use simple_select_list here, but we need to call + # COALESCE on the 'enabled' column. While it is technically possible to give + # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it + # feels a bit hacky, so it's probably better to just inline the query. + sql = """ + SELECT + id, user_name, access_token, profile_tag, kind, app_id, + app_display_name, device_display_name, pushkey, ts, lang, data, + last_stream_ordering, last_success, failing_since, + COALESCE(enabled, TRUE) AS enabled + FROM pushers + """ + + sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),) + + txn.execute(sql, list(keyvalues.values())) + + return self.db_pool.cursor_to_dict(txn) + + ret = await self.db_pool.runInteraction( desc="get_pushers_by", + func=get_pushers_by_txn, ) + return self._decode_pushers_rows(ret) - async def get_all_pushers(self) -> Iterator[PusherConfig]: - def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: - txn.execute("SELECT * FROM pushers") + async def get_enabled_pushers(self) -> Iterator[PusherConfig]: + def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]: + txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)") rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - return await self.db_pool.runInteraction("get_all_pushers", get_pushers) + return await self.db_pool.runInteraction( + "get_enabled_pushers", get_enabled_pushers_txn + ) async def get_all_updated_pushers_rows( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -476,6 +495,7 @@ class PusherStore(PusherWorkerStore): data: Optional[JsonDict], last_stream_ordering: int, profile_tag: str = "", + enabled: bool = True, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -494,6 +514,7 @@ class PusherStore(PusherWorkerStore): "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, + "enabled": enabled, }, desc="add_pusher", lock=False, diff --git a/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql new file mode 100644 index 0000000000..dba3b4900b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql @@ -0,0 +1,16 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE pushers ADD COLUMN enabled BOOLEAN; \ No newline at end of file diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7a3b0d6755..fd14568f55 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.pusher = self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", @@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase): """ with self.assertRaises(SynapseError) as cm: self.get_success_or_raise( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", diff --git a/tests/push/test_http.py b/tests/push/test_http.py index d9c68cdd2d..af67d84463 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -19,8 +19,8 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable -from synapse.push import PusherConfigException -from synapse.rest.client import login, push_rule, receipts, room +from synapse.push import PusherConfig, PusherConfigException +from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer from synapse.types import JsonDict from synapse.util import Clock @@ -35,6 +35,7 @@ class HTTPPusherTests(HomeserverTestCase): login.register_servlets, receipts.register_servlets, push_rule.register_servlets, + pusher.register_servlets, ] user_id = True hijack_auth = False @@ -74,7 +75,7 @@ class HTTPPusherTests(HomeserverTestCase): def test_data(data: Optional[JsonDict]) -> None: self.get_failure( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -119,7 +120,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -235,7 +236,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -355,7 +356,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -441,7 +442,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -518,7 +519,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -624,7 +625,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -728,18 +729,38 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.json_body) - def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: + def _make_user_with_pusher( + self, username: str, enabled: bool = True + ) -> Tuple[str, str]: + """Registers a user and creates a pusher for them. + + Args: + username: the localpart of the new user's Matrix ID. + enabled: whether to create the pusher in an enabled or disabled state. + """ user_id = self.register_user(username, "pass") access_token = self.login(username, "pass") # Register the pusher + self._set_pusher(user_id, access_token, enabled) + + return user_id, access_token + + def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None: + """Creates or updates the pusher for the given user. + + Args: + user_id: the user's Matrix ID. + access_token: the access token associated with the pusher. + enabled: whether to enable or disable the pusher. + """ user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -749,11 +770,10 @@ class HTTPPusherTests(HomeserverTestCase): pushkey="a@example.com", lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, + enabled=enabled, ) ) - return user_id, access_token - def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification @@ -791,3 +811,105 @@ class HTTPPusherTests(HomeserverTestCase): # The user sends a message back (sends a notification) self.helper.send(room, body="Hello", tok=access_token) self.assertEqual(len(self.push_attempts), 1) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_disable(self) -> None: + """Tests that disabling a pusher means it's not pushed to anymore.""" + user_id, access_token = self._make_user_with_pusher("user") + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it generated a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Disable the pusher. + self._set_pusher(user_id, access_token, enabled=False) + + # Send another message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as disabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertFalse(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_enable(self) -> None: + """Tests that enabling a disabled pusher means it gets pushed to.""" + # Create the user with the pusher already disabled. + user_id, access_token = self._make_user_with_pusher("user", enabled=False) + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 0) + + # Enable the pusher. + self._set_pusher(user_id, access_token, enabled=True) + + # Send another message and check that it did generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as enabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertTrue(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_null_enabled(self) -> None: + """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers + created before the column was introduced) is considered enabled. + """ + # We intentionally set 'enabled' to None so that it's stored as NULL in the + # database. + user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type] + + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) + + def test_update_different_device_access_token(self) -> None: + """Tests that if we create a pusher from one device, the update it from another + device, the access token associated with the pusher stays the same. + """ + # Create a user with a pusher. + user_id, access_token = self._make_user_with_pusher("user") + + # Get the token ID for the current access token, since that's what we store in + # the pushers table. + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + + # Generate a new access token, and update the pusher with it. + new_token = self.login("user", "pass") + self._set_pusher(user_id, new_token, enabled=False) + + # Get the current list of pushers for the user. + ret = self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) + pushers: List[PusherConfig] = list(ret) + + # Check that we still have one pusher, and that the access token associated with + # it didn't change. + self.assertEqual(len(pushers), 1) + self.assertEqual(pushers[0].access_token, token_id) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 8f4f6688ce..59fea93e49 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): token_id = user_dict.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9f536ceeb3..1847e6ad6b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.other_user, access_token=token_id, kind="http", -- cgit 1.5.1 From 0fd2f2d46064efd37284a36d5b478815d69ddd96 Mon Sep 17 00:00:00 2001 From: Hugh Nimmo-Smith Date: Wed, 21 Sep 2022 16:12:29 +0100 Subject: Implementation of MSC3882 login token request (#13722) --- changelog.d/13722.feature | 1 + synapse/config/experimental.py | 7 ++ synapse/rest/__init__.py | 2 + synapse/rest/client/login_token_request.py | 94 ++++++++++++++++++ synapse/rest/client/versions.py | 2 + tests/rest/client/test_login_token_request.py | 132 ++++++++++++++++++++++++++ 6 files changed, 238 insertions(+) create mode 100644 changelog.d/13722.feature create mode 100644 synapse/rest/client/login_token_request.py create mode 100644 tests/rest/client/test_login_token_request.py (limited to 'tests') diff --git a/changelog.d/13722.feature b/changelog.d/13722.feature new file mode 100644 index 0000000000..588d143c0f --- /dev/null +++ b/changelog.d/13722.feature @@ -0,0 +1 @@ +Experimental implementation of MSC3882 to allow an existing device/session to generate a login token for use on a new device/session. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f4541a8db0..bf27f6c101 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -96,3 +96,10 @@ class ExperimentalConfig(Config): # MSC3881: Remotely toggle push notifications for another client self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) + + # MSC3882: Allow an existing session to sign in a new session + self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False) + self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True) + self.msc3882_token_timeout = self.parse_duration( + experimental.get("msc3882_token_timeout", "5m") + ) diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index b712215112..9a2ab99ede 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.client import ( keys, knock, login as v1_login, + login_token_request, logout, mutual_rooms, notifications, @@ -130,3 +131,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) + login_token_request.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py new file mode 100644 index 0000000000..ca5c54bf17 --- /dev/null +++ b/synapse/rest/client/login_token_request.py @@ -0,0 +1,94 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class LoginTokenRequestServlet(RestServlet): + """ + Get a token that can be used with `m.login.token` to log in a second device. + + Request: + + POST /login/token HTTP/1.1 + Content-Type: application/json + + {} + + Response: + + HTTP/1.1 200 OK + { + "login_token": "ABDEFGH", + "expires_in": 3600, + } + """ + + PATTERNS = client_patterns("/login/token$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + self.server_name = hs.config.server.server_name + self.macaroon_gen = hs.get_macaroon_generator() + self.auth_handler = hs.get_auth_handler() + self.token_timeout = hs.config.experimental.msc3882_token_timeout + self.ui_auth = hs.config.experimental.msc3882_ui_auth + + @interactive_auth_handler + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + body = parse_json_object_from_request(request) + + if self.ui_auth: + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "issue a new access token for your account", + can_skip_ui_auth=False, # Don't allow skipping of UI auth + ) + + login_token = self.macaroon_gen.generate_short_term_login_token( + user_id=requester.user.to_string(), + auth_provider_id="org.matrix.msc3882.login_token_request", + duration_in_ms=self.token_timeout, + ) + + return ( + 200, + { + "login_token": login_token, + "expires_in": self.token_timeout // 1000, + }, + ) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3882_enabled: + LoginTokenRequestServlet(hs).register(http_server) diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c516cda95d..c3488f4330 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -105,6 +105,8 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, + # Adds support for login token requests as per MSC3882 + "org.matrix.msc3882": self.config.experimental.msc3882_enabled, }, }, ) diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py new file mode 100644 index 0000000000..d5bb16c98d --- /dev/null +++ b/tests/rest/client/test_login_token_request.py @@ -0,0 +1,132 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, login_token_request +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + + +class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + admin.register_servlets, + login_token_request.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + self.hs.config.registration.enable_registration = True + self.hs.config.registration.registrations_require_3pid = [] + self.hs.config.registration.auto_join_rooms = [] + self.hs.config.captcha.enable_registration_captcha = False + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = "user123" + self.password = "password" + + def test_disabled(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 400) + + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_require_auth(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 401) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_uia_on(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 401) + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + session = channel.json_body["session"] + + uia = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.password, + "session": session, + }, + } + + channel = self.make_request("POST", "/login/token", uia, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}} + ) + def test_uia_off(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + { + "experimental_features": { + "msc3882_enabled": True, + "msc3882_ui_auth": False, + "msc3882_token_timeout": "15s", + } + } + ) + def test_expires_in(self) -> None: + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 15) -- cgit 1.5.1 From ccca14140a019c2e0430f95d78fa075efd8d535f Mon Sep 17 00:00:00 2001 From: Brendan Abolivier Date: Wed, 21 Sep 2022 16:31:53 +0100 Subject: Track device IDs for pushers (#13831) Second half of the MSC3881 implementation --- changelog.d/13831.feature | 1 + synapse/push/__init__.py | 2 + synapse/push/pusherpool.py | 10 ++- synapse/rest/client/pusher.py | 3 + synapse/storage/databases/main/pusher.py | 73 +++++++++++++++++++++- .../schema/main/delta/73/03pusher_device_id.sql | 20 ++++++ tests/push/test_http.py | 55 ++++++++++++++-- 7 files changed, 154 insertions(+), 10 deletions(-) create mode 100644 changelog.d/13831.feature create mode 100644 synapse/storage/schema/main/delta/73/03pusher_device_id.sql (limited to 'tests') diff --git a/changelog.d/13831.feature b/changelog.d/13831.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13831.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index ac99d35a7e..a0c760239d 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -117,6 +117,7 @@ class PusherConfig: last_success: Optional[int] failing_since: Optional[int] enabled: bool + device_id: Optional[str] def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -130,6 +131,7 @@ class PusherConfig: "profile_tag": self.profile_tag, "pushkey": self.pushkey, "enabled": self.enabled, + "device_id": self.device_id, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 2597898cf4..e2648cbc93 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -107,6 +107,7 @@ class PusherPool: data: JsonDict, profile_tag: str = "", enabled: bool = True, + device_id: Optional[str] = None, ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -149,18 +150,20 @@ class PusherPool: last_success=None, failing_since=None, enabled=enabled, + device_id=device_id, ) ) # Before we actually persist the pusher, we check if the user already has one - # for this app ID and pushkey. If so, we want to keep the access token in place, - # since this could be one device modifying (e.g. enabling/disabling) another - # device's pusher. + # this app ID and pushkey. If so, we want to keep the access token and device ID + # in place, since this could be one device modifying (e.g. enabling/disabling) + # another device's pusher. existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( user_id, app_id, pushkey ) if existing_config: access_token = existing_config.access_token + device_id = existing_config.device_id await self.store.add_pusher( user_id=user_id, @@ -176,6 +179,7 @@ class PusherPool: last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, enabled=enabled, + device_id=device_id, ) pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index c9f76125dc..975eef2144 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -57,7 +57,9 @@ class PushersRestServlet(RestServlet): for pusher in pusher_dicts: if self._msc3881_enabled: pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] + pusher["org.matrix.msc3881.device_id"] = pusher["device_id"] del pusher["enabled"] + del pusher["device_id"] return 200, {"pushers": pusher_dicts} @@ -134,6 +136,7 @@ class PushersSetRestServlet(RestServlet): data=content["data"], profile_tag=content.get("profile_tag", ""), enabled=enabled, + device_id=requester.device_id, ) except PusherConfigException as pce: raise SynapseError( diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index ee55b8c4a9..01206950a9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -124,7 +124,7 @@ class PusherWorkerStore(SQLBaseStore): id, user_name, access_token, profile_tag, kind, app_id, app_display_name, device_display_name, pushkey, ts, lang, data, last_stream_ordering, last_success, failing_since, - COALESCE(enabled, TRUE) AS enabled + COALESCE(enabled, TRUE) AS enabled, device_id FROM pushers """ @@ -477,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore): return number_deleted -class PusherStore(PusherWorkerStore): +class PusherBackgroundUpdatesStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "set_device_id_for_pushers", self._set_device_id_for_pushers + ) + + async def _set_device_id_for_pushers( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update to populate the device_id column of the pushers table.""" + last_pusher_id = progress.get("pusher_id", 0) + + def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int: + txn.execute( + """ + SELECT p.id, at.device_id + FROM pushers AS p + INNER JOIN access_tokens AS at + ON p.access_token = at.id + WHERE + p.access_token IS NOT NULL + AND at.device_id IS NOT NULL + AND p.id > ? + ORDER BY p.id + LIMIT ? + """, + (last_pusher_id, batch_size), + ) + + rows = self.db_pool.cursor_to_dict(txn) + if len(rows) == 0: + return 0 + + self.db_pool.simple_update_many_txn( + txn=txn, + table="pushers", + key_names=("id",), + key_values=[(row["id"],) for row in rows], + value_names=("device_id",), + value_values=[(row["device_id"],) for row in rows], + ) + + self.db_pool.updates._background_update_progress_txn( + txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]} + ) + + return len(rows) + + nb_processed = await self.db_pool.runInteraction( + "set_device_id_for_pushers", set_device_id_for_pushers_txn + ) + + if nb_processed < batch_size: + await self.db_pool.updates._end_background_update( + "set_device_id_for_pushers" + ) + + return nb_processed + + +class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() @@ -496,6 +563,7 @@ class PusherStore(PusherWorkerStore): last_stream_ordering: int, profile_tag: str = "", enabled: bool = True, + device_id: Optional[str] = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -515,6 +583,7 @@ class PusherStore(PusherWorkerStore): "profile_tag": profile_tag, "id": stream_id, "enabled": enabled, + "device_id": device_id, }, desc="add_pusher", lock=False, diff --git a/synapse/storage/schema/main/delta/73/03pusher_device_id.sql b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql new file mode 100644 index 0000000000..1b4ffbeebe --- /dev/null +++ b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a device_id column to track the device ID that created the pusher. It's NULLable +-- on purpose, because a) it might not be possible to track down the device that created +-- old pushers (pushers.access_token and access_tokens.device_id are both NULLable), and +-- b) access tokens retrieved via the admin API don't have a device associated to them. +ALTER TABLE pushers ADD COLUMN device_id TEXT; \ No newline at end of file diff --git a/tests/push/test_http.py b/tests/push/test_http.py index af67d84463..b383b8401f 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -22,6 +22,7 @@ from synapse.logging.context import make_deferred_yieldable from synapse.push import PusherConfig, PusherConfigException from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import JsonDict from synapse.util import Clock @@ -771,6 +772,7 @@ class HTTPPusherTests(HomeserverTestCase): lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, enabled=enabled, + device_id=user_tuple.device_id, ) ) @@ -885,19 +887,21 @@ class HTTPPusherTests(HomeserverTestCase): self.assertEqual(len(channel.json_body["pushers"]), 1) self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) - def test_update_different_device_access_token(self) -> None: + def test_update_different_device_access_token_device_id(self) -> None: """Tests that if we create a pusher from one device, the update it from another - device, the access token associated with the pusher stays the same. + device, the access token and device ID associated with the pusher stays the + same. """ # Create a user with a pusher. user_id, access_token = self._make_user_with_pusher("user") # Get the token ID for the current access token, since that's what we store in - # the pushers table. + # the pushers table. Also get the device ID from it. user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id + device_id = user_tuple.device_id # Generate a new access token, and update the pusher with it. new_token = self.login("user", "pass") @@ -909,7 +913,48 @@ class HTTPPusherTests(HomeserverTestCase): ) pushers: List[PusherConfig] = list(ret) - # Check that we still have one pusher, and that the access token associated with - # it didn't change. + # Check that we still have one pusher, and that the access token and device ID + # associated with it didn't change. self.assertEqual(len(pushers), 1) self.assertEqual(pushers[0].access_token, token_id) + self.assertEqual(pushers[0].device_id, device_id) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_device_id(self) -> None: + """Tests that a pusher created with a given device ID shows that device ID in + GET /pushers requests. + """ + self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # We create the pusher with an HTTP request rather than with + # _make_user_with_pusher so that we can test the device ID is correctly set when + # creating a pusher via an API call. + self.make_request( + method="POST", + path="/pushers/set", + content={ + "kind": "http", + "app_id": "m.http", + "app_display_name": "HTTP Push Notifications", + "device_display_name": "pushy push", + "pushkey": "a@example.com", + "lang": "en", + "data": {"url": "http://example.com/_matrix/push/v1/notify"}, + }, + access_token=access_token, + ) + + # Look up the user info for the access token so we can compare the device ID. + lookup_result: TokenLookupResult = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + + # Get the user's devices and check it has the correct device ID. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertEqual( + channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"], + lookup_result.device_id, + ) -- cgit 1.5.1 From a847a35921519abc17c98a831266236a86d566f8 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Wed, 21 Sep 2022 18:02:45 -0500 Subject: Fix `have_seen_event` cache not being invalidated Fix https://github.com/matrix-org/synapse/issues/13856 `_invalidate_caches_for_event` doesn't run in monolith mode which means we never even tried to clear the `have_seen_event` and other caches. And even in worker mode, it only runs on the workers, not the master (AFAICT). Additionally there is bug with the key being wrong so `_invalidate_caches_for_event` never invalidates the `have_seen_event` cache even when it does run. Wrong: ```py self.have_seen_event.invalidate((room_id, event_id)) ``` Correct: ```py self.have_seen_event.invalidate(((room_id, event_id),)) ``` --- synapse/storage/controllers/persist_events.py | 34 +++++- synapse/storage/databases/main/cache.py | 2 +- tests/storage/databases/main/test_events_worker.py | 120 ++++++++++++--------- 3 files changed, 106 insertions(+), 50 deletions(-) (limited to 'tests') diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 501dbbc990..094628ec18 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -43,7 +43,7 @@ from prometheus_client import Counter, Histogram from twisted.internet import defer from synapse.api.constants import EventTypes, Membership -from synapse.events import EventBase +from synapse.events import EventBase, relation_from_event from synapse.events.snapshot import EventContext from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.logging.opentracing import ( @@ -431,6 +431,22 @@ class EventsPersistenceStorageController: else: events.append(event) + # We expect events to be persisted by this point and this makes + # mypy happy about `stream_ordering` not being optional below + assert event.internal_metadata.stream_ordering + # Invalidate related caches after we persist a new event + relation = relation_from_event(event) + self.main_store._invalidate_caches_for_event( + stream_ordering=event.internal_metadata.stream_ordering, + event_id=event.event_id, + room_id=event.room_id, + etype=event.type, + state_key=event.state_key if hasattr(event, "state_key") else None, + redacts=event.redacts, + relates_to=relation.parent_id if relation else None, + backfilled=backfilled, + ) + return ( events, self.main_store.get_room_max_token(), @@ -463,6 +479,22 @@ class EventsPersistenceStorageController: replaced_event = replaced_events.get(event.event_id) if replaced_event: event = await self.main_store.get_event(replaced_event) + else: + # We expect events to be persisted by this point and this makes + # mypy happy about `stream_ordering` not being optional below + assert event.internal_metadata.stream_ordering + # Invalidate related caches after we persist a new event + relation = relation_from_event(event) + self.main_store._invalidate_caches_for_event( + stream_ordering=event.internal_metadata.stream_ordering, + event_id=event.event_id, + room_id=event.room_id, + etype=event.type, + state_key=event.state_key if hasattr(event, "state_key") else None, + redacts=event.redacts, + relates_to=relation.parent_id if relation else None, + backfilled=backfilled, + ) event_stream_id = event.internal_metadata.stream_ordering # stream ordering should have been assigned by now diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2c421151c1..2dbf93a4e5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -223,7 +223,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # process triggering the invalidation is responsible for clearing any external # cached objects. self._invalidate_local_get_event_cache(event_id) - self.have_seen_event.invalidate((room_id, event_id)) + self.have_seen_event.invalidate(((room_id, event_id),)) self.get_latest_event_ids_in_room.invalidate((room_id,)) diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 67401272ac..158ad1f439 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -35,66 +35,45 @@ from synapse.util import Clock from synapse.util.async_helpers import yieldable_gather_results from tests import unittest +from tests.test_utils.event_injection import create_event, inject_event class HaveSeenEventsTestCase(unittest.HomeserverTestCase): + servlets = [ + admin.register_servlets, + room.register_servlets, + login.register_servlets, + ] + def prepare(self, reactor, clock, hs): + self.hs = hs self.store: EventsWorkerStore = hs.get_datastores().main - # insert some test data - for rid in ("room1", "room2"): - self.get_success( - self.store.db_pool.simple_insert( - "rooms", - {"room_id": rid, "room_version": 4}, - ) - ) + self.user = self.register_user("user", "pass") + self.token = self.login(self.user, "pass") + self.room_id = self.helper.create_room_as(self.user, tok=self.token) self.event_ids: List[str] = [] - for idx, rid in enumerate( - ( - "room1", - "room1", - "room1", - "room2", - ) - ): - event_json = {"type": f"test {idx}", "room_id": rid} - event = make_event_from_dict(event_json, room_version=RoomVersions.V4) - event_id = event.event_id - - self.get_success( - self.store.db_pool.simple_insert( - "events", - { - "event_id": event_id, - "room_id": rid, - "topological_ordering": idx, - "stream_ordering": idx, - "type": event.type, - "processed": True, - "outlier": False, - }, - ) - ) - self.get_success( - self.store.db_pool.simple_insert( - "event_json", - { - "event_id": event_id, - "room_id": rid, - "json": json.dumps(event_json), - "internal_metadata": "{}", - "format_version": 3, - }, + for i in range(3): + event = self.get_success( + inject_event( + hs, + room_version=RoomVersions.V7.identifier, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": f"foobarbaz{i}"}, ) ) - self.event_ids.append(event_id) + + self.event_ids.append(event.event_id) def test_simple(self): with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) + self.store.have_seen_events( + self.room_id, [self.event_ids[0], "eventdoesnotexist"] + ) ) self.assertEqual(res, {self.event_ids[0]}) @@ -104,7 +83,9 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # a second lookup of the same events should cause no queries with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0], "event19"]) + self.store.have_seen_events( + self.room_id, [self.event_ids[0], "eventdoesnotexist"] + ) ) self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) @@ -116,11 +97,54 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # looking it up should now cause no db hits with LoggingContext(name="test") as ctx: res = self.get_success( - self.store.have_seen_events("room1", [self.event_ids[0]]) + self.store.have_seen_events(self.room_id, [self.event_ids[0]]) ) self.assertEqual(res, {self.event_ids[0]}) self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) + def test_persisting_event_invalidates_cache(self): + event, event_context = self.get_success( + create_event( + self.hs, + room_id=self.room_id, + sender=self.user, + type="test_event_type", + content={"body": "garply"}, + ) + ) + + with LoggingContext(name="test") as ctx: + # First, check `have_seen_event` for an event we have not seen yet + # to prime the cache with a `false` value. + res = self.get_success( + self.store.have_seen_events(event.room_id, [event.event_id]) + ) + self.assertEqual(res, set()) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Persist the event which should invalidate or prefill the + # `have_seen_event` cache so we don't return stale values. + persistence = self.hs.get_storage_controllers().persistence + self.get_success( + persistence.persist_event( + event, + event_context, + ) + ) + + with LoggingContext(name="test") as ctx: + # Check `have_seen_event` again and we should see the updated fact + # that we have now seen the event after persisting it. + res = self.get_success( + self.store.have_seen_events(event.room_id, [event.event_id]) + ) + self.assertEqual(res, {event.event_id}) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + class EventCacheTestCase(unittest.HomeserverTestCase): """Test that the various layers of event cache works.""" -- cgit 1.5.1 From b7272b73aa38dcb19c9b075514f963390358113d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 22 Sep 2022 08:47:49 -0400 Subject: Properly paginate forward in the /relations API. (#13840) This fixes a bug where the `/relations` API with `dir=f` would skip the first item of each page (except the first page), causing incomplete data to be returned to the client. --- changelog.d/13840.bugfix | 1 + synapse/storage/databases/main/relations.py | 38 +++++++++++++++++++++-------- synapse/storage/databases/main/stream.py | 6 ++--- tests/rest/client/test_relations.py | 29 +++++++++++++++++++++- 4 files changed, 60 insertions(+), 14 deletions(-) create mode 100644 changelog.d/13840.bugfix (limited to 'tests') diff --git a/changelog.d/13840.bugfix b/changelog.d/13840.bugfix new file mode 100644 index 0000000000..0f014439a8 --- /dev/null +++ b/changelog.d/13840.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.53.0 where the experimental implementation of [MSC3715](https://github.com/matrix-org/matrix-spec-proposals/pull/3715) would give incorrect results when paginating forward. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7bd27790eb..898947af95 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -51,6 +51,8 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str + topological_ordering: Optional[int] + stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -91,6 +93,9 @@ class RelationsWorkerStore(SQLBaseStore): # it. The `event_id` must match the `event.event_id`. assert event.event_id == event_id + # Ensure bad limits aren't being passed in. + assert limit >= 0 + where_clause = ["relates_to_id = ?", "room_id = ?"] where_args: List[Union[str, int]] = [event.event_id, room_id] is_redacted = event.internal_metadata.is_redacted() @@ -139,21 +144,34 @@ class RelationsWorkerStore(SQLBaseStore): ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) - last_topo_id = None - last_stream_id = None events = [] - for row in txn: + for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: # Do not include edits for redacted events as they leak event # content. - if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append(_RelatedEvent(row[0], row[2])) - last_topo_id = row[3] - last_stream_id = row[4] + if not is_redacted or relation_type != RelationTypes.REPLACE: + events.append( + _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) + ) - # If there are more events, generate the next pagination key. + # If there are more events, generate the next pagination key from the + # last event returned. next_token = None - if len(events) > limit and last_topo_id and last_stream_id: - next_key = RoomStreamToken(last_topo_id, last_stream_id) + if len(events) > limit: + # Instead of using the last row (which tells us there is more + # data), use the last row to be returned. + events = events[:limit] + + topo = events[-1].topological_ordering + token = events[-1].stream_ordering + if direction == "b": + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + token -= 1 + next_key = RoomStreamToken(topo, token) + if from_token: next_token = from_token.copy_and_replace( StreamKeyType.ROOM, next_key diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 3f9bfaeac5..530f04e149 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1334,15 +1334,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if rows: topo = rows[-1].topological_ordering - toke = rows[-1].stream_ordering + token = rows[-1].stream_ordering if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. - toke -= 1 - next_token = RoomStreamToken(topo, toke) + token -= 1 + next_token = RoomStreamToken(topo, token) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 651f4f415d..d33e34d829 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) + @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) + # Test forward pagination. + prev_token = "" + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEqual(found_event_ids, expected_event_ids) + def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") -- cgit 1.5.1 From 5b9b645400c9c4fdc4054e625930ffa697a271a2 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 22 Sep 2022 21:51:51 -0500 Subject: Add test description --- tests/storage/databases/main/test_events_worker.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'tests') diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 158ad1f439..47ec189684 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -103,6 +103,11 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) def test_persisting_event_invalidates_cache(self): + """ + Test to make sure that the `have_seen_event` cache + is invalided after we persist an event and returns + the updated value. + """ event, event_context = self.get_success( create_event( self.hs, -- cgit 1.5.1 From 4fa8f0534486061f59f8e34f0ea581cca0833c4f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 22 Sep 2022 22:28:56 -0500 Subject: Add test to make sure we can actually clear entries just by room_id --- tests/storage/databases/main/test_events_worker.py | 29 +++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 47ec189684..32a798d74b 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -105,7 +105,7 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): def test_persisting_event_invalidates_cache(self): """ Test to make sure that the `have_seen_event` cache - is invalided after we persist an event and returns + is invalidated after we persist an event and returns the updated value. """ event, event_context = self.get_success( @@ -150,6 +150,33 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # That should result in a single db query to lookup self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + def test_invalidate_cache_by_room_id(self): + """ + Test to make sure that all events associated with the given `(room_id,)` + are invalidated in the `have_seen_event` cache. + """ + with LoggingContext(name="test") as ctx: + # Prime the cache with some values + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Clear the cache with any events associated with the `room_id` + self.store.have_seen_event.invalidate((self.room_id,)) + + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # Since we cleared the cache, it should result in another db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + class EventCacheTestCase(unittest.HomeserverTestCase): """Test that the various layers of event cache works.""" -- cgit 1.5.1 From f8dc17b539748ae28bc34d6f7ef76497b5eac5e6 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 22 Sep 2022 23:05:51 -0500 Subject: Add test to ensure the safety works --- tests/util/caches/test_descriptors.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) (limited to 'tests') diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 48e616ac74..90861fe522 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Set +from typing import Iterable, Set, Tuple from unittest import mock from twisted.internet import defer, reactor @@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.inner_context_was_finished, "Tried to restart a finished logcontext" ) self.assertEqual(current_context(), SENTINEL_CONTEXT) + + def test_num_args_mismatch(self): + """ + Make sure someone does not accidentally use @cachedList on a method with + a mismatch in the number args to the underlying single cache method. + """ + + class Cls: + @descriptors.cached(tree=True) + def fn(self, room_id, event_id): + pass + + # This is wrong ❌. `@cachedList` expects to be given the same number + # of arguments as the underlying cached function, just with one of + # the arguments being an iterable + @descriptors.cachedList(cached_method_name="fn", list_name="keys") + def list_fn(self, keys: Iterable[Tuple[str, str]]): + pass + + # Corrected syntax ✅ + # + # @cachedList(cached_method_name="fn", list_name="event_ids") + # async def list_fn( + # self, room_id: str, event_ids: Collection[str], + # ) + + obj = Cls() + + # Make sure this raises an error about the arg mismatch + with self.assertRaises(Exception): + obj.list_fn([("foo", "bar")]) -- cgit 1.5.1