From 65e6c64d8317d3a10527a7e422753281f3e9ec81 Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 26 May 2021 12:19:47 +0200 Subject: Add an admin API for unprotecting local media from quarantine (#10040) Signed-off-by: Dirk Klimpel dirk@klimpel.org --- synapse/storage/databases/main/media_repository.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'synapse/storage/databases/main') diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index c584868188..2fa945d171 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -143,6 +143,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): "created_ts", "quarantined_by", "url_cache", + "safe_from_quarantine", ), allow_none=True, desc="get_local_media", @@ -296,12 +297,12 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): desc="store_local_media", ) - async def mark_local_media_as_safe(self, media_id: str) -> None: - """Mark a local media as safe from quarantining.""" + async def mark_local_media_as_safe(self, media_id: str, safe: bool = True) -> None: + """Mark a local media as safe or unsafe from quarantining.""" await self.db_pool.simple_update_one( table="local_media_repository", keyvalues={"media_id": media_id}, - updatevalues={"safe_from_quarantine": True}, + updatevalues={"safe_from_quarantine": safe}, desc="mark_local_media_as_safe", ) -- cgit 1.5.1 From 224f2f949b1661094a64d1105efb64159ddf4aa0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 27 May 2021 10:33:56 +0100 Subject: Combine `LruCache.invalidate` and `invalidate_many` (#9973) * Make `invalidate` and `invalidate_many` do the same thing ... so that we can do either over the invalidation replication stream, and also because they always confused me a bit. * Kill off `invalidate_many` * changelog --- changelog.d/9973.misc | 1 + synapse/replication/slave/storage/devices.py | 2 +- synapse/storage/databases/main/cache.py | 6 ++-- synapse/storage/databases/main/devices.py | 2 +- .../storage/databases/main/event_push_actions.py | 2 +- synapse/storage/databases/main/events.py | 8 ++--- synapse/storage/databases/main/receipts.py | 6 ++-- synapse/util/caches/deferred_cache.py | 42 +++++++++------------- synapse/util/caches/descriptors.py | 8 +++-- synapse/util/caches/lrucache.py | 18 ++++++---- synapse/util/caches/treecache.py | 3 ++ tests/util/caches/test_descriptors.py | 6 ++-- 12 files changed, 52 insertions(+), 52 deletions(-) create mode 100644 changelog.d/9973.misc (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/9973.misc b/changelog.d/9973.misc new file mode 100644 index 0000000000..7f22d42291 --- /dev/null +++ b/changelog.d/9973.misc @@ -0,0 +1 @@ +Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`. diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 70207420a6..26bdead565 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -68,7 +68,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto if row.entity.startswith("@"): self._device_list_stream_cache.entity_has_changed(row.entity, token) self.get_cached_devices_for_user.invalidate((row.entity,)) - self._get_cached_user_device.invalidate_many((row.entity,)) + self._get_cached_user_device.invalidate((row.entity,)) self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) else: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ecc1f935e2..f7872501a0 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -171,7 +171,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_latest_event_ids_in_room.invalidate((room_id,)) - self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,)) + self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) @@ -184,8 +184,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_invited_rooms_for_local_user.invalidate((state_key,)) if relates_to: - self.get_relations_for_event.invalidate_many((relates_to,)) - self.get_aggregation_groups_for_event.invalidate_many((relates_to,)) + self.get_relations_for_event.invalidate((relates_to,)) + self.get_aggregation_groups_for_event.invalidate((relates_to,)) self.get_applicable_edit.invalidate((relates_to,)) async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]): diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fd87ba71ab..18f07d96dc 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -1282,7 +1282,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): ) txn.call_after(self.get_cached_devices_for_user.invalidate, (user_id,)) - txn.call_after(self._get_cached_user_device.invalidate_many, (user_id,)) + txn.call_after(self._get_cached_user_device.invalidate, (user_id,)) txn.call_after( self.get_device_list_last_stream_id_for_remote.invalidate, (user_id,) ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 5845322118..d1237c65cc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -860,7 +860,7 @@ class EventPushActionsWorkerStore(SQLBaseStore): not be deleted. """ txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id, user_id), ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index fd25c8112d..897fa06639 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1748,9 +1748,9 @@ class PersistEventsStore: }, ) - txn.call_after(self.store.get_relations_for_event.invalidate_many, (parent_id,)) + txn.call_after(self.store.get_relations_for_event.invalidate, (parent_id,)) txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate_many, (parent_id,) + self.store.get_aggregation_groups_for_event.invalidate, (parent_id,) ) if rel_type == RelationTypes.REPLACE: @@ -1903,7 +1903,7 @@ class PersistEventsStore: for user_id in user_ids: txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.store.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id, user_id), ) @@ -1917,7 +1917,7 @@ class PersistEventsStore: def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id): # Sad that we have to blow away the cache for the whole room here txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate_many, + self.store.get_unread_event_push_actions_by_room_for_user.invalidate, (room_id,), ) txn.execute( diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 3647276acb..edeaacd7a6 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -460,7 +460,7 @@ class ReceiptsWorkerStore(SQLBaseStore): def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate_many((room_id,)) + self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( (user_id, room_id, receipt_type) ) @@ -659,9 +659,7 @@ class ReceiptsWorkerStore(SQLBaseStore): ) txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) # FIXME: This shouldn't invalidate the whole cache - txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) - ) + txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) self.db_pool.simple_delete_txn( txn, diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 371e7e4dd0..1044139119 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -16,16 +16,7 @@ import enum import threading -from typing import ( - Callable, - Generic, - Iterable, - MutableMapping, - Optional, - TypeVar, - Union, - cast, -) +from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, Union from prometheus_client import Gauge @@ -91,7 +82,7 @@ class DeferredCache(Generic[KT, VT]): # _pending_deferred_cache maps from the key value to a `CacheEntry` object. self._pending_deferred_cache = ( cache_type() - ) # type: MutableMapping[KT, CacheEntry] + ) # type: Union[TreeCache, MutableMapping[KT, CacheEntry]] def metrics_cb(): cache_pending_metric.labels(name).set(len(self._pending_deferred_cache)) @@ -287,8 +278,17 @@ class DeferredCache(Generic[KT, VT]): self.cache.set(key, value, callbacks=callbacks) def invalidate(self, key): + """Delete a key, or tree of entries + + If the cache is backed by a regular dict, then "key" must be of + the right type for this cache + + If the cache is backed by a TreeCache, then "key" must be a tuple, but + may be of lower cardinality than the TreeCache - in which case the whole + subtree is deleted. + """ self.check_thread() - self.cache.pop(key, None) + self.cache.del_multi(key) # if we have a pending lookup for this key, remove it from the # _pending_deferred_cache, which will (a) stop it being returned @@ -299,20 +299,10 @@ class DeferredCache(Generic[KT, VT]): # run the invalidation callbacks now, rather than waiting for the # deferred to resolve. if entry: - entry.invalidate() - - def invalidate_many(self, key: KT): - self.check_thread() - if not isinstance(key, tuple): - raise TypeError("The cache key must be a tuple not %r" % (type(key),)) - key = cast(KT, key) - self.cache.del_multi(key) - - # if we have a pending lookup for this key, remove it from the - # _pending_deferred_cache, as above - entry_dict = self._pending_deferred_cache.pop(key, None) - if entry_dict is not None: - for entry in iterate_tree_cache_entry(entry_dict): + # _pending_deferred_cache.pop should either return a CacheEntry, or, in the + # case of a TreeCache, a dict of keys to cache entries. Either way calling + # iterate_tree_cache_entry on it will do the right thing. + for entry in iterate_tree_cache_entry(entry): entry.invalidate() def invalidate_all(self): diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 2ac24a2f25..d77e8edeea 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -48,7 +48,6 @@ F = TypeVar("F", bound=Callable[..., Any]) class _CachedFunction(Generic[F]): invalidate = None # type: Any invalidate_all = None # type: Any - invalidate_many = None # type: Any prefill = None # type: Any cache = None # type: Any num_args = None # type: Any @@ -262,6 +261,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): ): super().__init__(orig, num_args=num_args, cache_context=cache_context) + if tree and self.num_args < 2: + raise RuntimeError( + "tree=True is nonsensical for cached functions with a single parameter" + ) + self.max_entries = max_entries self.tree = tree self.iterable = iterable @@ -302,11 +306,11 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): wrapped = cast(_CachedFunction, _wrapped) if self.num_args == 1: + assert not self.tree wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.prefill = lambda key, val: cache.prefill(key[0], val) else: wrapped.invalidate = cache.invalidate - wrapped.invalidate_many = cache.invalidate_many wrapped.prefill = cache.prefill wrapped.invalidate_all = cache.invalidate_all diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 54df407ff7..d89e9d9b1d 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -152,7 +152,6 @@ class LruCache(Generic[KT, VT]): """ Least-recently-used cache, supporting prometheus metrics and invalidation callbacks. - Supports del_multi only if cache_type=TreeCache If cache_type=TreeCache, all keys must be tuples. """ @@ -393,10 +392,16 @@ class LruCache(Generic[KT, VT]): @synchronized def cache_del_multi(key: KT) -> None: + """Delete an entry, or tree of entries + + If the LruCache is backed by a regular dict, then "key" must be of + the right type for this cache + + If the LruCache is backed by a TreeCache, then "key" must be a tuple, but + may be of lower cardinality than the TreeCache - in which case the whole + subtree is deleted. """ - This will only work if constructed with cache_type=TreeCache - """ - popped = cache.pop(key) + popped = cache.pop(key, None) if popped is None: return # for each deleted node, we now need to remove it from the linked list @@ -430,11 +435,10 @@ class LruCache(Generic[KT, VT]): self.set = cache_set self.setdefault = cache_set_default self.pop = cache_pop + self.del_multi = cache_del_multi # `invalidate` is exposed for consistency with DeferredCache, so that it can be # invalidated by the cache invalidation replication stream. - self.invalidate = cache_pop - if cache_type is TreeCache: - self.del_multi = cache_del_multi + self.invalidate = cache_del_multi self.len = synchronized(cache_len) self.contains = cache_contains self.clear = cache_clear diff --git a/synapse/util/caches/treecache.py b/synapse/util/caches/treecache.py index 73502a8b06..a6df81ebff 100644 --- a/synapse/util/caches/treecache.py +++ b/synapse/util/caches/treecache.py @@ -89,6 +89,9 @@ class TreeCache: value. If the key is partial, the TreeCacheNode corresponding to the part of the tree that was removed. """ + if not isinstance(key, tuple): + raise TypeError("The cache key must be a tuple not %r" % (type(key),)) + # a list of the nodes we have touched on the way down the tree nodes = [] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index bbbc276697..0277998cbe 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -622,17 +622,17 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): self.assertEquals(callcount2[0], 1) a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 1) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 1) yield a.func2("foo") a.func2.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 2) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 2) self.assertEquals(callcount[0], 1) self.assertEquals(callcount2[0], 2) a.func.invalidate(("foo",)) - self.assertEquals(a.func2.cache.cache.pop.call_count, 3) + self.assertEquals(a.func2.cache.cache.del_multi.call_count, 3) yield a.func("foo") self.assertEquals(callcount[0], 2) -- cgit 1.5.1 From 8fb9af570f942d2057e8acb4a047d61ed7048f58 Mon Sep 17 00:00:00 2001 From: Callum Brown Date: Thu, 27 May 2021 18:42:23 +0100 Subject: Make reason and score optional for report_event (#10077) Implements MSC2414: https://github.com/matrix-org/matrix-doc/pull/2414 See #8551 Signed-off-by: Callum Brown --- changelog.d/10077.feature | 1 + docs/admin_api/event_reports.md | 4 +- synapse/rest/client/v2_alpha/report_event.py | 13 ++-- synapse/storage/databases/main/room.py | 2 +- tests/rest/admin/test_event_reports.py | 15 ++++- tests/rest/client/v2_alpha/test_report_event.py | 83 +++++++++++++++++++++++++ 6 files changed, 105 insertions(+), 13 deletions(-) create mode 100644 changelog.d/10077.feature create mode 100644 tests/rest/client/v2_alpha/test_report_event.py (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10077.feature b/changelog.d/10077.feature new file mode 100644 index 0000000000..808feb2215 --- /dev/null +++ b/changelog.d/10077.feature @@ -0,0 +1 @@ +Make reason and score parameters optional for reporting content. Implements [MSC2414](https://github.com/matrix-org/matrix-doc/pull/2414). Contributed by Callum Brown. diff --git a/docs/admin_api/event_reports.md b/docs/admin_api/event_reports.md index 0159098138..bfec06f755 100644 --- a/docs/admin_api/event_reports.md +++ b/docs/admin_api/event_reports.md @@ -75,9 +75,9 @@ The following fields are returned in the JSON response body: * `name`: string - The name of the room. * `event_id`: string - The ID of the reported event. * `user_id`: string - This is the user who reported the event and wrote the reason. -* `reason`: string - Comment made by the `user_id` in this report. May be blank. +* `reason`: string - Comment made by the `user_id` in this report. May be blank or `null`. * `score`: integer - Content is reported based upon a negative score, where -100 is - "most offensive" and 0 is "inoffensive". + "most offensive" and 0 is "inoffensive". May be `null`. * `sender`: string - This is the ID of the user who sent the original message/event that was reported. * `canonical_alias`: string - The canonical alias of the room. `null` if the room does not diff --git a/synapse/rest/client/v2_alpha/report_event.py b/synapse/rest/client/v2_alpha/report_event.py index 2c169abbf3..07ea39a8a3 100644 --- a/synapse/rest/client/v2_alpha/report_event.py +++ b/synapse/rest/client/v2_alpha/report_event.py @@ -16,11 +16,7 @@ import logging from http import HTTPStatus from synapse.api.errors import Codes, SynapseError -from synapse.http.servlet import ( - RestServlet, - assert_params_in_dict, - parse_json_object_from_request, -) +from synapse.http.servlet import RestServlet, parse_json_object_from_request from ._base import client_patterns @@ -42,15 +38,14 @@ class ReportEventRestServlet(RestServlet): user_id = requester.user.to_string() body = parse_json_object_from_request(request) - assert_params_in_dict(body, ("reason", "score")) - if not isinstance(body["reason"], str): + if not isinstance(body.get("reason", ""), str): raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'reason' must be a string", Codes.BAD_JSON, ) - if not isinstance(body["score"], int): + if not isinstance(body.get("score", 0), int): raise SynapseError( HTTPStatus.BAD_REQUEST, "Param 'score' must be an integer", @@ -61,7 +56,7 @@ class ReportEventRestServlet(RestServlet): room_id=room_id, event_id=event_id, user_id=user_id, - reason=body["reason"], + reason=body.get("reason"), content=body, received_ts=self.clock.time_msec(), ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 5f38634f48..0cf450f81d 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): room_id: str, event_id: str, user_id: str, - reason: str, + reason: Optional[str], content: JsonDict, received_ts: int, ) -> None: diff --git a/tests/rest/admin/test_event_reports.py b/tests/rest/admin/test_event_reports.py index 29341bc6e9..f15d1cf6f7 100644 --- a/tests/rest/admin/test_event_reports.py +++ b/tests/rest/admin/test_event_reports.py @@ -64,7 +64,7 @@ class EventReportsTestCase(unittest.HomeserverTestCase): user_tok=self.admin_user_tok, ) for _ in range(5): - self._create_event_and_report( + self._create_event_and_report_without_parameters( room_id=self.room_id2, user_tok=self.admin_user_tok, ) @@ -378,6 +378,19 @@ class EventReportsTestCase(unittest.HomeserverTestCase): ) self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def _create_event_and_report_without_parameters(self, room_id, user_tok): + """Create and report an event, but omit reason and score""" + resp = self.helper.send(room_id, tok=user_tok) + event_id = resp["event_id"] + + channel = self.make_request( + "POST", + "rooms/%s/report/%s" % (room_id, event_id), + json.dumps({}), + access_token=user_tok, + ) + self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) + def _check_fields(self, content): """Checks that all attributes are present in an event report""" for c in content: diff --git a/tests/rest/client/v2_alpha/test_report_event.py b/tests/rest/client/v2_alpha/test_report_event.py new file mode 100644 index 0000000000..1ec6b05e5b --- /dev/null +++ b/tests/rest/client/v2_alpha/test_report_event.py @@ -0,0 +1,83 @@ +# Copyright 2021 Callum Brown +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json + +import synapse.rest.admin +from synapse.rest.client.v1 import login, room +from synapse.rest.client.v2_alpha import report_event + +from tests import unittest + + +class ReportEventTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + report_event.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + self.other_user = self.register_user("user", "pass") + self.other_user_tok = self.login("user", "pass") + + self.room_id = self.helper.create_room_as( + self.other_user, tok=self.other_user_tok, is_public=True + ) + self.helper.join(self.room_id, user=self.admin_user, tok=self.admin_user_tok) + resp = self.helper.send(self.room_id, tok=self.admin_user_tok) + self.event_id = resp["event_id"] + self.report_path = "rooms/{}/report/{}".format(self.room_id, self.event_id) + + def test_reason_str_and_score_int(self): + data = {"reason": "this makes me sad", "score": -100} + self._assert_status(200, data) + + def test_no_reason(self): + data = {"score": 0} + self._assert_status(200, data) + + def test_no_score(self): + data = {"reason": "this makes me sad"} + self._assert_status(200, data) + + def test_no_reason_and_no_score(self): + data = {} + self._assert_status(200, data) + + def test_reason_int_and_score_str(self): + data = {"reason": 10, "score": "string"} + self._assert_status(400, data) + + def test_reason_zero_and_score_blank(self): + data = {"reason": 0, "score": ""} + self._assert_status(400, data) + + def test_reason_and_score_null(self): + data = {"reason": None, "score": None} + self._assert_status(400, data) + + def _assert_status(self, response_status, data): + channel = self.make_request( + "POST", + self.report_path, + json.dumps(data), + access_token=self.other_user_tok, + ) + self.assertEqual( + response_status, int(channel.result["code"]), msg=channel.result["body"] + ) -- cgit 1.5.1 From b4b2fd2ecee26214fa6b322bcb62bec1ea324c1a Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Jun 2021 12:04:47 +0100 Subject: add a cache to have_seen_event (#9953) Empirically, this helped my server considerably when handling gaps in Matrix HQ. The problem was that we would repeatedly call have_seen_events for the same set of (50K or so) auth_events, each of which would take many minutes to complete, even though it's only an index scan. --- changelog.d/9953.feature | 1 + changelog.d/9973.feature | 1 + changelog.d/9973.misc | 1 - synapse/handlers/federation.py | 12 +-- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events_worker.py | 61 ++++++++++++-- synapse/storage/databases/main/purge_events.py | 26 ++++-- tests/storage/databases/__init__.py | 13 +++ tests/storage/databases/main/__init__.py | 13 +++ tests/storage/databases/main/test_events_worker.py | 96 ++++++++++++++++++++++ 10 files changed, 205 insertions(+), 20 deletions(-) create mode 100644 changelog.d/9953.feature create mode 100644 changelog.d/9973.feature delete mode 100644 changelog.d/9973.misc create mode 100644 tests/storage/databases/__init__.py create mode 100644 tests/storage/databases/main/__init__.py create mode 100644 tests/storage/databases/main/test_events_worker.py (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/9953.feature b/changelog.d/9953.feature new file mode 100644 index 0000000000..6b3d1adc70 --- /dev/null +++ b/changelog.d/9953.feature @@ -0,0 +1 @@ +Improve performance of incoming federation transactions in large rooms. diff --git a/changelog.d/9973.feature b/changelog.d/9973.feature new file mode 100644 index 0000000000..6b3d1adc70 --- /dev/null +++ b/changelog.d/9973.feature @@ -0,0 +1 @@ +Improve performance of incoming federation transactions in large rooms. diff --git a/changelog.d/9973.misc b/changelog.d/9973.misc deleted file mode 100644 index 7f22d42291..0000000000 --- a/changelog.d/9973.misc +++ /dev/null @@ -1 +0,0 @@ -Make `LruCache.invalidate` support tree invalidation, and remove `invalidate_many`. diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bf11315251..49ed7cabcc 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -577,7 +577,9 @@ class FederationHandler(BaseHandler): # Fetch the state events from the DB, and check we have the auth events. event_map = await self.store.get_events(state_event_ids, allow_rejected=True) - auth_events_in_store = await self.store.have_seen_events(auth_event_ids) + auth_events_in_store = await self.store.have_seen_events( + room_id, auth_event_ids + ) # Check for missing events. We handle state and auth event seperately, # as we want to pull the state from the DB, but we don't for the auth @@ -610,7 +612,7 @@ class FederationHandler(BaseHandler): if missing_auth_events: auth_events_in_store = await self.store.have_seen_events( - missing_auth_events + room_id, missing_auth_events ) missing_auth_events.difference_update(auth_events_in_store) @@ -710,7 +712,7 @@ class FederationHandler(BaseHandler): missing_auth_events = set(auth_event_ids) - fetched_events.keys() missing_auth_events.difference_update( - await self.store.have_seen_events(missing_auth_events) + await self.store.have_seen_events(room_id, missing_auth_events) ) logger.debug("We are also missing %i auth events", len(missing_auth_events)) @@ -2475,7 +2477,7 @@ class FederationHandler(BaseHandler): # # we start by checking if they are in the store, and then try calling /event_auth/. if missing_auth: - have_events = await self.store.have_seen_events(missing_auth) + have_events = await self.store.have_seen_events(event.room_id, missing_auth) logger.debug("Events %s are in the store", have_events) missing_auth.difference_update(have_events) @@ -2494,7 +2496,7 @@ class FederationHandler(BaseHandler): return context seen_remotes = await self.store.have_seen_events( - [e.event_id for e in remote_auth_chain] + event.room_id, [e.event_id for e in remote_auth_chain] ) for e in remote_auth_chain: diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index f7872501a0..c57ae5ef15 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -168,6 +168,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): backfilled, ): self._invalidate_get_event_cache(event_id) + self.have_seen_event.invalidate((room_id, event_id)) self.get_latest_event_ids_in_room.invalidate((room_id,)) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 6963bbf7f4..403a5ddaba 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -22,6 +22,7 @@ from typing import ( Iterable, List, Optional, + Set, Tuple, overload, ) @@ -55,7 +56,7 @@ from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.sequence import build_sequence_generator from synapse.types import JsonDict, get_domain_from_id -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -1045,32 +1046,74 @@ class EventsWorkerStore(SQLBaseStore): return {r["event_id"] for r in rows} - async def have_seen_events(self, event_ids): + async def have_seen_events( + self, room_id: str, event_ids: Iterable[str] + ) -> Set[str]: """Given a list of event ids, check if we have already processed them. + The room_id is only used to structure the cache (so that it can later be + invalidated by room_id) - there is no guarantee that the events are actually + in the room in question. + Args: - event_ids (iterable[str]): + room_id: Room we are polling + event_ids: events we are looking for Returns: set[str]: The events we have already seen. """ + res = await self._have_seen_events_dict( + (room_id, event_id) for event_id in event_ids + ) + return {eid for ((_rid, eid), have_event) in res.items() if have_event} + + @cachedList("have_seen_event", "keys") + async def _have_seen_events_dict( + self, keys: Iterable[Tuple[str, str]] + ) -> Dict[Tuple[str, str], bool]: + """Helper for have_seen_events + + Returns: + a dict {(room_id, event_id)-> bool} + """ # if the event cache contains the event, obviously we've seen it. - results = {x for x in event_ids if self._get_event_cache.contains(x)} - def have_seen_events_txn(txn, chunk): - sql = "SELECT event_id FROM events as e WHERE " + cache_results = { + (rid, eid) for (rid, eid) in keys if self._get_event_cache.contains((eid,)) + } + results = {x: True for x in cache_results} + + def have_seen_events_txn(txn, chunk: Tuple[Tuple[str, str], ...]): + # we deliberately do *not* query the database for room_id, to make the + # query an index-only lookup on `events_event_id_key`. + # + # We therefore pull the events from the database into a set... + + sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", chunk + txn.database_engine, "e.event_id", [eid for (_rid, eid) in chunk] ) txn.execute(sql + clause, args) - results.update(row[0] for row in txn) + found_events = {eid for eid, in txn} - for chunk in batch_iter((x for x in event_ids if x not in results), 100): + # ... and then we can update the results for each row in the batch + results.update({(rid, eid): (eid in found_events) for (rid, eid) in chunk}) + + # each batch requires its own index scan, so we make the batches as big as + # possible. + for chunk in batch_iter((k for k in keys if k not in cache_results), 500): await self.db_pool.runInteraction( "have_seen_events", have_seen_events_txn, chunk ) + return results + @cached(max_entries=100000, tree=True) + async def have_seen_event(self, room_id: str, event_id: str): + # this only exists for the benefit of the @cachedList descriptor on + # _have_seen_events_dict + raise NotImplementedError() + def _get_current_state_event_counts_txn(self, txn, room_id): """ See get_current_state_event_counts. diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 8f83748b5e..7fb7780d0f 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -16,14 +16,14 @@ import logging from typing import Any, List, Set, Tuple from synapse.api.errors import SynapseError -from synapse.storage._base import SQLBaseStore +from synapse.storage.databases.main import CacheInvalidationWorkerStore from synapse.storage.databases.main.state import StateGroupWorkerStore from synapse.types import RoomStreamToken logger = logging.getLogger(__name__) -class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): +class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): async def purge_history( self, room_id: str, token: str, delete_local_events: bool ) -> Set[int]: @@ -203,8 +203,6 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): "DELETE FROM event_to_state_groups " "WHERE event_id IN (SELECT event_id from events_to_purge)" ) - for event_id, _ in event_rows: - txn.call_after(self._get_state_group_for_event.invalidate, (event_id,)) # Delete all remote non-state events for table in ( @@ -283,6 +281,20 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # so make sure to keep this actually last. txn.execute("DROP TABLE events_to_purge") + for event_id, should_delete in event_rows: + self._invalidate_cache_and_stream( + txn, self._get_state_group_for_event, (event_id,) + ) + + # XXX: This is racy, since have_seen_events could be called between the + # transaction completing and the invalidation running. On the other hand, + # that's no different to calling `have_seen_events` just before the + # event is deleted from the database. + if should_delete: + self._invalidate_cache_and_stream( + txn, self.have_seen_event, (room_id, event_id) + ) + logger.info("[purge] done") return referenced_state_groups @@ -422,7 +434,11 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): # index on them. In any case we should be clearing out 'stream' tables # periodically anyway (#5888) - # TODO: we could probably usefully do a bunch of cache invalidation here + # TODO: we could probably usefully do a bunch more cache invalidation here + + # XXX: as with purge_history, this is racy, but no worse than other races + # that already exist. + self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,)) logger.info("[purge] done") diff --git a/tests/storage/databases/__init__.py b/tests/storage/databases/__init__.py new file mode 100644 index 0000000000..c24c7ecd92 --- /dev/null +++ b/tests/storage/databases/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/storage/databases/main/__init__.py b/tests/storage/databases/main/__init__.py new file mode 100644 index 0000000000..c24c7ecd92 --- /dev/null +++ b/tests/storage/databases/main/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py new file mode 100644 index 0000000000..932970fd9a --- /dev/null +++ b/tests/storage/databases/main/test_events_worker.py @@ -0,0 +1,96 @@ +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import json + +from synapse.logging.context import LoggingContext +from synapse.storage.databases.main.events_worker import EventsWorkerStore + +from tests import unittest + + +class HaveSeenEventsTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.store: EventsWorkerStore = hs.get_datastore() + + # insert some test data + for rid in ("room1", "room2"): + self.get_success( + self.store.db_pool.simple_insert( + "rooms", + {"room_id": rid, "room_version": 4}, + ) + ) + + for idx, (rid, eid) in enumerate( + ( + ("room1", "event10"), + ("room1", "event11"), + ("room1", "event12"), + ("room2", "event20"), + ) + ): + self.get_success( + self.store.db_pool.simple_insert( + "events", + { + "event_id": eid, + "room_id": rid, + "topological_ordering": idx, + "stream_ordering": idx, + "type": "test", + "processed": True, + "outlier": False, + }, + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + "event_json", + { + "event_id": eid, + "room_id": rid, + "json": json.dumps({"type": "test", "room_id": rid}), + "internal_metadata": "{}", + "format_version": 3, + }, + ) + ) + + def test_simple(self): + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events("room1", ["event10", "event19"]) + ) + self.assertEquals(res, {"event10"}) + + # that should result in a single db query + self.assertEquals(ctx.get_resource_usage().db_txn_count, 1) + + # a second lookup of the same events should cause no queries + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events("room1", ["event10", "event19"]) + ) + self.assertEquals(res, {"event10"}) + self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) + + def test_query_via_event_cache(self): + # fetch an event into the event cache + self.get_success(self.store.get_event("event10")) + + # looking it up should now cause no db hits + with LoggingContext(name="test") as ctx: + res = self.get_success(self.store.have_seen_events("room1", ["event10"])) + self.assertEquals(res, {"event10"}) + self.assertEquals(ctx.get_resource_usage().db_txn_count, 0) -- cgit 1.5.1 From 0284d2a2976e3d58e9970fdb7590f98a2556326d Mon Sep 17 00:00:00 2001 From: Dirk Klimpel <5740567+dklimpel@users.noreply.github.com> Date: Wed, 2 Jun 2021 19:50:35 +0200 Subject: Add new admin APIs to remove media by media ID from quarantine. (#10044) Related to: #6681, #5956, #10040 Signed-off-by: Dirk Klimpel dirk@klimpel.org --- changelog.d/10044.feature | 1 + docs/admin_api/media_admin_api.md | 22 ++++++ synapse/rest/admin/media.py | 30 ++++++++ synapse/storage/databases/main/room.py | 30 +++++--- tests/rest/admin/test_media.py | 128 +++++++++++++++++++++++++++++++++ 5 files changed, 201 insertions(+), 10 deletions(-) create mode 100644 changelog.d/10044.feature (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10044.feature b/changelog.d/10044.feature new file mode 100644 index 0000000000..70c0a3851e --- /dev/null +++ b/changelog.d/10044.feature @@ -0,0 +1 @@ +Add new admin APIs to remove media by media ID from quarantine. Contributed by @dkimpel. diff --git a/docs/admin_api/media_admin_api.md b/docs/admin_api/media_admin_api.md index d1b7e390d5..7709f3d8c7 100644 --- a/docs/admin_api/media_admin_api.md +++ b/docs/admin_api/media_admin_api.md @@ -4,6 +4,7 @@ * [List all media uploaded by a user](#list-all-media-uploaded-by-a-user) - [Quarantine media](#quarantine-media) * [Quarantining media by ID](#quarantining-media-by-id) + * [Remove media from quarantine by ID](#remove-media-from-quarantine-by-id) * [Quarantining media in a room](#quarantining-media-in-a-room) * [Quarantining all media of a user](#quarantining-all-media-of-a-user) * [Protecting media from being quarantined](#protecting-media-from-being-quarantined) @@ -77,6 +78,27 @@ Response: {} ``` +## Remove media from quarantine by ID + +This API removes a single piece of local or remote media from quarantine. + +Request: + +``` +POST /_synapse/admin/v1/media/unquarantine// + +{} +``` + +Where `server_name` is in the form of `example.org`, and `media_id` is in the +form of `abcdefg12345...`. + +Response: + +```json +{} +``` + ## Quarantining media in a room This API quarantines all local and remote media in a room. diff --git a/synapse/rest/admin/media.py b/synapse/rest/admin/media.py index 2c71af4279..b68db2c57c 100644 --- a/synapse/rest/admin/media.py +++ b/synapse/rest/admin/media.py @@ -120,6 +120,35 @@ class QuarantineMediaByID(RestServlet): return 200, {} +class UnquarantineMediaByID(RestServlet): + """Quarantines local or remote media by a given ID so that no one can download + it via this server. + """ + + PATTERNS = admin_patterns( + "/media/unquarantine/(?P[^/]+)/(?P[^/]+)" + ) + + def __init__(self, hs: "HomeServer"): + self.store = hs.get_datastore() + self.auth = hs.get_auth() + + async def on_POST( + self, request: SynapseRequest, server_name: str, media_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + logging.info( + "Remove from quarantine local media by ID: %s/%s", server_name, media_id + ) + + # Remove from quarantine this media id + await self.store.quarantine_media_by_id(server_name, media_id, None) + + return 200, {} + + class ProtectMediaByID(RestServlet): """Protect local media from being quarantined.""" @@ -290,6 +319,7 @@ def register_servlets_for_media_repo(hs: "HomeServer", http_server): PurgeMediaCacheRestServlet(hs).register(http_server) QuarantineMediaInRoom(hs).register(http_server) QuarantineMediaByID(hs).register(http_server) + UnquarantineMediaByID(hs).register(http_server) QuarantineMediaByUser(hs).register(http_server) ProtectMediaByID(hs).register(http_server) UnprotectMediaByID(hs).register(http_server) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 0cf450f81d..2a96bcd314 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -764,14 +764,15 @@ class RoomWorkerStore(SQLBaseStore): self, server_name: str, media_id: str, - quarantined_by: str, + quarantined_by: Optional[str], ) -> int: - """quarantines a single local or remote media id + """quarantines or unquarantines a single local or remote media id Args: server_name: The name of the server that holds this media media_id: The ID of the media to be quarantined quarantined_by: The user ID that initiated the quarantine request + If it is `None` media will be removed from quarantine """ logger.info("Quarantining media: %s/%s", server_name, media_id) is_local = server_name == self.config.server_name @@ -838,9 +839,9 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs: List[str], remote_mxcs: List[Tuple[str, str]], - quarantined_by: str, + quarantined_by: Optional[str], ) -> int: - """Quarantine local and remote media items + """Quarantine and unquarantine local and remote media items Args: txn (cursor) @@ -848,18 +849,27 @@ class RoomWorkerStore(SQLBaseStore): remote_mxcs: A list of (remote server, media id) tuples representing remote mxc URLs quarantined_by: The ID of the user who initiated the quarantine request + If it is `None` media will be removed from quarantine Returns: The total number of media items quarantined """ + # Update all the tables to set the quarantined_by flag - txn.executemany( - """ + sql = """ UPDATE local_media_repository SET quarantined_by = ? - WHERE media_id = ? AND safe_from_quarantine = ? - """, - ((quarantined_by, media_id, False) for media_id in local_mxcs), - ) + WHERE media_id = ? + """ + + # set quarantine + if quarantined_by is not None: + sql += "AND safe_from_quarantine = ?" + rows = [(quarantined_by, media_id, False) for media_id in local_mxcs] + # remove from quarantine + else: + rows = [(quarantined_by, media_id) for media_id in local_mxcs] + + txn.executemany(sql, rows) # Note that a rowcount of -1 can be used to indicate no rows were affected. total_media_quarantined = txn.rowcount if txn.rowcount > 0 else 0 diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index f741121ea2..6fee0f95b6 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -566,6 +566,134 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertFalse(os.path.exists(local_path)) +class QuarantineMediaByIDTestCase(unittest.HomeserverTestCase): + + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.admin.register_servlets_for_media_repo, + login.register_servlets, + ] + + def prepare(self, reactor, clock, hs): + media_repo = hs.get_media_repository_resource() + self.store = hs.get_datastore() + self.server_name = hs.hostname + + self.admin_user = self.register_user("admin", "pass", admin=True) + self.admin_user_tok = self.login("admin", "pass") + + # Create media + upload_resource = media_repo.children[b"upload"] + # file size is 67 Byte + image_data = unhexlify( + b"89504e470d0a1a0a0000000d4948445200000001000000010806" + b"0000001f15c4890000000a49444154789c63000100000500010d" + b"0a2db40000000049454e44ae426082" + ) + + # Upload some media into the room + response = self.helper.upload_media( + upload_resource, image_data, tok=self.admin_user_tok, expect_code=200 + ) + # Extract media ID from the response + server_and_media_id = response["content_uri"][6:] # Cut off 'mxc://' + self.media_id = server_and_media_id.split("/")[1] + + self.url = "/_synapse/admin/v1/media/%s/%s/%s" + + @parameterized.expand(["quarantine", "unquarantine"]) + def test_no_auth(self, action: str): + """ + Try to protect media without authentication. + """ + + channel = self.make_request( + "POST", + self.url % (action, self.server_name, self.media_id), + b"{}", + ) + + self.assertEqual(401, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"]) + + @parameterized.expand(["quarantine", "unquarantine"]) + def test_requester_is_no_admin(self, action: str): + """ + If the user is not a server admin, an error is returned. + """ + self.other_user = self.register_user("user", "pass") + self.other_user_token = self.login("user", "pass") + + channel = self.make_request( + "POST", + self.url % (action, self.server_name, self.media_id), + access_token=self.other_user_token, + ) + + self.assertEqual(403, int(channel.result["code"]), msg=channel.result["body"]) + self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"]) + + def test_quarantine_media(self): + """ + Tests that quarantining and remove from quarantine a media is successfully + """ + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["quarantined_by"]) + + # quarantining + channel = self.make_request( + "POST", + self.url % ("quarantine", self.server_name, self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertTrue(media_info["quarantined_by"]) + + # remove from quarantine + channel = self.make_request( + "POST", + self.url % ("unquarantine", self.server_name, self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["quarantined_by"]) + + def test_quarantine_protected_media(self): + """ + Tests that quarantining from protected media fails + """ + + # protect + self.get_success(self.store.mark_local_media_as_safe(self.media_id, safe=True)) + + # verify protection + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertTrue(media_info["safe_from_quarantine"]) + + # quarantining + channel = self.make_request( + "POST", + self.url % ("quarantine", self.server_name, self.media_id), + access_token=self.admin_user_tok, + ) + + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertFalse(channel.json_body) + + # verify that is not in quarantine + media_info = self.get_success(self.store.get_local_media(self.media_id)) + self.assertFalse(media_info["quarantined_by"]) + + class ProtectMediaByIDTestCase(unittest.HomeserverTestCase): servlets = [ -- cgit 1.5.1 From c955f22e2c88676944124a4a3c80112b35231035 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 11 Jun 2021 10:27:12 +0100 Subject: Fix bug when running presence off master (#10149) Hopefully fixes #10027. --- changelog.d/10149.bugfix | 1 + synapse/storage/databases/main/presence.py | 2 +- synapse/storage/util/id_generators.py | 15 +++++++++++++++ 3 files changed, 17 insertions(+), 1 deletion(-) create mode 100644 changelog.d/10149.bugfix (limited to 'synapse/storage/databases/main') diff --git a/changelog.d/10149.bugfix b/changelog.d/10149.bugfix new file mode 100644 index 0000000000..cb2d2eedb3 --- /dev/null +++ b/changelog.d/10149.bugfix @@ -0,0 +1 @@ +Fix a bug which caused presence updates to stop working some time after restart, when using a presence writer worker. diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py index 6a2baa7841..1388771c40 100644 --- a/synapse/storage/databases/main/presence.py +++ b/synapse/storage/databases/main/presence.py @@ -50,7 +50,7 @@ class PresenceStore(SQLBaseStore): instance_name=self._instance_name, tables=[("presence_stream", "instance_name", "stream_id")], sequence_name="presence_stream_sequence", - writers=hs.config.worker.writers.to_device, + writers=hs.config.worker.writers.presence, ) else: self._presence_id_gen = StreamIdGenerator( diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b1bd3a52d9..f1e62f9e85 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -397,6 +397,11 @@ class MultiWriterIdGenerator: # ... persist event ... """ + # If we have a list of instances that are allowed to write to this + # stream, make sure we're in it. + if self._writers and self._instance_name not in self._writers: + raise Exception("Tried to allocate stream ID on non-writer") + return _MultiWriterCtxManager(self) def get_next_mult(self, n: int): @@ -406,6 +411,11 @@ class MultiWriterIdGenerator: # ... persist events ... """ + # If we have a list of instances that are allowed to write to this + # stream, make sure we're in it. + if self._writers and self._instance_name not in self._writers: + raise Exception("Tried to allocate stream ID on non-writer") + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): @@ -416,6 +426,11 @@ class MultiWriterIdGenerator: # ... persist event ... """ + # If we have a list of instances that are allowed to write to this + # stream, make sure we're in it. + if self._writers and self._instance_name not in self._writers: + raise Exception("Tried to allocate stream ID on non-writer") + next_id = self._load_next_id_txn(txn) with self._lock: -- cgit 1.5.1