From 65e6c64d8317d3a10527a7e422753281f3e9ec81 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 26 May 2021 12:19:47 +0200 Subject: Add an admin API for unprotecting local media from quarantine (#10040) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- tests/rest/admin/test_media.py | 99 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) (limited to 'tests') diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index ac7b219700..f741121ea2 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -16,6 +16,8 @@ import json import os from binascii import unhexlify +from parameterized import parameterized + import synapse.rest.admin from synapse.api.errors import Codes from synapse.rest.client.v1 import login, profile, room @@ -562,3 +564,100 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): ) # Test that the file is deleted self.assertFalse(os.path.exists(local_path)) + + +class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + media_repo = hs.get_media_repository_resource() + self.store = hs.get_datastore() + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create media + upload_resource = media_repo.children[b"upload"] + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + self.media_id = server_and_media_id.split("/")[1] + + self.url = "/_synapse/admin/v1/media/%s/%s" + + @parameterized.expand(["protect", "unprotect"]) + def test_no_auth(self, action: str): + """ + Try to protect media without authentication. + """ + + channel = self.make_request("POST", self.url % (action, self.media_id), b"{}") + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + @parameterized.expand(["protect", "unprotect"]) + def test_requester_is_no_admin(self, action: str): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + channel = self.make_request( + "POST", + self.url % (action, self.media_id), + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_protect_media(self): + """ + Tests that protect and unprotect a media is successfully + """ + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["safe_from_quarantine"]) + + # protect + channel = self.make_request( + "POST", + self.url % ("protect", self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertTrue(media_info["safe_from_quarantine"]) + + # unprotect + channel = self.make_request( + "POST", + self.url % ("unprotect", self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["safe_from_quarantine"]) -- cgit 1.5.1 From 224f2f949b1661094a64d1105efb64159ddf4aa0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 27 May 2021 10:33:56 +0100 Subject: Combine `LruCache.invalidate` and `invalidate_many` (#9973) * Make `invalidate` and `invalidate_many` do the same thing ... so that we can do either over the invalidation replication stream, and also because they always confused me a bit. * Kill off `invalidate_many` * changelog --- changelog.d/9973.misc | 1 + synapse/replication/slave/storage/devices.py | 2 +- synapse/storage/databases/main/cache.py | 6 ++-- synapse/storage/databases/main/devices.py | 2 +- .../storage/databases/main/event_push_actions.py | 2 +- synapse/storage/databases/main/events.py | 8 ++--- synapse/storage/databases/main/receipts.py | 6 ++-- synapse/util/caches/deferred_cache.py | 42 +++++++++------------- synapse/util/caches/descriptors.py | 8 +++-- synapse/util/caches/lrucache.py | 18 ++++++---- synapse/util/caches/treecache.py | 3 ++ tests/util/caches/test_descriptors.py | 6 ++-- 12 files changed, 52 insertions(+), 52 deletions(-) create mode 100644 changelog.d/9973.misc (limited to 'tests') diff --git a/changelog.d/9973.misc b/changelog.d/9973.misc new file mode 100644 index 0000000000..7f22d42291 --- /dev/null +++ b/changelog.d/9973.misc @@ -0,0 +1 @@ +Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`. diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 70207420a6..26bdead565 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -68,7 +68,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto if row.entity.startswith("@"): self._device_list_stream_cache.entity_has_changed(row.entity, token) self.get_cached_devices_for_user.invalidate((row.entity,)) - self._get_cached_user_device.invalidate_many((row.entity,)) + self._get_cached_user_device.invalidate((row.entity,)) self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) else: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ecc1f935e2..f7872501a0 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -171,7 +171,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) @@ -184,8 +184,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_invited_rooms_for_local_user.invalidate((state_key,)) if relates_to: - self.get_relations_for_event.invalidate_many((relates_to,)) - self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_relations_for_event.invalidate((relates_to,)) + self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fd87ba71ab..18f07d96dc 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1282,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate, (user_id,)) txn.call_after( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 5845322118..d1237c65cc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -860,7 +860,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): not be deleted. """ txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id, user_id), ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index fd25c8112d..897fa06639 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1748,9 +1748,9 @@ class PersistEventsStore: }, ) - txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,)) txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + self.store.get_aggregation_groups_for_event.invalidate, (parent_id,) ) if rel_type == RelationTypes.REPLACE: @@ -1903,7 +1903,7 @@ class PersistEventsStore: for user_id in user_ids: txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.store.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id, user_id), ) @@ -1917,7 +1917,7 @@ class PersistEventsStore: def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): # Sad that we have to blow away the cache for the whole room here txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.store.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id,), ) txn.execute( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 3647276acb..edeaacd7a6 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -460,7 +460,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate_many((room_id,)) + self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) @@ -659,9 +659,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) # FIXME: This shouldn't invalidate the whole cache - txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) - ) + txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) self.db_pool.simple_delete_txn( txn, diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 371e7e4dd0..1044139119 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -16,16 +16,7 @@ import enum import threading -from typing import ( - Callable, - Generic, - Iterable, - MutableMapping, - Optional, - TypeVar, - Union, - cast, -) +from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union from prometheus_client import Gauge @@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache maps from the key value to a `CacheEntry` object. self._pending_deferred_cache = ( cache_type() - ) # type: MutableMapping[KT, CacheEntry] + ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]] def metrics_cb(): cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) @@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]): self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): + """Delete a key, or tree of entries + + If the cache is backed by a regular dict, then "key" must be of + the right type for this cache + + If the cache is backed by a TreeCache, then "key" must be a tuple, but + may be of lower cardinality than the TreeCache - in which case the whole + subtree is deleted. + """ self.check_thread() - self.cache.pop(key, None) + self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, which will (a) stop it being returned @@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]): # run the invalidation callbacks now, rather than waiting for the # deferred to resolve. if entry: - entry.invalidate() - - def invalidate_many(self, key: KT): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - key = cast(KT, key) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(key, None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): + # _pending_deferred_cache.pop should either return a CacheEntry, or, in the + # case of a TreeCache, a dict of keys to cache entries. Either way calling + # iterate_tree_cache_entry on it will do the right thing. + for entry in iterate_tree_cache_entry(entry): entry.invalidate() def invalidate_all(self): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 2ac24a2f25..d77e8edeea 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any]) class _CachedFunction(Generic[F]): invalidate = None # type: Any invalidate_all = None # type: Any - invalidate_many = None # type: Any prefill = None # type: Any cache = None # type: Any num_args = None # type: Any @@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): ): super().__init__(orig, num_args=num_args, cache_context=cache_context) + if tree and self.num_args < 2: + raise RuntimeError( + "tree=True is nonsensical for cached functions with a single parameter" + ) + self.max_entries = max_entries self.tree = tree self.iterable = iterable @@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): wrapped = cast(_CachedFunction, _wrapped) if self.num_args == 1: + assert not self.tree wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.prefill = lambda key, val: cache.prefill(key[0], val) else: wrapped.invalidate = cache.invalidate - wrapped.invalidate_many = cache.invalidate_many wrapped.prefill = cache.prefill wrapped.invalidate_all = cache.invalidate_all diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 54df407ff7..d89e9d9b1d 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]): """ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. - Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. """ @@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_del_multi(key: KT) -> None: + """Delete an entry, or tree of entries + + If the LruCache is backed by a regular dict, then "key" must be of + the right type for this cache + + If the LruCache is backed by a TreeCache, then "key" must be a tuple, but + may be of lower cardinality than the TreeCache - in which case the whole + subtree is deleted. """ - This will only work if constructed with cache_type=TreeCache - """ - popped = cache.pop(key) + popped = cache.pop(key, None) if popped is None: return # for each deleted node, we now need to remove it from the linked list @@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]): self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + self.del_multi = cache_del_multi # `invalidate` is exposed for consistency with DeferredCache, so that it can be # invalidated by the cache invalidation replication stream. - self.invalidate = cache_pop - if cache_type is TreeCache: - self.del_multi = cache_del_multi + self.invalidate = cache_del_multi self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 73502a8b06..a6df81ebff 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -89,6 +89,9 @@ class TreeCache: value. If the key is partial, the TreeCacheNode corresponding to the part of the tree that was removed. """ + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + # a list of the nodes we have touched on the way down the tree nodes = [] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index bbbc276697..0277998cbe 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -622,17 +622,17 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEquals(callcount2[0], 1) a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1) yield a.func2("foo") a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2) self.assertEquals(callcount[0], 1) self.assertEquals(callcount2[0], 2) a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3) yield a.func("foo") self.assertEquals(callcount[0], 2) -- cgit 1.5.1 From fe5dad46b0da00e9757ed54eb23304ed3c6ceadf Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 27 May 2021 10:34:24 +0100 Subject: Remove redundant code to reload tls cert (#10054) we don't need to reload the tls cert if we don't have any tls listeners. Follow-up to #9280. --- changelog.d/10054.misc | 1 + synapse/app/_base.py | 5 +---- synapse/config/tls.py | 22 +++------------------- tests/config/test_tls.py | 3 +-- 4 files changed, 6 insertions(+), 25 deletions(-) create mode 100644 changelog.d/10054.misc (limited to 'tests') diff --git a/changelog.d/10054.misc b/changelog.d/10054.misc new file mode 100644 index 0000000000..cebe39ce54 --- /dev/null +++ b/changelog.d/10054.misc @@ -0,0 +1 @@ +Remove some dead code regarding TLS certificate handling. diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 59918d789e..1329af2e2b 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -261,13 +261,10 @@ def refresh_certificate(hs): Refresh the TLS certificates that Synapse is using by re-reading them from disk and updating the TLS context factories to use them. """ - if not hs.config.has_tls_listener(): - # attempt to reload the certs for the good of the tls_fingerprints - hs.config.read_certificate_from_disk(require_cert_and_key=False) return - hs.config.read_certificate_from_disk(require_cert_and_key=True) + hs.config.read_certificate_from_disk() hs.tls_server_context_factory = context_factory.ServerContextFactory(hs.config) if hs._listening_services: diff --git a/synapse/config/tls.py b/synapse/config/tls.py index 26f1150ca5..0e9bba53c9 100644 --- a/synapse/config/tls.py +++ b/synapse/config/tls.py @@ -215,28 +215,12 @@ class TlsConfig(Config): days_remaining = (expires_on - now).days return days_remaining - def read_certificate_from_disk(self, require_cert_and_key: bool): + def read_certificate_from_disk(self): """ Read the certificates and private key from disk. - - Args: - require_cert_and_key: set to True to throw an error if the certificate - and key file are not given """ - if require_cert_and_key: - self.tls_private_key = self.read_tls_private_key() - self.tls_certificate = self.read_tls_certificate() - elif self.tls_certificate_file: - # we only need the certificate for the tls_fingerprints. Reload it if we - # can, but it's not a fatal error if we can't. - try: - self.tls_certificate = self.read_tls_certificate() - except Exception as e: - logger.info( - "Unable to read TLS certificate (%s). Ignoring as no " - "tls listeners enabled.", - e, - ) + self.tls_private_key = self.read_tls_private_key() + self.tls_certificate = self.read_tls_certificate() def generate_config_section( self, diff --git a/tests/config/test_tls.py b/tests/config/test_tls.py index 183034f7d4..dcf336416c 100644 --- a/tests/config/test_tls.py +++ b/tests/config/test_tls.py @@ -74,12 +74,11 @@ s4niecZKPBizL6aucT59CsunNmmb5Glq8rlAcU+1ZTZZzGYqVYhF6axB9Qg= config = { "tls_certificate_path": os.path.join(config_dir, "cert.pem"), - "tls_fingerprints": [], } t = TestConfig() t.read_config(config, config_dir_path="", data_dir_path="") - t.read_certificate_from_disk(require_cert_and_key=False) + t.read_tls_certificate() warnings = self.flushWarnings() self.assertEqual(len(warnings), 1) -- cgit 1.5.1 From 78b5102ae71f828deb851eca8e677381710bf716 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 27 May 2021 14:32:31 +0100 Subject: Fix up `BatchingQueue` (#10078) Fixes #10068 --- changelog.d/10078.misc | 1 + synapse/util/batching_queue.py | 70 ++++++++++++++++++++++++----------- tests/util/test_batching_queue.py | 78 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 125 insertions(+), 24 deletions(-) create mode 100644 changelog.d/10078.misc (limited to 'tests') diff --git a/changelog.d/10078.misc b/changelog.d/10078.misc new file mode 100644 index 0000000000..a4b089d0fd --- /dev/null +++ b/changelog.d/10078.misc @@ -0,0 +1 @@ +Fix up `BatchingQueue` implementation. diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 44bbb7b1a8..8fd5bfb69b 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -25,10 +25,11 @@ from typing import ( TypeVar, ) +from prometheus_client import Gauge + from twisted.internet import defer from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable -from synapse.metrics import LaterGauge from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util import Clock @@ -38,6 +39,24 @@ logger = logging.getLogger(__name__) V = TypeVar("V") R = TypeVar("R") +number_queued = Gauge( + "synapse_util_batching_queue_number_queued", + "The number of items waiting in the queue across all keys", + labelnames=("name",), +) + +number_in_flight = Gauge( + "synapse_util_batching_queue_number_pending", + "The number of items across all keys either being processed or waiting in a queue", + labelnames=("name",), +) + +number_of_keys = Gauge( + "synapse_util_batching_queue_number_of_keys", + "The number of distinct keys that have items queued", + labelnames=("name",), +) + class BatchingQueue(Generic[V, R]): """A queue that batches up work, calling the provided processing function @@ -48,10 +67,20 @@ class BatchingQueue(Generic[V, R]): called, and will keep being called until the queue has been drained (for the given key). + If the processing function raises an exception then the exception is proxied + through to the callers waiting on that batch of work. + Note that the return value of `add_to_queue` will be the return value of the processing function that processed the given item. This means that the returned value will likely include data for other items that were in the batch. + + Args: + name: A name for the queue, used for logging contexts and metrics. + This must be unique, otherwise the metrics will be wrong. + clock: The clock to use to schedule work. + process_batch_callback: The callback to to be run to process a batch of + work. """ def __init__( @@ -73,19 +102,15 @@ class BatchingQueue(Generic[V, R]): # The function to call with batches of values. self._process_batch_callback = process_batch_callback - LaterGauge( - "synapse_util_batching_queue_number_queued", - "The number of items waiting in the queue across all keys", - labels=("name",), - caller=lambda: sum(len(v) for v in self._next_values.values()), + number_queued.labels(self._name).set_function( + lambda: sum(len(q) for q in self._next_values.values()) ) - LaterGauge( - "synapse_util_batching_queue_number_of_keys", - "The number of distinct keys that have items queued", - labels=("name",), - caller=lambda: len(self._next_values), - ) + number_of_keys.labels(self._name).set_function(lambda: len(self._next_values)) + + self._number_in_flight_metric = number_in_flight.labels( + self._name + ) # type: Gauge async def add_to_queue(self, value: V, key: Hashable = ()) -> R: """Adds the value to the queue with the given key, returning the result @@ -107,17 +132,18 @@ class BatchingQueue(Generic[V, R]): if key not in self._processing_keys: run_as_background_process(self._name, self._process_queue, key) - return await make_deferred_yieldable(d) + with self._number_in_flight_metric.track_inprogress(): + return await make_deferred_yieldable(d) async def _process_queue(self, key: Hashable) -> None: """A background task to repeatedly pull things off the queue for the given key and call the `self._process_batch_callback` with the values. """ - try: - if key in self._processing_keys: - return + if key in self._processing_keys: + return + try: self._processing_keys.add(key) while True: @@ -137,16 +163,16 @@ class BatchingQueue(Generic[V, R]): values = [value for value, _ in next_values] results = await self._process_batch_callback(values) - for _, deferred in next_values: - with PreserveLoggingContext(): + with PreserveLoggingContext(): + for _, deferred in next_values: deferred.callback(results) except Exception as e: - for _, deferred in next_values: - if deferred.called: - continue + with PreserveLoggingContext(): + for _, deferred in next_values: + if deferred.called: + continue - with PreserveLoggingContext(): deferred.errback(e) finally: diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index 5def1e56c9..edf29e5b96 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -14,7 +14,12 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable -from synapse.util.batching_queue import BatchingQueue +from synapse.util.batching_queue import ( + BatchingQueue, + number_in_flight, + number_of_keys, + number_queued, +) from tests.server import get_clock from tests.unittest import TestCase @@ -24,6 +29,14 @@ class BatchingQueueTestCase(TestCase): def setUp(self): self.clock, hs_clock = get_clock() + # We ensure that we remove any existing metrics for "test_queue". + try: + number_queued.remove("test_queue") + number_of_keys.remove("test_queue") + number_in_flight.remove("test_queue") + except KeyError: + pass + self._pending_calls = [] self.queue = BatchingQueue("test_queue", hs_clock, self._process_queue) @@ -32,6 +45,41 @@ class BatchingQueueTestCase(TestCase): self._pending_calls.append((values, d)) return await make_deferred_yieldable(d) + def _assert_metrics(self, queued, keys, in_flight): + """Assert that the metrics are correct""" + + self.assertEqual(len(number_queued.collect()), 1) + self.assertEqual(len(number_queued.collect()[0].samples), 1) + self.assertEqual( + number_queued.collect()[0].samples[0].labels, + {"name": self.queue._name}, + ) + self.assertEqual( + number_queued.collect()[0].samples[0].value, + queued, + "number_queued", + ) + + self.assertEqual(len(number_of_keys.collect()), 1) + self.assertEqual(len(number_of_keys.collect()[0].samples), 1) + self.assertEqual( + number_queued.collect()[0].samples[0].labels, {"name": self.queue._name} + ) + self.assertEqual( + number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys" + ) + + self.assertEqual(len(number_in_flight.collect()), 1) + self.assertEqual(len(number_in_flight.collect()[0].samples), 1) + self.assertEqual( + number_queued.collect()[0].samples[0].labels, {"name": self.queue._name} + ) + self.assertEqual( + number_in_flight.collect()[0].samples[0].value, + in_flight, + "number_in_flight", + ) + def test_simple(self): """Tests the basic case of calling `add_to_queue` once and having `_process_queue` return. @@ -41,6 +89,8 @@ class BatchingQueueTestCase(TestCase): queue_d = defer.ensureDeferred(self.queue.add_to_queue("foo")) + self._assert_metrics(queued=1, keys=1, in_flight=1) + # The queue should wait a reactor tick before calling the processing # function. self.assertFalse(self._pending_calls) @@ -52,12 +102,15 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo"]) self.assertFalse(queue_d.called) + self._assert_metrics(queued=0, keys=0, in_flight=1) # Return value of the `_process_queue` should be propagated back. self._pending_calls.pop()[1].callback("bar") self.assertEqual(self.successResultOf(queue_d), "bar") + self._assert_metrics(queued=0, keys=0, in_flight=0) + def test_batching(self): """Test that multiple calls at the same time get batched up into one call to `_process_queue`. @@ -68,6 +121,8 @@ class BatchingQueueTestCase(TestCase): queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) + self._assert_metrics(queued=2, keys=1, in_flight=2) + self.clock.pump([0]) # We should see only *one* call to `_process_queue` @@ -75,12 +130,14 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self._pending_calls[0][0], ["foo1", "foo2"]) self.assertFalse(queue_d1.called) self.assertFalse(queue_d2.called) + self._assert_metrics(queued=0, keys=0, in_flight=2) # Return value of the `_process_queue` should be propagated back to both. self._pending_calls.pop()[1].callback("bar") self.assertEqual(self.successResultOf(queue_d1), "bar") self.assertEqual(self.successResultOf(queue_d2), "bar") + self._assert_metrics(queued=0, keys=0, in_flight=0) def test_queuing(self): """Test that we queue up requests while a `_process_queue` is being @@ -92,13 +149,20 @@ class BatchingQueueTestCase(TestCase): queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) self.clock.pump([0]) + self.assertEqual(len(self._pending_calls), 1) + + # We queue up work after the process function has been called, testing + # that they get correctly queued up. queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2")) + queue_d3 = defer.ensureDeferred(self.queue.add_to_queue("foo3")) # We should see only *one* call to `_process_queue` self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo1"]) self.assertFalse(queue_d1.called) self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + self._assert_metrics(queued=2, keys=1, in_flight=3) # Return value of the `_process_queue` should be propagated back to the # first. @@ -106,18 +170,24 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self.successResultOf(queue_d1), "bar1") self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + self._assert_metrics(queued=2, keys=1, in_flight=2) # We should now see a second call to `_process_queue` self.clock.pump([0]) self.assertEqual(len(self._pending_calls), 1) - self.assertEqual(self._pending_calls[0][0], ["foo2"]) + self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"]) self.assertFalse(queue_d2.called) + self.assertFalse(queue_d3.called) + self._assert_metrics(queued=0, keys=0, in_flight=2) # Return value of the `_process_queue` should be propagated back to the # second. self._pending_calls.pop()[1].callback("bar2") self.assertEqual(self.successResultOf(queue_d2), "bar2") + self.assertEqual(self.successResultOf(queue_d3), "bar2") + self._assert_metrics(queued=0, keys=0, in_flight=0) def test_different_keys(self): """Test that calls to different keys get processed in parallel.""" @@ -140,6 +210,7 @@ class BatchingQueueTestCase(TestCase): self.assertFalse(queue_d1.called) self.assertFalse(queue_d2.called) self.assertFalse(queue_d3.called) + self._assert_metrics(queued=1, keys=1, in_flight=3) # Return value of the `_process_queue` should be propagated back to the # first. @@ -148,6 +219,7 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(self.successResultOf(queue_d1), "bar1") self.assertFalse(queue_d2.called) self.assertFalse(queue_d3.called) + self._assert_metrics(queued=1, keys=1, in_flight=2) # Return value of the `_process_queue` should be propagated back to the # second. @@ -161,9 +233,11 @@ class BatchingQueueTestCase(TestCase): self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo3"]) self.assertFalse(queue_d3.called) + self._assert_metrics(queued=0, keys=0, in_flight=1) # Return value of the `_process_queue` should be propagated back to the # third deferred. self._pending_calls.pop()[1].callback("bar4") self.assertEqual(self.successResultOf(queue_d3), "bar4") + self._assert_metrics(queued=0, keys=0, in_flight=0) -- cgit 1.5.1 From 8fb9af570f942d2057e8acb4a047d61ed7048f58 Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Thu, 27 May 2021 18:42:23 +0100 Subject: Make reason and score optional for report_event (#10077) Implements MSC2414: https://github.com/matrix-org/matrix-doc/pull/2414 See #8551 Signed-off-by: Callum Brown --- changelog.d/10077.feature | 1 + docs/admin_api/event_reports.md | 4 +- synapse/rest/client/v2_alpha/report_event.py | 13 ++-- synapse/storage/databases/main/room.py | 2 +- tests/rest/admin/test_event_reports.py | 15 ++++- tests/rest/client/v2_alpha/test_report_event.py | 83 +++++++++++++++++++++++++ 6 files changed, 105 insertions(+), 13 deletions(-) create mode 100644 changelog.d/10077.feature create mode 100644 tests/rest/client/v2_alpha/test_report_event.py (limited to 'tests') diff --git a/changelog.d/10077.feature b/changelog.d/10077.feature new file mode 100644 index 0000000000..808feb2215 --- /dev/null +++ b/changelog.d/10077.feature @@ -0,0 +1 @@ +Make reason and score parameters optional for reporting content. Implements [MSC2414](https://github.com/matrix-org/matrix-doc/pull/2414). Contributed by Callum Brown. diff --git a/docs/admin_api/event_reports.md b/docs/admin_api/event_reports.md index 0159098138..bfec06f755 100644 --- a/docs/admin_api/event_reports.md +++ b/docs/admin_api/event_reports.md @@ -75,9 +75,9 @@ The following fields are returned in the JSON response body: * `name`: string - The name of the room. * `event_id`: string - The ID of the reported event. * `user_id`: string - This is the user who reported the event and wrote the reason. -* `reason`: string - Comment made by the `user_id` in this report. May be blank. +* `reason`: string - Comment made by the `user_id` in this report. May be blank or `null`. * `score`: integer - Content is reported based upon a negative score, where -100 is - "most offensive" and 0 is "inoffensive". + "most offensive" and 0 is "inoffensive". May be `null`. * `sender`: string - This is the ID of the user who sent the original message/event that was reported. * `canonical_alias`: string - The canonical alias of the room. `null` if the room does not diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 2c169abbf3..07ea39a8a3 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -16,11 +16,7 @@ import logging from http import HTTPStatus from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) +from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -42,15 +38,14 @@ class ReportEventRestServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - assert_params_in_dict(body, ("reason", "score")) - if not isinstance(body["reason"], str): + if not isinstance(body.get("reason", ""), str): raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'reason' must be a string", Codes.BAD_JSON, ) - if not isinstance(body["score"], int): + if not isinstance(body.get("score", 0), int): raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", @@ -61,7 +56,7 @@ class ReportEventRestServlet(RestServlet): room_id=room_id, event_id=event_id, user_id=user_id, - reason=body["reason"], + reason=body.get("reason"), content=body, received_ts=self.clock.time_msec(), ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 5f38634f48..0cf450f81d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): room_id: str, event_id: str, user_id: str, - reason: str, + reason: Optional[str], content: JsonDict, received_ts: int, ) -> None: diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 29341bc6e9..f15d1cf6f7 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -64,7 +64,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): user_tok=self.admin_user_tok, ) for _ in range(5): - self._create_event_and_report( + self._create_event_and_report_without_parameters( room_id=self.room_id2, user_tok=self.admin_user_tok, ) @@ -378,6 +378,19 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def _create_event_and_report_without_parameters(self, room_id, user_tok): + """Create and report an event, but omit reason and score""" + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({}), + access_token=user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def _check_fields(self, content): """Checks that all attributes are present in an event report""" for c in content: diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py new file mode 100644 index 0000000000..1ec6b05e5b --- /dev/null +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -0,0 +1,83 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import report_event + +from tests import unittest + + +class ReportEventTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) + resp = self.helper.send(self.room_id, tok=self.admin_user_tok) + self.event_id = resp["event_id"] + self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id) + + def test_reason_str_and_score_int(self): + data = {"reason": "this makes me sad", "score": -100} + self._assert_status(200, data) + + def test_no_reason(self): + data = {"score": 0} + self._assert_status(200, data) + + def test_no_score(self): + data = {"reason": "this makes me sad"} + self._assert_status(200, data) + + def test_no_reason_and_no_score(self): + data = {} + self._assert_status(200, data) + + def test_reason_int_and_score_str(self): + data = {"reason": 10, "score": "string"} + self._assert_status(400, data) + + def test_reason_zero_and_score_blank(self): + data = {"reason": 0, "score": ""} + self._assert_status(400, data) + + def test_reason_and_score_null(self): + data = {"reason": None, "score": None} + self._assert_status(400, data) + + def _assert_status(self, response_status, data): + channel = self.make_request( + "POST", + self.report_path, + json.dumps(data), + access_token=self.other_user_tok, + ) + self.assertEqual( + response_status, int(channel.result["code"]), msg=channel.result["body"] + ) -- cgit 1.5.1 From b4b2fd2ecee26214fa6b322bcb62bec1ea324c1a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Jun 2021 12:04:47 +0100 Subject: add a cache to have_seen_event (#9953) Empirically, this helped my server considerably when handling gaps in Matrix HQ. The problem was that we would repeatedly call have_seen_events for the same set of (50K or so) auth_events, each of which would take many minutes to complete, even though it's only an index scan. --- changelog.d/9953.feature | 1 + changelog.d/9973.feature | 1 + changelog.d/9973.misc | 1 - synapse/handlers/federation.py | 12 +-- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events_worker.py | 61 ++++++++++++-- synapse/storage/databases/main/purge_events.py | 26 ++++-- tests/storage/databases/__init__.py | 13 +++ tests/storage/databases/main/__init__.py | 13 +++ tests/storage/databases/main/test_events_worker.py | 96 ++++++++++++++++++++++ 10 files changed, 205 insertions(+), 20 deletions(-) create mode 100644 changelog.d/9953.feature create mode 100644 changelog.d/9973.feature delete mode 100644 changelog.d/9973.misc create mode 100644 tests/storage/databases/__init__.py create mode 100644 tests/storage/databases/main/__init__.py create mode 100644 tests/storage/databases/main/test_events_worker.py (limited to 'tests') diff --git a/changelog.d/9953.feature b/changelog.d/9953.feature new file mode 100644 index 0000000000..6b3d1adc70 --- /dev/null +++ b/changelog.d/9953.feature @@ -0,0 +1 @@ +Improve performance of incoming federation transactions in large rooms. diff --git a/changelog.d/9973.feature b/changelog.d/9973.feature new file mode 100644 index 0000000000..6b3d1adc70 --- /dev/null +++ b/changelog.d/9973.feature @@ -0,0 +1 @@ +Improve performance of incoming federation transactions in large rooms. diff --git a/changelog.d/9973.misc b/changelog.d/9973.misc deleted file mode 100644 index 7f22d42291..0000000000 --- a/changelog.d/9973.misc +++ /dev/null @@ -1 +0,0 @@ -Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bf11315251..49ed7cabcc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -577,7 +577,9 @@ class FederationHandler(BaseHandler): # Fetch the state events from the DB, and check we have the auth events. event_map = await self.store.get_events(state_event_ids, allow_rejected=True) - auth_events_in_store = await self.store.have_seen_events(auth_event_ids) + auth_events_in_store = await self.store.have_seen_events( + room_id, auth_event_ids + ) # Check for missing events. We handle state and auth event seperately, # as we want to pull the state from the DB, but we don't for the auth @@ -610,7 +612,7 @@ class FederationHandler(BaseHandler): if missing_auth_events: auth_events_in_store = await self.store.have_seen_events( - missing_auth_events + room_id, missing_auth_events ) missing_auth_events.difference_update(auth_events_in_store) @@ -710,7 +712,7 @@ class FederationHandler(BaseHandler): missing_auth_events = set(auth_event_ids) - fetched_events.keys() missing_auth_events.difference_update( - await self.store.have_seen_events(missing_auth_events) + await self.store.have_seen_events(room_id, missing_auth_events) ) logger.debug("We are also missing %i auth events", len(missing_auth_events)) @@ -2475,7 +2477,7 @@ class FederationHandler(BaseHandler): # # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - have_events = await self.store.have_seen_events(missing_auth) + have_events = await self.store.have_seen_events(event.room_id, missing_auth) logger.debug("Events %s are in the store", have_events) missing_auth.difference_update(have_events) @@ -2494,7 +2496,7 @@ class FederationHandler(BaseHandler): return context seen_remotes = await self.store.have_seen_events( - [e.event_id for e in remote_auth_chain] + event.room_id, [e.event_id for e in remote_auth_chain] ) for e in remote_auth_chain: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index f7872501a0..c57ae5ef15 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -168,6 +168,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled, ): self._invalidate_get_event_cache(event_id) + self.have_seen_event.invalidate((room_id, event_id)) self.get_latest_event_ids_in_room.invalidate((room_id,)) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6963bbf7f4..403a5ddaba 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -22,6 +22,7 @@ from typing import ( Iterable, List, Optional, + Set, Tuple, overload, ) @@ -55,7 +56,7 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1045,32 +1046,74 @@ class EventsWorkerStore(SQLBaseStore): return {r["event_id"] for r in rows} - async def have_seen_events(self, event_ids): + async def have_seen_events( + self, room_id: str, event_ids: Iterable[str] + ) -> Set[str]: """Given a list of event ids, check if we have already processed them. + The room_id is only used to structure the cache (so that it can later be + invalidated by room_id) - there is no guarantee that the events are actually + in the room in question. + Args: - event_ids (iterable[str]): + room_id: Room we are polling + event_ids: events we are looking for Returns: set[str]: The events we have already seen. """ + res = await self._have_seen_events_dict( + (room_id, event_id) for event_id in event_ids + ) + return {eid for ((_rid, eid), have_event) in res.items() if have_event} + + @cachedList("have_seen_event", "keys") + async def _have_seen_events_dict( + self, keys: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], bool]: + """Helper for have_seen_events + + Returns: + a dict {(room_id, event_id)-> bool} + """ # if the event cache contains the event, obviously we've seen it. - results = {x for x in event_ids if self._get_event_cache.contains(x)} - def have_seen_events_txn(txn, chunk): - sql = "SELECT event_id FROM events as e WHERE " + cache_results = { + (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) + } + results = {x: True for x in cache_results} + + def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): + # we deliberately do *not* query the database for room_id, to make the + # query an index-only lookup on `events_event_id_key`. + # + # We therefore pull the events from the database into a set... + + sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", chunk + txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk] ) txn.execute(sql + clause, args) - results.update(row[0] for row in txn) + found_events = {eid for eid, in txn} - for chunk in batch_iter((x for x in event_ids if x not in results), 100): + # ... and then we can update the results for each row in the batch + results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk}) + + # each batch requires its own index scan, so we make the batches as big as + # possible. + for chunk in batch_iter((k for k in keys if k not in cache_results), 500): await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) + return results + @cached(max_entries=100000, tree=True) + async def have_seen_event(self, room_id: str, event_id: str): + # this only exists for the benefit of the @cachedList descriptor on + # _have_seen_events_dict + raise NotImplementedError() + def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 8f83748b5e..7fb7780d0f 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -16,14 +16,14 @@ import logging from typing import Any, List, Set, Tuple from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken logger = logging.getLogger(__name__) -class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): +class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): async def purge_history( self, room_id: str, token: str, delete_local_events: bool ) -> Set[int]: @@ -203,8 +203,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "DELETE FROM event_to_state_groups " "WHERE event_id IN (SELECT event_id from events_to_purge)" ) - for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) # Delete all remote non-state events for table in ( @@ -283,6 +281,20 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # so make sure to keep this actually last. txn.execute("DROP TABLE events_to_purge") + for event_id, should_delete in event_rows: + self._invalidate_cache_and_stream( + txn, self._get_state_group_for_event, (event_id,) + ) + + # XXX: This is racy, since have_seen_events could be called between the + # transaction completing and the invalidation running. On the other hand, + # that's no different to calling `have_seen_events` just before the + # event is deleted from the database. + if should_delete: + self._invalidate_cache_and_stream( + txn, self.have_seen_event, (room_id, event_id) + ) + logger.info("[purge] done") return referenced_state_groups @@ -422,7 +434,11 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # index on them. In any case we should be clearing out 'stream' tables # periodically anyway (#5888) - # TODO: we could probably usefully do a bunch of cache invalidation here + # TODO: we could probably usefully do a bunch more cache invalidation here + + # XXX: as with purge_history, this is racy, but no worse than other races + # that already exist. + self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,)) logger.info("[purge] done") diff --git a/tests/storage/databases/__init__.py b/tests/storage/databases/__init__.py new file mode 100644 index 0000000000..c24c7ecd92 --- /dev/null +++ b/tests/storage/databases/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/storage/databases/main/__init__.py b/tests/storage/databases/main/__init__.py new file mode 100644 index 0000000000..c24c7ecd92 --- /dev/null +++ b/tests/storage/databases/main/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py new file mode 100644 index 0000000000..932970fd9a --- /dev/null +++ b/tests/storage/databases/main/test_events_worker.py @@ -0,0 +1,96 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from synapse.logging.context import LoggingContext +from synapse.storage.databases.main.events_worker import EventsWorkerStore + +from tests import unittest + + +class HaveSeenEventsTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + # insert some test data + for rid in ("room1", "room2"): + self.get_success( + self.store.db_pool.simple_insert( + "rooms", + {"room_id": rid, "room_version": 4}, + ) + ) + + for idx, (rid, eid) in enumerate( + ( + ("room1", "event10"), + ("room1", "event11"), + ("room1", "event12"), + ("room2", "event20"), + ) + ): + self.get_success( + self.store.db_pool.simple_insert( + "events", + { + "event_id": eid, + "room_id": rid, + "topological_ordering": idx, + "stream_ordering": idx, + "type": "test", + "processed": True, + "outlier": False, + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "event_json", + { + "event_id": eid, + "room_id": rid, + "json": json.dumps({"type": "test", "room_id": rid}), + "internal_metadata": "{}", + "format_version": 3, + }, + ) + ) + + def test_simple(self): + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events("room1", ["event10", "event19"]) + ) + self.assertEquals(res, {"event10"}) + + # that should result in a single db query + self.assertEquals(ctx.get_resource_usage().db_txn_count, 1) + + # a second lookup of the same events should cause no queries + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events("room1", ["event10", "event19"]) + ) + self.assertEquals(res, {"event10"}) + self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + def test_query_via_event_cache(self): + # fetch an event into the event cache + self.get_success(self.store.get_event("event10")) + + # looking it up should now cause no db hits + with LoggingContext(name="test") as ctx: + res = self.get_success(self.store.have_seen_events("room1", ["event10"])) + self.assertEquals(res, {"event10"}) + self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) -- cgit 1.5.1 From fc3d2dc269a79e0404d0a9867e5042354d59147f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 2 Jun 2021 16:37:59 +0100 Subject: Rewrite the KeyRing (#10035) --- changelog.d/10035.feature | 1 + synapse/crypto/keyring.py | 642 +++++++++++--------------- synapse/federation/transport/server.py | 4 +- synapse/groups/attestations.py | 4 +- synapse/rest/key/v2/remote_key_resource.py | 9 +- tests/crypto/test_keyring.py | 170 +++---- tests/rest/key/v2/test_remote_key_resource.py | 18 +- tests/util/test_batching_queue.py | 37 +- 8 files changed, 393 insertions(+), 492 deletions(-) create mode 100644 changelog.d/10035.feature (limited to 'tests') diff --git a/changelog.d/10035.feature b/changelog.d/10035.feature new file mode 100644 index 0000000000..68052b5a7e --- /dev/null +++ b/changelog.d/10035.feature @@ -0,0 +1 @@ +Rewrite logic around verifying JSON object and fetching server keys to be more performant and use less memory. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 6fc0712978..c840ffca71 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -16,8 +16,7 @@ import abc import logging import urllib -from collections import defaultdict -from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Callable, Dict, Iterable, List, Optional, Tuple import attr from signedjson.key import ( @@ -44,17 +43,12 @@ from synapse.api.errors import ( from synapse.config.key import TrustedKeyServer from synapse.events import EventBase from synapse.events.utils import prune_event_dict -from synapse.logging.context import ( - PreserveLoggingContext, - make_deferred_yieldable, - preserve_fn, - run_in_background, -) +from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.storage.keys import FetchKeyResult from synapse.types import JsonDict from synapse.util import unwrapFirstError from synapse.util.async_helpers import yieldable_gather_results -from synapse.util.metrics import Measure +from synapse.util.batching_queue import BatchingQueue from synapse.util.retryutils import NotRetryingDestination if TYPE_CHECKING: @@ -80,32 +74,19 @@ class VerifyJsonRequest: minimum_valid_until_ts: time at which we require the signing key to be valid. (0 implies we don't care) - request_name: The name of the request. - key_ids: The set of key_ids to that could be used to verify the JSON object - - key_ready (Deferred[str, str, nacl.signing.VerifyKey]): - A deferred (server_name, key_id, verify_key) tuple that resolves when - a verify key has been fetched. The deferreds' callbacks are run with no - logcontext. - - If we are unable to find a key which satisfies the request, the deferred - errbacks with an M_UNAUTHORIZED SynapseError. """ server_name = attr.ib(type=str) get_json_object = attr.ib(type=Callable[[], JsonDict]) minimum_valid_until_ts = attr.ib(type=int) - request_name = attr.ib(type=str) key_ids = attr.ib(type=List[str]) - key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred) @staticmethod def from_json_object( server_name: str, json_object: JsonDict, minimum_valid_until_ms: int, - request_name: str, ): """Create a VerifyJsonRequest to verify all signatures on a signed JSON object for the given server. @@ -115,7 +96,6 @@ class VerifyJsonRequest: server_name, lambda: json_object, minimum_valid_until_ms, - request_name=request_name, key_ids=key_ids, ) @@ -135,16 +115,48 @@ class VerifyJsonRequest: # memory than the Event object itself. lambda: prune_event_dict(event.room_version, event.get_pdu_json()), minimum_valid_until_ms, - request_name=event.event_id, key_ids=key_ids, ) + def to_fetch_key_request(self) -> "_FetchKeyRequest": + """Create a key fetch request for all keys needed to satisfy the + verification request. + """ + return _FetchKeyRequest( + server_name=self.server_name, + minimum_valid_until_ts=self.minimum_valid_until_ts, + key_ids=self.key_ids, + ) + class KeyLookupError(ValueError): pass +@attr.s(slots=True) +class _FetchKeyRequest: + """A request for keys for a given server. + + We will continue to try and fetch until we have all the keys listed under + `key_ids` (with an appropriate `valid_until_ts` property) or we run out of + places to fetch keys from. + + Attributes: + server_name: The name of the server that owns the keys. + minimum_valid_until_ts: The timestamp which the keys must be valid until. + key_ids: The IDs of the keys to attempt to fetch + """ + + server_name = attr.ib(type=str) + minimum_valid_until_ts = attr.ib(type=int) + key_ids = attr.ib(type=List[str]) + + class Keyring: + """Handles verifying signed JSON objects and fetching the keys needed to do + so. + """ + def __init__( self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None ): @@ -158,22 +170,22 @@ class Keyring: ) self._key_fetchers = key_fetchers - # map from server name to Deferred. Has an entry for each server with - # an ongoing key download; the Deferred completes once the download - # completes. - # - # These are regular, logcontext-agnostic Deferreds. - self.key_downloads = {} # type: Dict[str, defer.Deferred] + self._server_queue = BatchingQueue( + "keyring_server", + clock=hs.get_clock(), + process_batch_callback=self._inner_fetch_key_requests, + ) # type: BatchingQueue[_FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]] - def verify_json_for_server( + async def verify_json_for_server( self, server_name: str, json_object: JsonDict, validity_time: int, - request_name: str, - ) -> defer.Deferred: + ) -> None: """Verify that a JSON object has been signed by a given server + Completes if the the object was correctly signed, otherwise raises. + Args: server_name: name of the server which must have signed this object @@ -181,52 +193,45 @@ class Keyring: validity_time: timestamp at which we require the signing key to be valid. (0 implies we don't care) - - request_name: an identifier for this json object (eg, an event id) - for logging. - - Returns: - Deferred[None]: completes if the the object was correctly signed, otherwise - errbacks with an error """ request = VerifyJsonRequest.from_json_object( server_name, json_object, validity_time, - request_name, ) - requests = (request,) - return make_deferred_yieldable(self._verify_objects(requests)[0]) + return await self.process_request(request) def verify_json_objects_for_server( - self, server_and_json: Iterable[Tuple[str, dict, int, str]] + self, server_and_json: Iterable[Tuple[str, dict, int]] ) -> List[defer.Deferred]: """Bulk verifies signatures of json objects, bulk fetching keys as necessary. Args: server_and_json: - Iterable of (server_name, json_object, validity_time, request_name) + Iterable of (server_name, json_object, validity_time) tuples. validity_time is a timestamp at which the signing key must be valid. - request_name is an identifier for this json object (eg, an event id) - for logging. - Returns: List: for each input triplet, a deferred indicating success or failure to verify each json object's signature for the given server_name. The deferreds run their callbacks in the sentinel logcontext. """ - return self._verify_objects( - VerifyJsonRequest.from_json_object( - server_name, json_object, validity_time, request_name + return [ + run_in_background( + self.process_request, + VerifyJsonRequest.from_json_object( + server_name, + json_object, + validity_time, + ), ) - for server_name, json_object, validity_time, request_name in server_and_json - ) + for server_name, json_object, validity_time in server_and_json + ] def verify_events_for_server( self, server_and_events: Iterable[Tuple[str, EventBase, int]] @@ -252,321 +257,223 @@ class Keyring: server_name. The deferreds run their callbacks in the sentinel logcontext. """ - return self._verify_objects( - VerifyJsonRequest.from_event(server_name, event, validity_time) + return [ + run_in_background( + self.process_request, + VerifyJsonRequest.from_event( + server_name, + event, + validity_time, + ), + ) for server_name, event, validity_time in server_and_events - ) - - def _verify_objects( - self, verify_requests: Iterable[VerifyJsonRequest] - ) -> List[defer.Deferred]: - """Does the work of verify_json_[objects_]for_server - - - Args: - verify_requests: Iterable of verification requests. + ] - Returns: - List: for each input item, a deferred indicating success - or failure to verify each json object's signature for the given - server_name. The deferreds run their callbacks in the sentinel - logcontext. + async def process_request(self, verify_request: VerifyJsonRequest) -> None: + """Processes the `VerifyJsonRequest`. Raises if the object is not signed + by the server, the signatures don't match or we failed to fetch the + necessary keys. """ - # a list of VerifyJsonRequests which are awaiting a key lookup - key_lookups = [] - handle = preserve_fn(_handle_key_deferred) - - def process(verify_request: VerifyJsonRequest) -> defer.Deferred: - """Process an entry in the request list - - Adds a key request to key_lookups, and returns a deferred which - will complete or fail (in the sentinel context) when verification completes. - """ - if not verify_request.key_ids: - return defer.fail( - SynapseError( - 400, - "Not signed by %s" % (verify_request.server_name,), - Codes.UNAUTHORIZED, - ) - ) - logger.debug( - "Verifying %s for %s with key_ids %s, min_validity %i", - verify_request.request_name, - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, + if not verify_request.key_ids: + raise SynapseError( + 400, + f"Not signed by {verify_request.server_name}", + Codes.UNAUTHORIZED, ) - # add the key request to the queue, but don't start it off yet. - key_lookups.append(verify_request) - - # now run _handle_key_deferred, which will wait for the key request - # to complete and then do the verification. - # - # We want _handle_key_request to log to the right context, so we - # wrap it with preserve_fn (aka run_in_background) - return handle(verify_request) - - results = [process(r) for r in verify_requests] - - if key_lookups: - run_in_background(self._start_key_lookups, key_lookups) - - return results - - async def _start_key_lookups( - self, verify_requests: List[VerifyJsonRequest] - ) -> None: - """Sets off the key fetches for each verify request - - Once each fetch completes, verify_request.key_ready will be resolved. - - Args: - verify_requests: - """ - - try: - # map from server name to a set of outstanding request ids - server_to_request_ids = {} # type: Dict[str, Set[int]] - - for verify_request in verify_requests: - server_name = verify_request.server_name - request_id = id(verify_request) - server_to_request_ids.setdefault(server_name, set()).add(request_id) - - # Wait for any previous lookups to complete before proceeding. - await self.wait_for_previous_lookups(server_to_request_ids.keys()) - - # take out a lock on each of the servers by sticking a Deferred in - # key_downloads - for server_name in server_to_request_ids.keys(): - self.key_downloads[server_name] = defer.Deferred() - logger.debug("Got key lookup lock on %s", server_name) - - # When we've finished fetching all the keys for a given server_name, - # drop the lock by resolving the deferred in key_downloads. - def drop_server_lock(server_name): - d = self.key_downloads.pop(server_name) - d.callback(None) - - def lookup_done(res, verify_request): - server_name = verify_request.server_name - server_requests = server_to_request_ids[server_name] - server_requests.remove(id(verify_request)) - - # if there are no more requests for this server, we can drop the lock. - if not server_requests: - logger.debug("Releasing key lookup lock on %s", server_name) - drop_server_lock(server_name) - - return res + # Add the keys we need to verify to the queue for retrieval. We queue + # up requests for the same server so we don't end up with many in flight + # requests for the same keys. + key_request = verify_request.to_fetch_key_request() + found_keys_by_server = await self._server_queue.add_to_queue( + key_request, key=verify_request.server_name + ) - for verify_request in verify_requests: - verify_request.key_ready.addBoth(lookup_done, verify_request) + # Since we batch up requests the returned set of keys may contain keys + # from other servers, so we pull out only the ones we care about.s + found_keys = found_keys_by_server.get(verify_request.server_name, {}) - # Actually start fetching keys. - self._get_server_verify_keys(verify_requests) - except Exception: - logger.exception("Error starting key lookups") + # Verify each signature we got valid keys for, raising if we can't + # verify any of them. + verified = False + for key_id in verify_request.key_ids: + key_result = found_keys.get(key_id) + if not key_result: + continue - async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None: - """Waits for any previous key lookups for the given servers to finish. + if key_result.valid_until_ts < verify_request.minimum_valid_until_ts: + continue - Args: - server_names: list of servers which we want to look up + verify_key = key_result.verify_key + json_object = verify_request.get_json_object() + try: + verify_signed_json( + json_object, + verify_request.server_name, + verify_key, + ) + verified = True + except SignatureVerifyException as e: + logger.debug( + "Error verifying signature for %s:%s:%s with key %s: %s", + verify_request.server_name, + verify_key.alg, + verify_key.version, + encode_verify_key_base64(verify_key), + str(e), + ) + raise SynapseError( + 401, + "Invalid signature for server %s with key %s:%s: %s" + % ( + verify_request.server_name, + verify_key.alg, + verify_key.version, + str(e), + ), + Codes.UNAUTHORIZED, + ) - Returns: - Resolves once all key lookups for the given servers have - completed. Follows the synapse rules of logcontext preservation. - """ - loop_count = 1 - while True: - wait_on = [ - (server_name, self.key_downloads[server_name]) - for server_name in server_names - if server_name in self.key_downloads - ] - if not wait_on: - break - logger.info( - "Waiting for existing lookups for %s to complete [loop %i]", - [w[0] for w in wait_on], - loop_count, + if not verified: + raise SynapseError( + 401, + f"Failed to find any key to satisfy: {key_request}", + Codes.UNAUTHORIZED, ) - with PreserveLoggingContext(): - await defer.DeferredList((w[1] for w in wait_on)) - loop_count += 1 + async def _inner_fetch_key_requests( + self, requests: List[_FetchKeyRequest] + ) -> Dict[str, Dict[str, FetchKeyResult]]: + """Processing function for the queue of `_FetchKeyRequest`.""" + + logger.debug("Starting fetch for %s", requests) + + # First we need to deduplicate requests for the same key. We do this by + # taking the *maximum* requested `minimum_valid_until_ts` for each pair + # of server name/key ID. + server_to_key_to_ts = {} # type: Dict[str, Dict[str, int]] + for request in requests: + by_server = server_to_key_to_ts.setdefault(request.server_name, {}) + for key_id in request.key_ids: + existing_ts = by_server.get(key_id, 0) + by_server[key_id] = max(request.minimum_valid_until_ts, existing_ts) + + deduped_requests = [ + _FetchKeyRequest(server_name, minimum_valid_ts, [key_id]) + for server_name, by_server in server_to_key_to_ts.items() + for key_id, minimum_valid_ts in by_server.items() + ] + + logger.debug("Deduplicated key requests to %s", deduped_requests) + + # For each key we call `_inner_verify_request` which will handle + # fetching each key. Note these shouldn't throw if we fail to contact + # other servers etc. + results_per_request = await yieldable_gather_results( + self._inner_fetch_key_request, + deduped_requests, + ) - def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None: - """Tries to find at least one key for each verify request + # We now convert the returned list of results into a map from server + # name to key ID to FetchKeyResult, to return. + to_return = {} # type: Dict[str, Dict[str, FetchKeyResult]] + for (request, results) in zip(deduped_requests, results_per_request): + to_return_by_server = to_return.setdefault(request.server_name, {}) + for key_id, key_result in results.items(): + existing = to_return_by_server.get(key_id) + if not existing or existing.valid_until_ts < key_result.valid_until_ts: + to_return_by_server[key_id] = key_result - For each verify_request, verify_request.key_ready is called back with - params (server_name, key_id, VerifyKey) if a key is found, or errbacked - with a SynapseError if none of the keys are found. + return to_return - Args: - verify_requests: list of verify requests + async def _inner_fetch_key_request( + self, verify_request: _FetchKeyRequest + ) -> Dict[str, FetchKeyResult]: + """Attempt to fetch the given key by calling each key fetcher one by + one. """ + logger.debug("Starting fetch for %s", verify_request) - remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} + found_keys: Dict[str, FetchKeyResult] = {} + missing_key_ids = set(verify_request.key_ids) - async def do_iterations(): - try: - with Measure(self.clock, "get_server_verify_keys"): - for f in self._key_fetchers: - if not remaining_requests: - return - await self._attempt_key_fetches_with_fetcher( - f, remaining_requests - ) - - # look for any requests which weren't satisfied - while remaining_requests: - verify_request = remaining_requests.pop() - rq_str = ( - "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ) - ) - - # If we run the errback immediately, it may cancel our - # loggingcontext while we are still in it, so instead we - # schedule it for the next time round the reactor. - # - # (this also ensures that we don't get a stack overflow if we - # has a massive queue of lookups waiting for this server). - self.clock.call_later( - 0, - verify_request.key_ready.errback, - SynapseError( - 401, - "Failed to find any key to satisfy %s" % (rq_str,), - Codes.UNAUTHORIZED, - ), - ) - except Exception as err: - # we don't really expect to get here, because any errors should already - # have been caught and logged. But if we do, let's log the error and make - # sure that all of the deferreds are resolved. - logger.error("Unexpected error in _get_server_verify_keys: %s", err) - with PreserveLoggingContext(): - for verify_request in remaining_requests: - if not verify_request.key_ready.called: - verify_request.key_ready.errback(err) - - run_in_background(do_iterations) - - async def _attempt_key_fetches_with_fetcher( - self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest] - ): - """Use a key fetcher to attempt to satisfy some key requests + for fetcher in self._key_fetchers: + if not missing_key_ids: + break - Args: - fetcher: fetcher to use to fetch the keys - remaining_requests: outstanding key requests. - Any successfully-completed requests will be removed from the list. - """ - # The keys to fetch. - # server_name -> key_id -> min_valid_ts - missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]] - - for verify_request in remaining_requests: - # any completed requests should already have been removed - assert not verify_request.key_ready.called - keys_for_server = missing_keys[verify_request.server_name] - - for key_id in verify_request.key_ids: - # If we have several requests for the same key, then we only need to - # request that key once, but we should do so with the greatest - # min_valid_until_ts of the requests, so that we can satisfy all of - # the requests. - keys_for_server[key_id] = max( - keys_for_server.get(key_id, -1), - verify_request.minimum_valid_until_ts, - ) + logger.debug("Getting keys from %s for %s", fetcher, verify_request) + keys = await fetcher.get_keys( + verify_request.server_name, + list(missing_key_ids), + verify_request.minimum_valid_until_ts, + ) - results = await fetcher.get_keys(missing_keys) + for key_id, key in keys.items(): + if not key: + continue - completed = [] - for verify_request in remaining_requests: - server_name = verify_request.server_name + # If we already have a result for the given key ID we keep the + # one with the highest `valid_until_ts`. + existing_key = found_keys.get(key_id) + if existing_key: + if key.valid_until_ts <= existing_key.valid_until_ts: + continue - # see if any of the keys we got this time are sufficient to - # complete this VerifyJsonRequest. - result_keys = results.get(server_name, {}) - for key_id in verify_request.key_ids: - fetch_key_result = result_keys.get(key_id) - if not fetch_key_result: - # we didn't get a result for this key - continue + # We always store the returned key even if it doesn't the + # `minimum_valid_until_ts` requirement, as some verification + # requests may still be able to be satisfied by it. + # + # We still keep looking for the key from other fetchers in that + # case though. + found_keys[key_id] = key - if ( - fetch_key_result.valid_until_ts - < verify_request.minimum_valid_until_ts - ): - # key was not valid at this point + if key.valid_until_ts < verify_request.minimum_valid_until_ts: continue - # we have a valid key for this request. If we run the callback - # immediately, it may cancel our loggingcontext while we are still in - # it, so instead we schedule it for the next time round the reactor. - # - # (this also ensures that we don't get a stack overflow if we had - # a massive queue of lookups waiting for this server). - logger.debug( - "Found key %s:%s for %s", - server_name, - key_id, - verify_request.request_name, - ) - self.clock.call_later( - 0, - verify_request.key_ready.callback, - (server_name, key_id, fetch_key_result.verify_key), - ) - completed.append(verify_request) - break + missing_key_ids.discard(key_id) - remaining_requests.difference_update(completed) + return found_keys class KeyFetcher(metaclass=abc.ABCMeta): - @abc.abstractmethod + def __init__(self, hs: "HomeServer"): + self._queue = BatchingQueue( + self.__class__.__name__, hs.get_clock(), self._fetch_keys + ) + async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] - ) -> Dict[str, Dict[str, FetchKeyResult]]: - """ - Args: - keys_to_fetch: - the keys to be fetched. server_name -> key_id -> min_valid_ts + self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + results = await self._queue.add_to_queue( + _FetchKeyRequest( + server_name=server_name, + key_ids=key_ids, + minimum_valid_until_ts=minimum_valid_until_ts, + ) + ) + return results.get(server_name, {}) - Returns: - Map from server_name -> key_id -> FetchKeyResult - """ - raise NotImplementedError + @abc.abstractmethod + async def _fetch_keys( + self, keys_to_fetch: List[_FetchKeyRequest] + ) -> Dict[str, Dict[str, FetchKeyResult]]: + pass class StoreKeyFetcher(KeyFetcher): """KeyFetcher impl which fetches keys from our data store""" def __init__(self, hs: "HomeServer"): - self.store = hs.get_datastore() + super().__init__(hs) - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] - ) -> Dict[str, Dict[str, FetchKeyResult]]: - """see KeyFetcher.get_keys""" + self.store = hs.get_datastore() + async def _fetch_keys(self, keys_to_fetch: List[_FetchKeyRequest]): key_ids_to_fetch = ( - (server_name, key_id) - for server_name, keys_for_server in keys_to_fetch.items() - for key_id in keys_for_server.keys() + (queue_value.server_name, key_id) + for queue_value in keys_to_fetch + for key_id in queue_value.key_ids ) res = await self.store.get_server_verify_keys(key_ids_to_fetch) @@ -578,6 +485,8 @@ class StoreKeyFetcher(KeyFetcher): class BaseV2KeyFetcher(KeyFetcher): def __init__(self, hs: "HomeServer"): + super().__init__(hs) + self.store = hs.get_datastore() self.config = hs.config @@ -685,10 +594,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): self.client = hs.get_federation_http_client() self.key_servers = self.config.key_servers - async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] + async def _fetch_keys( + self, keys_to_fetch: List[_FetchKeyRequest] ) -> Dict[str, Dict[str, FetchKeyResult]]: - """see KeyFetcher.get_keys""" + """see KeyFetcher._fetch_keys""" async def get_key(key_server: TrustedKeyServer) -> Dict: try: @@ -724,12 +633,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return union_of_keys async def get_server_verify_key_v2_indirect( - self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer + self, keys_to_fetch: List[_FetchKeyRequest], key_server: TrustedKeyServer ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: keys_to_fetch: - the keys to be fetched. server_name -> key_id -> min_valid_ts + the keys to be fetched. key_server: notary server to query for the keys @@ -743,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): perspective_name = key_server.server_name logger.info( "Requesting keys %s from notary server %s", - keys_to_fetch.items(), + keys_to_fetch, perspective_name, ) @@ -753,11 +662,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): path="/_matrix/key/v2/query", data={ "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() + queue_value.server_name: { + key_id: { + "minimum_valid_until_ts": queue_value.minimum_valid_until_ts, + } + for key_id in queue_value.key_ids } - for server_name, server_keys in keys_to_fetch.items() + for queue_value in keys_to_fetch } }, ) @@ -858,7 +769,20 @@ class ServerKeyFetcher(BaseV2KeyFetcher): self.client = hs.get_federation_http_client() async def get_keys( - self, keys_to_fetch: Dict[str, Dict[str, int]] + self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + results = await self._queue.add_to_queue( + _FetchKeyRequest( + server_name=server_name, + key_ids=key_ids, + minimum_valid_until_ts=minimum_valid_until_ts, + ), + key=server_name, + ) + return results.get(server_name, {}) + + async def _fetch_keys( + self, keys_to_fetch: List[_FetchKeyRequest] ) -> Dict[str, Dict[str, FetchKeyResult]]: """ Args: @@ -871,8 +795,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher): results = {} - async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None: - server_name, key_ids = key_to_fetch_item + async def get_key(key_to_fetch_item: _FetchKeyRequest) -> None: + server_name = key_to_fetch_item.server_name + key_ids = key_to_fetch_item.key_ids + try: keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) results[server_name] = keys @@ -883,7 +809,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - await yieldable_gather_results(get_key, keys_to_fetch.items()) + await yieldable_gather_results(get_key, keys_to_fetch) return results async def get_server_verify_key_v2_direct( @@ -955,37 +881,3 @@ class ServerKeyFetcher(BaseV2KeyFetcher): keys.update(response_keys) return keys - - -async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None: - """Waits for the key to become available, and then performs a verification - - Args: - verify_request: - - Raises: - SynapseError if there was a problem performing the verification - """ - server_name = verify_request.server_name - with PreserveLoggingContext(): - _, key_id, verify_key = await verify_request.key_ready - - json_object = verify_request.get_json_object() - - try: - verify_signed_json(json_object, server_name, verify_key) - except SignatureVerifyException as e: - logger.debug( - "Error verifying signature for %s:%s:%s with key %s: %s", - server_name, - verify_key.alg, - verify_key.version, - encode_verify_key_base64(verify_key), - str(e), - ) - raise SynapseError( - 401, - "Invalid signature for server %s with key %s:%s: %s" - % (server_name, verify_key.alg, verify_key.version, str(e)), - Codes.UNAUTHORIZED, - ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index fdeaa0f37c..5756fcb551 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -152,7 +152,9 @@ class Authenticator: ) await self.keyring.verify_json_for_server( - origin, json_request, now, "Incoming request" + origin, + json_request, + now, ) logger.debug("Request from %s", origin) diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py index d2fc8be5f5..ff8372c4e9 100644 --- a/synapse/groups/attestations.py +++ b/synapse/groups/attestations.py @@ -108,7 +108,9 @@ class GroupAttestationSigning: assert server_name is not None await self.keyring.verify_json_for_server( - server_name, attestation, now, "Group attestation" + server_name, + attestation, + now, ) def create_attestation(self, group_id: str, user_id: str) -> JsonDict: diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index aba1734a55..d56a1ae482 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request from synapse.util import json_decoder +from synapse.util.async_helpers import yieldable_gather_results logger = logging.getLogger(__name__) @@ -210,7 +211,13 @@ class RemoteKey(DirectServeJsonResource): # If there is a cache miss, request the missing keys, then recurse (and # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: - await self.fetcher.get_keys(cache_misses) + await yieldable_gather_results( + lambda t: self.fetcher.get_keys(*t), + ( + (server_name, list(keys), 0) + for server_name, keys in cache_misses.items() + ), + ) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 2775dfd880..745c295d3b 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import time +from typing import Dict, List from unittest.mock import Mock import attr @@ -21,7 +22,6 @@ import signedjson.sign from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key -from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred from synapse.api.errors import SynapseError @@ -92,23 +92,23 @@ class KeyringTestCase(unittest.HomeserverTestCase): # deferred completes. first_lookup_deferred = Deferred() - async def first_lookup_fetch(keys_to_fetch): - self.assertEquals(current_context().request.id, "context_11") - self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}}) + async def first_lookup_fetch( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + # self.assertEquals(current_context().request.id, "context_11") + self.assertEqual(server_name, "server10") + self.assertEqual(key_ids, [get_key_id(key1)]) + self.assertEqual(minimum_valid_until_ts, 0) await make_deferred_yieldable(first_lookup_deferred) - return { - "server10": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) - } - } + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)} mock_fetcher.get_keys.side_effect = first_lookup_fetch async def first_lookup(): with LoggingContext("context_11", request=FakeRequest("context_11")): res_deferreds = kr.verify_json_objects_for_server( - [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")] + [("server10", json1, 0), ("server11", {}, 0)] ) # the unsigned json should be rejected pretty quickly @@ -126,18 +126,18 @@ class KeyringTestCase(unittest.HomeserverTestCase): d0 = ensureDeferred(first_lookup()) + self.pump() + mock_fetcher.get_keys.assert_called_once() # a second request for a server with outstanding requests # should block rather than start a second call - async def second_lookup_fetch(keys_to_fetch): - self.assertEquals(current_context().request.id, "context_12") - return { - "server10": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100) - } - } + async def second_lookup_fetch( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + # self.assertEquals(current_context().request.id, "context_12") + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)} mock_fetcher.get_keys.reset_mock() mock_fetcher.get_keys.side_effect = second_lookup_fetch @@ -146,7 +146,13 @@ class KeyringTestCase(unittest.HomeserverTestCase): async def second_lookup(): with LoggingContext("context_12", request=FakeRequest("context_12")): res_deferreds_2 = kr.verify_json_objects_for_server( - [("server10", json1, 0, "test")] + [ + ( + "server10", + json1, + 0, + ) + ] ) res_deferreds_2[0].addBoth(self.check_context, None) second_lookup_state[0] = 1 @@ -183,11 +189,11 @@ class KeyringTestCase(unittest.HomeserverTestCase): signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") + d = kr.verify_json_for_server("server9", {}, 0) self.get_failure(d, SynapseError) # should succeed on a signed object - d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") + d = kr.verify_json_for_server("server9", json1, 500) # self.assertFalse(d.called) self.get_success(d) @@ -214,24 +220,24 @@ class KeyringTestCase(unittest.HomeserverTestCase): signedjson.sign.sign_json(json1, "server9", key1) # should fail immediately on an unsigned object - d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") + d = kr.verify_json_for_server("server9", {}, 0) self.get_failure(d, SynapseError) # should fail on a signed object with a non-zero minimum_valid_until_ms, # as it tries to refetch the keys and fails. - d = _verify_json_for_server( - kr, "server9", json1, 500, "test signed non-zero min" - ) + d = kr.verify_json_for_server("server9", json1, 500) self.get_failure(d, SynapseError) # We expect the keyring tried to refetch the key once. mock_fetcher.get_keys.assert_called_once_with( - {"server9": {get_key_id(key1): 500}} + "server9", [get_key_id(key1)], 500 ) # should succeed on a signed object with a 0 minimum_valid_until_ms - d = _verify_json_for_server( - kr, "server9", json1, 0, "test signed with zero min" + d = kr.verify_json_for_server( + "server9", + json1, + 0, ) self.get_success(d) @@ -239,15 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Two requests for the same key should be deduped.""" key1 = signedjson.key.generate_signing_key(1) - async def get_keys(keys_to_fetch): + async def get_keys( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: # there should only be one request object (with the max validity) - self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) + self.assertEqual(server_name, "server1") + self.assertEqual(key_ids, [get_key_id(key1)]) + self.assertEqual(minimum_valid_until_ts, 1500) - return { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } - } + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)} mock_fetcher = Mock() mock_fetcher.get_keys = Mock(side_effect=get_keys) @@ -259,7 +265,14 @@ class KeyringTestCase(unittest.HomeserverTestCase): # the first request should succeed; the second should fail because the key # has expired results = kr.verify_json_objects_for_server( - [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")] + [ + ( + "server1", + json1, + 500, + ), + ("server1", json1, 1500), + ] ) self.assertEqual(len(results), 2) self.get_success(results[0]) @@ -274,19 +287,21 @@ class KeyringTestCase(unittest.HomeserverTestCase): """If the first fetcher cannot provide a recent enough key, we fall back""" key1 = signedjson.key.generate_signing_key(1) - async def get_keys1(keys_to_fetch): - self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return { - "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)} - } - - async def get_keys2(keys_to_fetch): - self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } - } + async def get_keys1( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + self.assertEqual(server_name, "server1") + self.assertEqual(key_ids, [get_key_id(key1)]) + self.assertEqual(minimum_valid_until_ts, 1500) + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)} + + async def get_keys2( + server_name: str, key_ids: List[str], minimum_valid_until_ts: int + ) -> Dict[str, FetchKeyResult]: + self.assertEqual(server_name, "server1") + self.assertEqual(key_ids, [get_key_id(key1)]) + self.assertEqual(minimum_valid_until_ts, 1500) + return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)} mock_fetcher1 = Mock() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) @@ -298,7 +313,18 @@ class KeyringTestCase(unittest.HomeserverTestCase): signedjson.sign.sign_json(json1, "server1", key1) results = kr.verify_json_objects_for_server( - [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")] + [ + ( + "server1", + json1, + 1200, + ), + ( + "server1", + json1, + 1500, + ), + ] ) self.assertEqual(len(results), 2) self.get_success(results[0]) @@ -349,9 +375,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): self.http_client.get_json.side_effect = get_json - keys_to_fetch = {SERVER_NAME: {"key1": 0}} - keys = self.get_success(fetcher.get_keys(keys_to_fetch)) - k = keys[SERVER_NAME][testverifykey_id] + keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) + k = keys[testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key.alg, "ed25519") @@ -378,7 +403,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): # change the server name: the result should be ignored response["server_name"] = "OTHER_SERVER" - keys = self.get_success(fetcher.get_keys(keys_to_fetch)) + keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) self.assertEqual(keys, {}) @@ -465,10 +490,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.expect_outgoing_key_query(SERVER_NAME, "key1", response) - keys_to_fetch = {SERVER_NAME: {"key1": 0}} - keys = self.get_success(fetcher.get_keys(keys_to_fetch)) - self.assertIn(SERVER_NAME, keys) - k = keys[SERVER_NAME][testverifykey_id] + keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) + self.assertIn(testverifykey_id, keys) + k = keys[testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key.alg, "ed25519") @@ -515,10 +539,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): self.expect_outgoing_key_query(SERVER_NAME, "key1", response) - keys_to_fetch = {SERVER_NAME: {"key1": 0}} - keys = self.get_success(fetcher.get_keys(keys_to_fetch)) - self.assertIn(SERVER_NAME, keys) - k = keys[SERVER_NAME][testverifykey_id] + keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) + self.assertIn(testverifykey_id, keys) + k = keys[testverifykey_id] self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS) self.assertEqual(k.verify_key, testverifykey) self.assertEqual(k.verify_key.alg, "ed25519") @@ -559,14 +582,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def get_key_from_perspectives(response): fetcher = PerspectivesKeyFetcher(self.hs) - keys_to_fetch = {SERVER_NAME: {"key1": 0}} self.expect_outgoing_key_query(SERVER_NAME, "key1", response) - return self.get_success(fetcher.get_keys(keys_to_fetch)) + return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0)) # start with a valid response so we can check we are testing the right thing response = build_response() keys = get_key_from_perspectives(response) - k = keys[SERVER_NAME][testverifykey_id] + k = keys[testverifykey_id] self.assertEqual(k.verify_key, testverifykey) # remove the perspectives server's signature @@ -585,23 +607,3 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): def get_key_id(key): """Get the matrix ID tag for a given SigningKey or VerifyKey""" return "%s:%s" % (key.alg, key.version) - - -@defer.inlineCallbacks -def run_in_context(f, *args, **kwargs): - with LoggingContext("testctx"): - rv = yield f(*args, **kwargs) - return rv - - -def _verify_json_for_server(kr, *args): - """thin wrapper around verify_json_for_server which makes sure it is wrapped - with the patched defer.inlineCallbacks. - """ - - @defer.inlineCallbacks - def v(): - rv1 = yield kr.verify_json_for_server(*args) - return rv1 - - return run_in_context(v) diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 3b275bc23b..a75c0ea3f0 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -208,10 +208,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): keyid = "ed25519:%s" % (testkey.version,) fetcher = PerspectivesKeyFetcher(self.hs2) - d = fetcher.get_keys({"targetserver": {keyid: 1000}}) + d = fetcher.get_keys("targetserver", [keyid], 1000) res = self.get_success(d) - self.assertIn("targetserver", res) - keyres = res["targetserver"][keyid] + self.assertIn(keyid, res) + keyres = res[keyid] assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), @@ -230,10 +230,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): keyid = "ed25519:%s" % (testkey.version,) fetcher = PerspectivesKeyFetcher(self.hs2) - d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}}) + d = fetcher.get_keys(self.hs.hostname, [keyid], 1000) res = self.get_success(d) - self.assertIn(self.hs.hostname, res) - keyres = res[self.hs.hostname][keyid] + self.assertIn(keyid, res) + keyres = res[keyid] assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), @@ -247,10 +247,10 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): keyid = "ed25519:%s" % (self.hs_signing_key.version,) fetcher = PerspectivesKeyFetcher(self.hs2) - d = fetcher.get_keys({self.hs.hostname: {keyid: 1000}}) + d = fetcher.get_keys(self.hs.hostname, [keyid], 1000) res = self.get_success(d) - self.assertIn(self.hs.hostname, res) - keyres = res[self.hs.hostname][keyid] + self.assertIn(keyid, res) + keyres = res[keyid] assert isinstance(keyres, FetchKeyResult) self.assertEqual( signedjson.key.encode_verify_key_base64(keyres.verify_key), diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index edf29e5b96..07be57d72c 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -45,37 +45,32 @@ class BatchingQueueTestCase(TestCase): self._pending_calls.append((values, d)) return await make_deferred_yieldable(d) + def _get_sample_with_name(self, metric, name) -> int: + """For a prometheus metric get the value of the sample that has a + matching "name" label. + """ + for sample in metric.collect()[0].samples: + if sample.labels.get("name") == name: + return sample.value + + self.fail("Found no matching sample") + def _assert_metrics(self, queued, keys, in_flight): """Assert that the metrics are correct""" - self.assertEqual(len(number_queued.collect()), 1) - self.assertEqual(len(number_queued.collect()[0].samples), 1) + sample = self._get_sample_with_name(number_queued, self.queue._name) self.assertEqual( - number_queued.collect()[0].samples[0].labels, - {"name": self.queue._name}, - ) - self.assertEqual( - number_queued.collect()[0].samples[0].value, + sample, queued, "number_queued", ) - self.assertEqual(len(number_of_keys.collect()), 1) - self.assertEqual(len(number_of_keys.collect()[0].samples), 1) - self.assertEqual( - number_queued.collect()[0].samples[0].labels, {"name": self.queue._name} - ) - self.assertEqual( - number_of_keys.collect()[0].samples[0].value, keys, "number_of_keys" - ) + sample = self._get_sample_with_name(number_of_keys, self.queue._name) + self.assertEqual(sample, keys, "number_of_keys") - self.assertEqual(len(number_in_flight.collect()), 1) - self.assertEqual(len(number_in_flight.collect()[0].samples), 1) - self.assertEqual( - number_queued.collect()[0].samples[0].labels, {"name": self.queue._name} - ) + sample = self._get_sample_with_name(number_in_flight, self.queue._name) self.assertEqual( - number_in_flight.collect()[0].samples[0].value, + sample, in_flight, "number_in_flight", ) -- cgit 1.5.1 From 0284d2a2976e3d58e9970fdb7590f98a2556326d Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 2 Jun 2021 19:50:35 +0200 Subject: Add new admin APIs to remove media by media ID from quarantine. (#10044) Related to: #6681, #5956, #10040 Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10044.feature | 1 + docs/admin_api/media_admin_api.md | 22 ++++++ synapse/rest/admin/media.py | 30 ++++++++ synapse/storage/databases/main/room.py | 30 +++++--- tests/rest/admin/test_media.py | 128 +++++++++++++++++++++++++++++++++ 5 files changed, 201 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10044.feature (limited to 'tests') diff --git a/changelog.d/10044.feature b/changelog.d/10044.feature new file mode 100644 index 0000000000..70c0a3851e --- /dev/null +++ b/changelog.d/10044.feature @@ -0,0 +1 @@ +Add new admin APIs to remove media by media ID from quarantine. Contributed by @dkimpel. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index d1b7e390d5..7709f3d8c7 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -4,6 +4,7 @@ * [List all media uploaded by a user](#list-all-media-uploaded-by-a-user) - [Quarantine media](#quarantine-media) * [Quarantining media by ID](#quarantining-media-by-id) + * [Remove media from quarantine by ID](#remove-media-from-quarantine-by-id) * [Quarantining media in a room](#quarantining-media-in-a-room) * [Quarantining all media of a user](#quarantining-all-media-of-a-user) * [Protecting media from being quarantined](#protecting-media-from-being-quarantined) @@ -77,6 +78,27 @@ Response: {} ``` +## Remove media from quarantine by ID + +This API removes a single piece of local or remote media from quarantine. + +Request: + +``` +POST /_synapse/admin/v1/media/unquarantine// + +{} +``` + +Where `server_name` is in the form of `example.org`, and `media_id` is in the +form of `abcdefg12345...`. + +Response: + +```json +{} +``` + ## Quarantining media in a room This API quarantines all local and remote media in a room. diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 2c71af4279..b68db2c57c 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -120,6 +120,35 @@ class QuarantineMediaByID(RestServlet): return 200, {} +class UnquarantineMediaByID(RestServlet): + """Quarantines local or remote media by a given ID so that no one can download + it via this server. + """ + + PATTERNS = admin_patterns( + "/media/unquarantine/(?P[^/]+)/(?P[^/]+)" + ) + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST( + self, request: SynapseRequest, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info( + "Remove from quarantine local media by ID: %s/%s", server_name, media_id + ) + + # Remove from quarantine this media id + await self.store.quarantine_media_by_id(server_name, media_id, None) + + return 200, {} + + class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" @@ -290,6 +319,7 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server): PurgeMediaCacheRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaByID(hs).register(http_server) + UnquarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) ProtectMediaByID(hs).register(http_server) UnprotectMediaByID(hs).register(http_server) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 0cf450f81d..2a96bcd314 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -764,14 +764,15 @@ class RoomWorkerStore(SQLBaseStore): self, server_name: str, media_id: str, - quarantined_by: str, + quarantined_by: Optional[str], ) -> int: - """quarantines a single local or remote media id + """quarantines or unquarantines a single local or remote media id Args: server_name: The name of the server that holds this media media_id: The ID of the media to be quarantined quarantined_by: The user ID that initiated the quarantine request + If it is `None` media will be removed from quarantine """ logger.info("Quarantining media: %s/%s", server_name, media_id) is_local = server_name == self.config.server_name @@ -838,9 +839,9 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs: List[str], remote_mxcs: List[Tuple[str, str]], - quarantined_by: str, + quarantined_by: Optional[str], ) -> int: - """Quarantine local and remote media items + """Quarantine and unquarantine local and remote media items Args: txn (cursor) @@ -848,18 +849,27 @@ class RoomWorkerStore(SQLBaseStore): remote_mxcs: A list of (remote server, media id) tuples representing remote mxc URLs quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine Returns: The total number of media items quarantined """ + # Update all the tables to set the quarantined_by flag - txn.executemany( - """ + sql = """ UPDATE local_media_repository SET quarantined_by = ? - WHERE media_id = ? AND safe_from_quarantine = ? - """, - ((quarantined_by, media_id, False) for media_id in local_mxcs), - ) + WHERE media_id = ? + """ + + # set quarantine + if quarantined_by is not None: + sql += "AND safe_from_quarantine = ?" + rows = [(quarantined_by, media_id, False) for media_id in local_mxcs] + # remove from quarantine + else: + rows = [(quarantined_by, media_id) for media_id in local_mxcs] + + txn.executemany(sql, rows) # Note that a rowcount of -1 can be used to indicate no rows were affected. total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index f741121ea2..6fee0f95b6 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -566,6 +566,134 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertFalse(os.path.exists(local_path)) +class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + media_repo = hs.get_media_repository_resource() + self.store = hs.get_datastore() + self.server_name = hs.hostname + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create media + upload_resource = media_repo.children[b"upload"] + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + self.media_id = server_and_media_id.split("/")[1] + + self.url = "/_synapse/admin/v1/media/%s/%s/%s" + + @parameterized.expand(["quarantine", "unquarantine"]) + def test_no_auth(self, action: str): + """ + Try to protect media without authentication. + """ + + channel = self.make_request( + "POST", + self.url % (action, self.server_name, self.media_id), + b"{}", + ) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + @parameterized.expand(["quarantine", "unquarantine"]) + def test_requester_is_no_admin(self, action: str): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + channel = self.make_request( + "POST", + self.url % (action, self.server_name, self.media_id), + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_quarantine_media(self): + """ + Tests that quarantining and remove from quarantine a media is successfully + """ + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["quarantined_by"]) + + # quarantining + channel = self.make_request( + "POST", + self.url % ("quarantine", self.server_name, self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertTrue(media_info["quarantined_by"]) + + # remove from quarantine + channel = self.make_request( + "POST", + self.url % ("unquarantine", self.server_name, self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["quarantined_by"]) + + def test_quarantine_protected_media(self): + """ + Tests that quarantining from protected media fails + """ + + # protect + self.get_success(self.store.mark_local_media_as_safe(self.media_id, safe=True)) + + # verify protection + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertTrue(media_info["safe_from_quarantine"]) + + # quarantining + channel = self.make_request( + "POST", + self.url % ("quarantine", self.server_name, self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + # verify that is not in quarantine + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["quarantined_by"]) + + class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From 5325f0308c5937d2e4447d2c64c8819b3c148d9c Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Thu, 3 Jun 2021 06:50:49 -0600 Subject: r0.6.1 support: /rooms/:roomId/aliases endpoint (#9224) [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432) added this endpoint originally but it has since been included in the spec for nearly a year. This is progress towards https://github.com/matrix-org/synapse/issues/8334 --- changelog.d/9224.feature | 1 + synapse/rest/client/v1/room.py | 2 +- tests/rest/client/v1/test_rooms.py | 3 +-- 3 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 changelog.d/9224.feature (limited to 'tests') diff --git a/changelog.d/9224.feature b/changelog.d/9224.feature new file mode 100644 index 0000000000..76519c23e2 --- /dev/null +++ b/changelog.d/9224.feature @@ -0,0 +1 @@ +Add new endpoint `/_matrix/client/r0/rooms/{roomId}/aliases` from Client-Server API r0.6.1 (previously [MSC2432](https://github.com/matrix-org/matrix-doc/pull/2432)). diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 70286b0ff7..5a9c27f75f 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -910,7 +910,7 @@ class RoomAliasListServlet(RestServlet): r"^/_matrix/client/unstable/org\.matrix\.msc2432" r"/rooms/(?P[^/]*)/aliases" ), - ] + ] + list(client_patterns("/rooms/(?P[^/]*)/aliases$", unstable=False)) def __init__(self, hs: "HomeServer"): super().__init__() diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 7c4bdcdfdd..5b1096d091 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -1880,8 +1880,7 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase): """Calls the endpoint under test. returns the json response object.""" channel = self.make_request( "GET", - "/_matrix/client/unstable/org.matrix.msc2432/rooms/%s/aliases" - % (self.room_id,), + "/_matrix/client/r0/rooms/%s/aliases" % (self.room_id,), access_token=access_token, ) self.assertEqual(channel.code, expected_code, channel.result) -- cgit 1.5.1 From 8942e23a6941dc740f6f703f7e353f273874f104 Mon Sep 17 00:00:00 2001 From: 14mRh4X0r <14mRh4X0r@gmail.com> Date: Mon, 7 Jun 2021 16:42:05 +0200 Subject: Always update AS last_pos, even on no events (#10107) Fixes #1834. `get_new_events_for_appservice` internally calls `get_events_as_list`, which will filter out any rejected events. If all returned events are filtered out, `_notify_interested_services` will return without updating the last handled stream position. If there are 100 consecutive such events, processing will halt altogether. Breaking the loop is now done by checking whether we're up-to-date with `current_max` in the loop condition, instead of relying on an empty `events` list. Signed-off-by: Willem Mulder <14mRh4X0r@gmail.com> --- changelog.d/10107.bugfix | 1 + synapse/handlers/appservice.py | 25 ++++++++++++------------- tests/handlers/test_appservice.py | 6 ++---- 3 files changed, 15 insertions(+), 17 deletions(-) create mode 100644 changelog.d/10107.bugfix (limited to 'tests') diff --git a/changelog.d/10107.bugfix b/changelog.d/10107.bugfix new file mode 100644 index 0000000000..80030efab2 --- /dev/null +++ b/changelog.d/10107.bugfix @@ -0,0 +1 @@ +Fixed a bug that could cause Synapse to stop notifying application services. Contributed by Willem Mulder. diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 177310f0be..862638cc4f 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -87,7 +87,8 @@ class ApplicationServicesHandler: self.is_processing = True try: limit = 100 - while True: + upper_bound = -1 + while upper_bound < self.current_max: ( upper_bound, events, @@ -95,9 +96,6 @@ class ApplicationServicesHandler: self.current_max, limit ) - if not events: - break - events_by_room = {} # type: Dict[str, List[EventBase]] for event in events: events_by_room.setdefault(event.room_id, []).append(event) @@ -153,9 +151,6 @@ class ApplicationServicesHandler: await self.store.set_appservice_last_pos(upper_bound) - now = self.clock.time_msec() - ts = await self.store.get_received_ts(events[-1].event_id) - synapse.metrics.event_processing_positions.labels( "appservice_sender" ).set(upper_bound) @@ -168,12 +163,16 @@ class ApplicationServicesHandler: event_processing_loop_counter.labels("appservice_sender").inc() - synapse.metrics.event_processing_lag.labels( - "appservice_sender" - ).set(now - ts) - synapse.metrics.event_processing_last_ts.labels( - "appservice_sender" - ).set(ts) + if events: + now = self.clock.time_msec() + ts = await self.store.get_received_ts(events[-1].event_id) + + synapse.metrics.event_processing_lag.labels( + "appservice_sender" + ).set(now - ts) + synapse.metrics.event_processing_last_ts.labels( + "appservice_sender" + ).set(ts) finally: self.is_processing = False diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index b037b12a0f..5d6cc2885f 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -57,10 +57,10 @@ class AppServiceHandlerTestCase(unittest.TestCase): sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar" ) self.mock_store.get_new_events_for_appservice.side_effect = [ - make_awaitable((0, [event])), make_awaitable((0, [])), + make_awaitable((1, [event])), ] - self.handler.notify_interested_services(RoomStreamToken(None, 0)) + self.handler.notify_interested_services(RoomStreamToken(None, 1)) self.mock_scheduler.submit_event_for_as.assert_called_once_with( interested_service, event @@ -77,7 +77,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_new_events_for_appservice.side_effect = [ make_awaitable((0, [event])), - make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) @@ -95,7 +94,6 @@ class AppServiceHandlerTestCase(unittest.TestCase): self.mock_as_api.query_user.return_value = make_awaitable(True) self.mock_store.get_new_events_for_appservice.side_effect = [ make_awaitable((0, [event])), - make_awaitable((0, [])), ] self.handler.notify_interested_services(RoomStreamToken(None, 0)) -- cgit 1.5.1