From 6bd8763804dc0987c7ecd37bcb5ebff465fffa29 Mon Sep 17 00:00:00 2001 From: Mathieu Velten Date: Wed, 21 Sep 2022 15:32:01 +0200 Subject: Add cache invalidation across workers to module API (#13667) Signed-off-by: Mathieu Velten --- synapse/storage/databases/main/cache.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) (limited to 'synapse/storage/databases/main/cache.py') diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 12e9a42382..2c421151c1 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -33,7 +33,7 @@ from synapse.storage.database import ( ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.util.caches.descriptors import _CachedFunction +from synapse.util.caches.descriptors import CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): return cache_func.invalidate(keys) - await self.db_pool.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, + await self.send_invalidation_to_replication( cache_func.__name__, keys, ) @@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_cache_and_stream( self, txn: LoggingTransaction, - cache_func: _CachedFunction, + cache_func: CachedFunction, keys: Tuple[Any, ...], ) -> None: """Invalidates the cache and adds it to the cache stream so slaves @@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._send_invalidation_to_replication(txn, cache_func.__name__, keys) def _invalidate_all_cache_and_stream( - self, txn: LoggingTransaction, cache_func: _CachedFunction + self, txn: LoggingTransaction, cache_func: CachedFunction ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn, CURRENT_STATE_CACHE_NAME, [room_id] ) + async def send_invalidation_to_replication( + self, cache_name: str, keys: Optional[Collection[Any]] + ) -> None: + await self.db_pool.runInteraction( + "send_invalidation_to_replication", + self._send_invalidation_to_replication, + cache_name, + keys, + ) + def _send_invalidation_to_replication( self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] ) -> None: -- cgit 1.5.1 From 6b4593a80fa2fd9ec8e1ec82fad74f3b7fbb9ba3 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 26 Sep 2022 16:26:35 +0100 Subject: Simplify cache invalidation after event persist txn (#13796) This moves all the invalidations into a single place and de-duplicates the code involved in invalidating caches for a given event by using the base class method. --- changelog.d/13796.misc | 1 + synapse/storage/_base.py | 3 + synapse/storage/databases/main/cache.py | 34 +++++--- synapse/storage/databases/main/events.py | 133 +++++++------------------------ 4 files changed, 52 insertions(+), 119 deletions(-) create mode 100644 changelog.d/13796.misc (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/13796.misc b/changelog.d/13796.misc new file mode 100644 index 0000000000..9ed1662394 --- /dev/null +++ b/changelog.d/13796.misc @@ -0,0 +1 @@ +Use shared methods for cache invalidation when persisting events, remove duplicate codepaths. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 303a5d5298..313e8aca7d 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -91,6 +91,9 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache( "get_user_in_room_with_profile", (room_id, user_id) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (user_id,) + ) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 2c421151c1..db6ce83a2b 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -223,15 +223,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # process triggering the invalidation is responsible for clearing any external # cached objects. self._invalidate_local_get_event_cache(event_id) - self.have_seen_event.invalidate((room_id, event_id)) - self.get_latest_event_ids_in_room.invalidate((room_id,)) - - self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,)) + self._attempt_to_invalidate_cache("have_seen_event", (room_id, event_id)) + self._attempt_to_invalidate_cache("get_latest_event_ids_in_room", (room_id,)) + self._attempt_to_invalidate_cache( + "get_unread_event_push_actions_by_room_for_user", (room_id,) + ) # The `_get_membership_from_event_id` is immutable, except for the # case where we look up an event *before* persisting it. - self._get_membership_from_event_id.invalidate((event_id,)) + self._attempt_to_invalidate_cache("_get_membership_from_event_id", (event_id,)) if not backfilled: self._events_stream_cache.entity_has_changed(room_id, stream_ordering) @@ -240,19 +241,26 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._invalidate_local_get_event_cache(redacts) # Caches which might leak edits must be invalidated for the event being # redacted. - self.get_relations_for_event.invalidate((redacts,)) - self.get_applicable_edit.invalidate((redacts,)) + self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) + self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) - self.get_invited_rooms_for_local_user.invalidate((state_key,)) + self._attempt_to_invalidate_cache( + "get_invited_rooms_for_local_user", (state_key,) + ) if relates_to: - self.get_relations_for_event.invalidate((relates_to,)) - self.get_aggregation_groups_for_event.invalidate((relates_to,)) - self.get_applicable_edit.invalidate((relates_to,)) - self.get_thread_summary.invalidate((relates_to,)) - self.get_thread_participated.invalidate((relates_to,)) + self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache( + "get_aggregation_groups_for_event", (relates_to,) + ) + self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) + self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) + self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) + self._attempt_to_invalidate_cache( + "get_mutual_event_relations_for_rel_type", (relates_to,) + ) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1b54a2eb57..2e156a4a11 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes, RelationTypes +from synapse.api.constants import EventContentFields, EventTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -410,6 +410,31 @@ class PersistEventsStore: assert min_stream_order assert max_stream_order + # Once the txn completes, invalidate all of the relevant caches. Note that we do this + # up here because it captures all the events_and_contexts before any are removed. + for event, _ in events_and_contexts: + self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) + if event.redacts: + self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) + + relates_to = None + relation = relation_from_event(event) + if relation: + relates_to = relation.parent_id + + assert event.internal_metadata.stream_ordering is not None + txn.call_after( + self.store._invalidate_caches_for_event, + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.type, + getattr(event, "state_key", None), + event.redacts, + relates_to, + backfilled=False, + ) + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremities, @@ -459,6 +484,7 @@ class PersistEventsStore: # We call this last as it assumes we've inserted the events into # room_memberships, where applicable. + # NB: This function invalidates all state related caches self._update_current_state_txn(txn, state_delta_for_room, min_stream_order) def _persist_event_auth_chain_txn( @@ -1172,13 +1198,6 @@ class PersistEventsStore: ) # Invalidate the various caches - - for member in members_changed: - txn.call_after( - self.store.get_rooms_for_user_with_stream_ordering.invalidate, - (member,), - ) - self.store._invalidate_state_caches_and_stream( txn, room_id, members_changed ) @@ -1222,9 +1241,6 @@ class PersistEventsStore: self.db_pool.simple_delete_txn( txn, table="event_forward_extremities", keyvalues={"room_id": room_id} ) - txn.call_after( - self.store.get_latest_event_ids_in_room.invalidate, (room_id,) - ) self.db_pool.simple_insert_many_txn( txn, @@ -1294,8 +1310,6 @@ class PersistEventsStore: """ depth_updates: Dict[str, int] = {} for event, context in events_and_contexts: - # Remove the any existing cache entries for the event_ids - self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) # Then update the `stream_ordering` position to mark the latest # event as the front of the room. This should not be done for # backfilled events because backfilled events have negative @@ -1697,16 +1711,7 @@ class PersistEventsStore: txn.async_call_after(prefill) def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None: - """Invalidate the caches for the redacted event. - - Note that these caches are also cleared as part of event replication in - _invalidate_caches_for_event. - """ assert event.redacts is not None - self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) - txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,)) - txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,)) - self.db_pool.simple_upsert_txn( txn, table="redactions", @@ -1807,34 +1812,6 @@ class PersistEventsStore: for event in events: assert event.internal_metadata.stream_ordering is not None - txn.call_after( - self.store._membership_stream_cache.entity_has_changed, - event.state_key, - event.internal_metadata.stream_ordering, - ) - txn.call_after( - self.store.get_invited_rooms_for_local_user.invalidate, - (event.state_key,), - ) - txn.call_after( - self.store.get_local_users_in_room.invalidate, - (event.room_id,), - ) - txn.call_after( - self.store.get_number_joined_users_in_room.invalidate, - (event.room_id,), - ) - txn.call_after( - self.store.get_user_in_room_with_profile.invalidate, - (event.room_id, event.state_key), - ) - - # The `_get_membership_from_event_id` is immutable, except for the - # case where we look up an event *before* persisting it. - txn.call_after( - self.store._get_membership_from_event_id.invalidate, - (event.event_id,), - ) # We update the local_current_membership table only if the event is # "current", i.e., its something that has just happened. @@ -1883,35 +1860,6 @@ class PersistEventsStore: }, ) - txn.call_after( - self.store.get_relations_for_event.invalidate, (relation.parent_id,) - ) - txn.call_after( - self.store.get_aggregation_groups_for_event.invalidate, - (relation.parent_id,), - ) - txn.call_after( - self.store.get_mutual_event_relations_for_rel_type.invalidate, - (relation.parent_id,), - ) - - if relation.rel_type == RelationTypes.REPLACE: - txn.call_after( - self.store.get_applicable_edit.invalidate, (relation.parent_id,) - ) - - if relation.rel_type == RelationTypes.THREAD: - txn.call_after( - self.store.get_thread_summary.invalidate, (relation.parent_id,) - ) - # It should be safe to only invalidate the cache if the user has not - # previously participated in the thread, but that's difficult (and - # potentially error-prone) so it is always invalidated. - txn.call_after( - self.store.get_thread_participated.invalidate, - (relation.parent_id, event.sender), - ) - def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -2213,28 +2161,6 @@ class PersistEventsStore: ), ) - room_to_event_ids: Dict[str, List[str]] = {} - for e in non_outlier_events: - room_to_event_ids.setdefault(e.room_id, []).append(e.event_id) - - for room_id, event_ids in room_to_event_ids.items(): - rows = self.db_pool.simple_select_many_txn( - txn, - table="event_push_actions_staging", - column="event_id", - iterable=event_ids, - keyvalues={}, - retcols=("user_id",), - ) - - user_ids = {row["user_id"] for row in rows} - - for user_id in user_ids: - txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate, - (room_id, user_id), - ) - # Now we delete the staging area for *all* events that were being # persisted. txn.execute_batch( @@ -2249,11 +2175,6 @@ class PersistEventsStore: def _remove_push_actions_for_event_id_txn( self, txn: LoggingTransaction, room_id: str, event_id: str ) -> None: - # Sad that we have to blow away the cache for the whole room here - txn.call_after( - self.store.get_unread_event_push_actions_by_room_for_user.invalidate, - (room_id,), - ) txn.execute( "DELETE FROM event_push_actions WHERE room_id = ? AND event_id = ?", (room_id, event_id), -- cgit 1.5.1 From a466164647b969efd2e85168144cd75693443c05 Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Thu, 29 Sep 2022 14:55:12 +0100 Subject: Optimise get_rooms_for_user (drop with_stream_ordering) (#13787) --- changelog.d/13787.misc | 1 + synapse/handlers/device.py | 6 +- synapse/handlers/sync.py | 14 +--- synapse/storage/_base.py | 1 + synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/roommember.py | 117 +++++++++++++-------------- tests/handlers/test_sync.py | 1 + 7 files changed, 66 insertions(+), 75 deletions(-) create mode 100644 changelog.d/13787.misc (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/13787.misc b/changelog.d/13787.misc new file mode 100644 index 0000000000..a9b93717f0 --- /dev/null +++ b/changelog.d/13787.misc @@ -0,0 +1 @@ +Optimise get rooms for user calls. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 03082fce42..f9cc5bddbc 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -273,11 +273,9 @@ class DeviceWorkerHandler: possibly_left = possibly_changed | possibly_left # Double check if we still share rooms with the given user. - users_rooms = await self.store.get_rooms_for_users_with_stream_ordering( - possibly_left - ) + users_rooms = await self.store.get_rooms_for_users(possibly_left) for changed_user_id, entries in users_rooms.items(): - if any(e.room_id in room_ids for e in entries): + if any(rid in room_ids for rid in entries): possibly_left.discard(changed_user_id) else: possibly_joined.discard(changed_user_id) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index e75fc6b947..4abb9b6127 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -1490,16 +1490,14 @@ class SyncHandler: since_token.device_list_key ) if changed_users is not None: - result = await self.store.get_rooms_for_users_with_stream_ordering( - changed_users - ) + result = await self.store.get_rooms_for_users(changed_users) for changed_user_id, entries in result.items(): # Check if the changed user shares any rooms with the user, # or if the changed user is the syncing user (as we always # want to include device list updates of their own devices). if user_id == changed_user_id or any( - e.room_id in joined_rooms for e in entries + rid in joined_rooms for rid in entries ): users_that_have_changed.add(changed_user_id) else: @@ -1533,13 +1531,9 @@ class SyncHandler: newly_left_users.update(left_users) # Remove any users that we still share a room with. - left_users_rooms = ( - await self.store.get_rooms_for_users_with_stream_ordering( - newly_left_users - ) - ) + left_users_rooms = await self.store.get_rooms_for_users(newly_left_users) for user_id, entries in left_users_rooms.items(): - if any(e.room_id in joined_rooms for e in entries): + if any(rid in joined_rooms for rid in entries): newly_left_users.discard(user_id) return DeviceListUpdates( diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 313e8aca7d..bf42aeb8d1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -94,6 +94,7 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache( "get_rooms_for_user_with_stream_ordering", (user_id,) ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (user_id,)) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index db6ce83a2b..3b8ed1f7ee 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -205,6 +205,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self.get_rooms_for_user_with_stream_ordering.invalidate( (data.state_key,) ) + self.get_rooms_for_user.invalidate((data.state_key,)) else: raise Exception("Unknown events stream row type %s" % (row.type,)) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 8ada3cdac3..982e1f08e3 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -15,7 +15,6 @@ import logging from typing import ( TYPE_CHECKING, - Callable, Collection, Dict, FrozenSet, @@ -52,7 +51,6 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string from synapse.util.caches.descriptors import _CacheContext, cached, cachedList -from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter from synapse.util.metrics import Measure @@ -600,58 +598,6 @@ class RoomMemberWorkerStore(EventsWorkerStore): for room_id, instance, stream_id in txn ) - @cachedList( - cached_method_name="get_rooms_for_user_with_stream_ordering", - list_name="user_ids", - ) - async def get_rooms_for_users_with_stream_ordering( - self, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - """A batched version of `get_rooms_for_user_with_stream_ordering`. - - Returns: - Map from user_id to set of rooms that is currently in. - """ - return await self.db_pool.runInteraction( - "get_rooms_for_users_with_stream_ordering", - self._get_rooms_for_users_with_stream_ordering_txn, - user_ids, - ) - - def _get_rooms_for_users_with_stream_ordering_txn( - self, txn: LoggingTransaction, user_ids: Collection[str] - ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]: - - clause, args = make_in_list_sql_clause( - self.database_engine, - "c.state_key", - user_ids, - ) - - sql = f""" - SELECT c.state_key, room_id, e.instance_name, e.stream_ordering - FROM current_state_events AS c - INNER JOIN events AS e USING (room_id, event_id) - WHERE - c.type = 'm.room.member' - AND c.membership = ? - AND {clause} - """ - - txn.execute(sql, [Membership.JOIN] + args) - - result: Dict[str, Set[GetRoomsForUserWithStreamOrdering]] = { - user_id: set() for user_id in user_ids - } - for user_id, room_id, instance, stream_id in txn: - result[user_id].add( - GetRoomsForUserWithStreamOrdering( - room_id, PersistedEventPosition(instance, stream_id) - ) - ) - - return {user_id: frozenset(v) for user_id, v in result.items()} - async def get_users_server_still_shares_room_with( self, user_ids: Collection[str] ) -> Set[str]: @@ -693,19 +639,68 @@ class RoomMemberWorkerStore(EventsWorkerStore): return {row[0] for row in txn} - @cancellable - async def get_rooms_for_user( - self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None - ) -> FrozenSet[str]: + @cached(max_entries=500000, iterable=True) + async def get_rooms_for_user(self, user_id: str) -> FrozenSet[str]: """Returns a set of room_ids the user is currently joined to. If a remote user only returns rooms this server is currently participating in. """ - rooms = await self.get_rooms_for_user_with_stream_ordering( - user_id, on_invalidate=on_invalidate + rooms = self.get_rooms_for_user_with_stream_ordering.cache.get_immediate( + (user_id,), + None, + update_metrics=False, + ) + if rooms: + return frozenset(r.room_id for r in rooms) + + room_ids = await self.db_pool.simple_select_onecol( + table="current_state_events", + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + "state_key": user_id, + }, + retcol="room_id", + desc="get_rooms_for_user", ) - return frozenset(r.room_id for r in rooms) + + return frozenset(room_ids) + + @cachedList( + cached_method_name="get_rooms_for_user", + list_name="user_ids", + ) + async def get_rooms_for_users( + self, user_ids: Collection[str] + ) -> Dict[str, FrozenSet[str]]: + """A batched version of `get_rooms_for_user`. + + Returns: + Map from user_id to set of rooms that is currently in. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="current_state_events", + column="state_key", + iterable=user_ids, + retcols=( + "state_key", + "room_id", + ), + keyvalues={ + "type": EventTypes.Member, + "membership": Membership.JOIN, + }, + desc="get_rooms_for_users", + ) + + user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids} + + for row in rows: + user_rooms[row["state_key"]].add(row["room_id"]) + + return {key: frozenset(rooms) for key, rooms in user_rooms.items()} @cached(max_entries=10000) async def does_pair_of_users_share_a_room( diff --git a/tests/handlers/test_sync.py b/tests/handlers/test_sync.py index e3f38fbcc5..ab5c101eb7 100644 --- a/tests/handlers/test_sync.py +++ b/tests/handlers/test_sync.py @@ -159,6 +159,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase): # Blow away caches (supported room versions can only change due to a restart). self.store.get_rooms_for_user_with_stream_ordering.invalidate_all() + self.store.get_rooms_for_user.invalidate_all() self.get_success(self.store._get_event_cache.clear()) self.store._event_ref.clear() -- cgit 1.5.1 From 09be8ab5f9d54fa1a577d8b0028abf8acc28f30d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 12 Oct 2022 06:26:39 -0400 Subject: Remove the experimental implementation of MSC3772. (#14094) MSC3772 has been abandoned. --- changelog.d/14094.removal | 1 + rust/src/push/base_rules.rs | 13 ---- rust/src/push/evaluator.rs | 105 +--------------------------- rust/src/push/mod.rs | 44 +++--------- stubs/synapse/synapse_rust/push.pyi | 6 +- synapse/config/experimental.py | 2 - synapse/push/bulk_push_rule_evaluator.py | 64 +---------------- synapse/storage/databases/main/cache.py | 3 - synapse/storage/databases/main/events.py | 5 -- synapse/storage/databases/main/push_rule.py | 15 ++-- synapse/storage/databases/main/relations.py | 53 -------------- tests/push/test_push_rule_evaluator.py | 76 +------------------- 12 files changed, 22 insertions(+), 365 deletions(-) create mode 100644 changelog.d/14094.removal (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/14094.removal b/changelog.d/14094.removal new file mode 100644 index 0000000000..6ef03b1a0f --- /dev/null +++ b/changelog.d/14094.removal @@ -0,0 +1 @@ +Remove the experimental implementation of [MSC3772](https://github.com/matrix-org/matrix-spec-proposals/pull/3772). diff --git a/rust/src/push/base_rules.rs b/rust/src/push/base_rules.rs index 2a09cf99ae..63240cacfc 100644 --- a/rust/src/push/base_rules.rs +++ b/rust/src/push/base_rules.rs @@ -257,19 +257,6 @@ pub const BASE_APPEND_UNDERRIDE_RULES: &[PushRule] = &[ default: true, default_enabled: true, }, - PushRule { - rule_id: Cow::Borrowed("global/underride/.org.matrix.msc3772.thread_reply"), - priority_class: 1, - conditions: Cow::Borrowed(&[Condition::Known(KnownCondition::RelationMatch { - rel_type: Cow::Borrowed("m.thread"), - event_type_pattern: None, - sender: None, - sender_type: Some(Cow::Borrowed("user_id")), - })]), - actions: Cow::Borrowed(&[Action::Notify, HIGHLIGHT_FALSE_ACTION]), - default: true, - default_enabled: true, - }, PushRule { rule_id: Cow::Borrowed("global/underride/.m.rule.message"), priority_class: 1, diff --git a/rust/src/push/evaluator.rs b/rust/src/push/evaluator.rs index efe88ec76e..0365dd01dc 100644 --- a/rust/src/push/evaluator.rs +++ b/rust/src/push/evaluator.rs @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{ - borrow::Cow, - collections::{BTreeMap, BTreeSet}, -}; +use std::collections::BTreeMap; use anyhow::{Context, Error}; use lazy_static::lazy_static; @@ -49,13 +46,6 @@ pub struct PushRuleEvaluator { /// The `notifications` section of the current power levels in the room. notification_power_levels: BTreeMap, - /// The relations related to the event as a mapping from relation type to - /// set of sender/event type 2-tuples. - relations: BTreeMap>, - - /// Is running "relation" conditions enabled? - relation_match_enabled: bool, - /// The power level of the sender of the event, or None if event is an /// outlier. sender_power_level: Option, @@ -70,8 +60,6 @@ impl PushRuleEvaluator { room_member_count: u64, sender_power_level: Option, notification_power_levels: BTreeMap, - relations: BTreeMap>, - relation_match_enabled: bool, ) -> Result { let body = flattened_keys .get("content.body") @@ -83,8 +71,6 @@ impl PushRuleEvaluator { body, room_member_count, notification_power_levels, - relations, - relation_match_enabled, sender_power_level, }) } @@ -203,89 +189,11 @@ impl PushRuleEvaluator { false } } - KnownCondition::RelationMatch { - rel_type, - event_type_pattern, - sender, - sender_type, - } => { - self.match_relations(rel_type, sender, sender_type, user_id, event_type_pattern)? - } }; Ok(result) } - /// Evaluates a relation condition. - fn match_relations( - &self, - rel_type: &str, - sender: &Option>, - sender_type: &Option>, - user_id: Option<&str>, - event_type_pattern: &Option>, - ) -> Result { - // First check if relation matching is enabled... - if !self.relation_match_enabled { - return Ok(false); - } - - // ... and if there are any relations to match against. - let relations = if let Some(relations) = self.relations.get(rel_type) { - relations - } else { - return Ok(false); - }; - - // Extract the sender pattern from the condition - let sender_pattern = if let Some(sender) = sender { - Some(sender.as_ref()) - } else if let Some(sender_type) = sender_type { - if sender_type == "user_id" { - if let Some(user_id) = user_id { - Some(user_id) - } else { - return Ok(false); - } - } else { - warn!("Unrecognized sender_type: {sender_type}"); - return Ok(false); - } - } else { - None - }; - - let mut sender_compiled_pattern = if let Some(pattern) = sender_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - let mut type_compiled_pattern = if let Some(pattern) = event_type_pattern { - Some(get_glob_matcher(pattern, GlobMatchType::Whole)?) - } else { - None - }; - - for (relation_sender, event_type) in relations { - if let Some(pattern) = &mut sender_compiled_pattern { - if !pattern.is_match(relation_sender)? { - continue; - } - } - - if let Some(pattern) = &mut type_compiled_pattern { - if !pattern.is_match(event_type)? { - continue; - } - } - - return Ok(true); - } - - Ok(false) - } - /// Evaluates a `event_match` condition. fn match_event_match( &self, @@ -359,15 +267,8 @@ impl PushRuleEvaluator { fn push_rule_evaluator() { let mut flattened_keys = BTreeMap::new(); flattened_keys.insert("content.body".to_string(), "foo bar bob hello".to_string()); - let evaluator = PushRuleEvaluator::py_new( - flattened_keys, - 10, - Some(0), - BTreeMap::new(), - BTreeMap::new(), - true, - ) - .unwrap(); + let evaluator = + PushRuleEvaluator::py_new(flattened_keys, 10, Some(0), BTreeMap::new()).unwrap(); let result = evaluator.run(&FilteredPushRules::default(), None, Some("bob")); assert_eq!(result.len(), 3); diff --git a/rust/src/push/mod.rs b/rust/src/push/mod.rs index 208b9c0d73..0dabfab8b8 100644 --- a/rust/src/push/mod.rs +++ b/rust/src/push/mod.rs @@ -275,16 +275,6 @@ pub enum KnownCondition { SenderNotificationPermission { key: Cow<'static, str>, }, - #[serde(rename = "org.matrix.msc3772.relation_match")] - RelationMatch { - rel_type: Cow<'static, str>, - #[serde(skip_serializing_if = "Option::is_none", rename = "type")] - event_type_pattern: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sender_type: Option>, - }, } impl IntoPy for Condition { @@ -401,21 +391,15 @@ impl PushRules { pub struct FilteredPushRules { push_rules: PushRules, enabled_map: BTreeMap, - msc3772_enabled: bool, } #[pymethods] impl FilteredPushRules { #[new] - pub fn py_new( - push_rules: PushRules, - enabled_map: BTreeMap, - msc3772_enabled: bool, - ) -> Self { + pub fn py_new(push_rules: PushRules, enabled_map: BTreeMap) -> Self { Self { push_rules, enabled_map, - msc3772_enabled, } } @@ -430,25 +414,13 @@ impl FilteredPushRules { /// Iterates over all the rules and their enabled state, including base /// rules, in the order they should be executed in. fn iter(&self) -> impl Iterator { - self.push_rules - .iter() - .filter(|rule| { - // Ignore disabled experimental push rules - if !self.msc3772_enabled - && rule.rule_id == "global/underride/.org.matrix.msc3772.thread_reply" - { - return false; - } - - true - }) - .map(|r| { - let enabled = *self - .enabled_map - .get(&*r.rule_id) - .unwrap_or(&r.default_enabled); - (r, enabled) - }) + self.push_rules.iter().map(|r| { + let enabled = *self + .enabled_map + .get(&*r.rule_id) + .unwrap_or(&r.default_enabled); + (r, enabled) + }) } } diff --git a/stubs/synapse/synapse_rust/push.pyi b/stubs/synapse/synapse_rust/push.pyi index 5900e61450..f2a61df660 100644 --- a/stubs/synapse/synapse_rust/push.pyi +++ b/stubs/synapse/synapse_rust/push.pyi @@ -25,9 +25,7 @@ class PushRules: def rules(self) -> Collection[PushRule]: ... class FilteredPushRules: - def __init__( - self, push_rules: PushRules, enabled_map: Dict[str, bool], msc3772_enabled: bool - ): ... + def __init__(self, push_rules: PushRules, enabled_map: Dict[str, bool]): ... def rules(self) -> Collection[Tuple[PushRule, bool]]: ... def get_base_rule_ids() -> Collection[str]: ... @@ -39,8 +37,6 @@ class PushRuleEvaluator: room_member_count: int, sender_power_level: Optional[int], notification_power_levels: Mapping[str, int], - relations: Mapping[str, Set[Tuple[str, str]]], - relation_match_enabled: bool, ): ... def run( self, diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e00cb7096c..f44655516e 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -95,8 +95,6 @@ class ExperimentalConfig(Config): # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) - # MSC3772: A push rule for mutual relations. - self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) # MSC3773: Thread notifications self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index eced182fd5..8d94aeaa32 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -13,18 +13,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import itertools import logging from typing import ( TYPE_CHECKING, Any, Collection, Dict, - Iterable, List, Mapping, Optional, - Set, Tuple, Union, ) @@ -38,7 +35,7 @@ from synapse.events.snapshot import EventContext from synapse.state import POWER_KEY from synapse.storage.databases.main.roommember import EventIdMembership from synapse.storage.state import StateFilter -from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRuleEvaluator +from synapse.synapse_rust.push import FilteredPushRules, PushRuleEvaluator from synapse.util.caches import register_cache from synapse.util.metrics import measure_func from synapse.visibility import filter_event_for_clients_with_state @@ -117,9 +114,6 @@ class BulkPushRuleEvaluator: resizable=False, ) - # Whether to support MSC3772 is supported. - self._relations_match_enabled = self.hs.config.experimental.msc3772_enabled - async def _get_rules_for_event( self, event: EventBase, @@ -200,51 +194,6 @@ class BulkPushRuleEvaluator: return pl_event.content if pl_event else {}, sender_level - async def _get_mutual_relations( - self, parent_id: str, rules: Iterable[Tuple[PushRule, bool]] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - parent_id: The event ID which is targeted by relations. - rules: The push rules which will be processed for this event. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - - # If the experimental feature is not enabled, skip fetching relations. - if not self._relations_match_enabled: - return {} - - # Pre-filter to figure out which relation types are interesting. - rel_types = set() - for rule, enabled in rules: - if not enabled: - continue - - for condition in rule.conditions: - if condition["kind"] != "org.matrix.msc3772.relation_match": - continue - - # rel_type is required. - rel_type = condition.get("rel_type") - if rel_type: - rel_types.add(rel_type) - - # If no valid rules were found, no mutual relations. - if not rel_types: - return {} - - # If any valid rules were found, fetch the mutual relations. - return await self.store.get_mutual_event_relations(parent_id, rel_types) - @measure_func("action_for_event_by_user") async def action_for_event_by_user( self, event: EventBase, context: EventContext @@ -276,16 +225,11 @@ class BulkPushRuleEvaluator: sender_power_level, ) = await self._get_power_levels_and_sender_level(event, context) + # Find the event's thread ID. relation = relation_from_event(event) - # If the event does not have a relation, then cannot have any mutual - # relations or thread ID. - relations = {} + # If the event does not have a relation, then it cannot have a thread ID. thread_id = MAIN_TIMELINE if relation: - relations = await self._get_mutual_relations( - relation.parent_id, - itertools.chain(*(r.rules() for r in rules_by_user.values())), - ) # Recursively attempt to find the thread this event relates to. if relation.rel_type == RelationTypes.THREAD: thread_id = relation.parent_id @@ -306,8 +250,6 @@ class BulkPushRuleEvaluator: room_member_count, sender_power_level, notification_levels, - relations, - self._relations_match_enabled, ) users = rules_by_user.keys() diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 3b8ed1f7ee..a9f25a5904 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,9 +259,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) - self._attempt_to_invalidate_cache( - "get_mutual_event_relations_for_rel_type", (relates_to,) - ) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 3e15827986..060fe71454 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2024,11 +2024,6 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) - self.store._invalidate_cache_and_stream( - txn, - self.store.get_mutual_event_relations_for_rel_type, - (redacted_relates_to,), - ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 8295322b0e..51416b2236 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -29,7 +29,6 @@ from typing import ( ) from synapse.api.errors import StoreError -from synapse.config.homeserver import ExperimentalConfig from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore from synapse.storage.database import ( @@ -63,9 +62,7 @@ logger = logging.getLogger(__name__) def _load_rules( - rawrules: List[JsonDict], - enabled_map: Dict[str, bool], - experimental_config: ExperimentalConfig, + rawrules: List[JsonDict], enabled_map: Dict[str, bool] ) -> FilteredPushRules: """Take the DB rows returned from the DB and convert them into a full `FilteredPushRules` object. @@ -83,9 +80,7 @@ def _load_rules( push_rules = PushRules(ruleslist) - filtered_rules = FilteredPushRules( - push_rules, enabled_map, msc3772_enabled=experimental_config.msc3772_enabled - ) + filtered_rules = FilteredPushRules(push_rules, enabled_map) return filtered_rules @@ -165,7 +160,7 @@ class PushRulesWorkerStore( enabled_map = await self.get_push_rules_enabled_for_user(user_id) - return _load_rules(rows, enabled_map, self.hs.config.experimental) + return _load_rules(rows, enabled_map) async def get_push_rules_enabled_for_user(self, user_id: str) -> Dict[str, bool]: results = await self.db_pool.simple_select_list( @@ -224,9 +219,7 @@ class PushRulesWorkerStore( results: Dict[str, FilteredPushRules] = {} for user_id, rules in raw_rules.items(): - results[user_id] = _load_rules( - rules, enabled_map_by_user.get(user_id, {}), self.hs.config.experimental - ) + results[user_id] = _load_rules(rules, enabled_map_by_user.get(user_id, {})) return results diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 116abef9de..6b7eec4bf2 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -776,59 +776,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - @cached(iterable=True) - async def get_mutual_event_relations_for_rel_type( - self, event_id: str, relation_type: str - ) -> Set[Tuple[str, str]]: - raise NotImplementedError() - - @cachedList( - cached_method_name="get_mutual_event_relations_for_rel_type", - list_name="relation_types", - ) - async def get_mutual_event_relations( - self, event_id: str, relation_types: Collection[str] - ) -> Dict[str, Set[Tuple[str, str]]]: - """ - Fetch event metadata for events which related to the same event as the given event. - - If the given event has no relation information, returns an empty dictionary. - - Args: - event_id: The event ID which is targeted by relations. - relation_types: The relation types to check for mutual relations. - - Returns: - A dictionary of relation type to: - A set of tuples of: - The sender - The event type - """ - rel_type_sql, rel_type_args = make_in_list_sql_clause( - self.database_engine, "relation_type", relation_types - ) - - sql = f""" - SELECT DISTINCT relation_type, sender, type FROM event_relations - INNER JOIN events USING (event_id) - WHERE relates_to_id = ? AND {rel_type_sql} - """ - - def _get_event_relations( - txn: LoggingTransaction, - ) -> Dict[str, Set[Tuple[str, str]]]: - txn.execute(sql, [event_id] + rel_type_args) - result: Dict[str, Set[Tuple[str, str]]] = { - rel_type: set() for rel_type in relation_types - } - for rel_type, sender, type in txn.fetchall(): - result[rel_type].add((sender, type)) - return result - - return await self.db_pool.runInteraction( - "get_event_relations", _get_event_relations - ) - @cached() async def get_thread_id(self, event_id: str) -> Optional[str]: """ diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py index 8804f0e0d3..decf619466 100644 --- a/tests/push/test_push_rule_evaluator.py +++ b/tests/push/test_push_rule_evaluator.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Optional, Set, Tuple, Union +from typing import Dict, Optional, Union import frozendict @@ -38,12 +38,7 @@ from tests.test_utils.event_injection import create_event, inject_member_event class PushRuleEvaluatorTestCase(unittest.TestCase): - def _get_evaluator( - self, - content: JsonDict, - relations: Optional[Dict[str, Set[Tuple[str, str]]]] = None, - relations_match_enabled: bool = False, - ) -> PushRuleEvaluator: + def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator: event = FrozenEvent( { "event_id": "$event_id", @@ -63,8 +58,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): room_member_count, sender_power_level, power_levels.get("notifications", {}), - relations or {}, - relations_match_enabled, ) def test_display_name(self) -> None: @@ -299,71 +292,6 @@ class PushRuleEvaluatorTestCase(unittest.TestCase): {"sound": "default", "highlight": True}, ) - def test_relation_match(self) -> None: - """Test the relation_match push rule kind.""" - - # Check if the experimental feature is disabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}} - ) - - # A push rule evaluator with the experimental rule enabled. - evaluator = self._get_evaluator( - {}, {"m.annotation": {("@user:test", "m.reaction")}}, True - ) - - # Check just relation type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and sender. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@user:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@other:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check relation type and event type. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "type": "m.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check just sender, this fails since rel_type is required. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "sender": "@user:test", - } - self.assertFalse(evaluator.matches(condition, "@user:test", "foo")) - - # Check sender glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "sender": "@*:test", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - - # Check event type glob. - condition = { - "kind": "org.matrix.msc3772.relation_match", - "rel_type": "m.annotation", - "event_type": "*.reaction", - } - self.assertTrue(evaluator.matches(condition, "@user:test", "foo")) - class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase): """Tests for the bulk push rule evaluator""" -- cgit 1.5.1 From 3bbe532abb7bfc41467597731ac1a18c0331f539 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 13 Oct 2022 08:02:11 -0400 Subject: Add an API for listing threads in a room. (#13394) Implement the /threads endpoint from MSC3856. This is currently unstable and behind an experimental configuration flag. It includes a background update to backfill data, results from the /threads endpoint will be partial until that finishes. --- changelog.d/13394.feature | 1 + synapse/_scripts/synapse_port_db.py | 2 + synapse/config/experimental.py | 3 + synapse/handlers/relations.py | 86 ++++++++++- synapse/rest/client/relations.py | 50 ++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 38 ++++- synapse/storage/databases/main/relations.py | 166 ++++++++++++++++++++- .../schema/main/delta/73/09threads_table.sql | 30 ++++ tests/rest/client/test_relations.py | 151 +++++++++++++++++++ 10 files changed, 522 insertions(+), 6 deletions(-) create mode 100644 changelog.d/13394.feature create mode 100644 synapse/storage/schema/main/delta/73/09threads_table.sql (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/13394.feature b/changelog.d/13394.feature new file mode 100644 index 0000000000..68de079cf3 --- /dev/null +++ b/changelog.d/13394.feature @@ -0,0 +1 @@ +Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 5fa599e70e..d850e54e17 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -72,6 +72,7 @@ from synapse.storage.databases.main.registration import ( RegistrationBackgroundUpdateStore, find_max_generated_user_id_localpart, ) +from synapse.storage.databases.main.relations import RelationsWorkerStore from synapse.storage.databases.main.room import RoomBackgroundUpdateStore from synapse.storage.databases.main.roommember import RoomMemberBackgroundUpdateStore from synapse.storage.databases.main.search import SearchBackgroundUpdateStore @@ -206,6 +207,7 @@ class Store( PusherWorkerStore, PresenceBackgroundUpdateStore, ReceiptsBackgroundUpdateStore, + RelationsWorkerStore, ): def execute(self, f: Callable[..., R], *args: Any, **kwargs: Any) -> Awaitable[R]: return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index f44655516e..1860006536 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -101,6 +101,9 @@ class ExperimentalConfig(Config): # MSC3848: Introduce errcodes for specific event sending failures self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) + # MSC3856: Threads list API + self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False) + # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index cc5e45c241..1fdd7a10bc 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import enum import logging from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple @@ -20,7 +21,7 @@ from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import _RelatedEvent +from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig from synapse.types import JsonDict, Requester, StreamToken, UserID from synapse.visibility import filter_events_for_client @@ -32,6 +33,13 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class ThreadsListInclude(str, enum.Enum): + """Valid values for the 'include' flag of /threads.""" + + all = "all" + participated = "participated" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _ThreadAggregation: # The latest event in the thread. @@ -482,3 +490,79 @@ class RelationsHandler: results.setdefault(event_id, BundledAggregations()).replace = edit return results + + async def get_threads( + self, + requester: Requester, + room_id: str, + include: ThreadsListInclude, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> JsonDict: + """Get related events of a event, ordered by topological ordering. + + Args: + requester: The user requesting the relations. + room_id: The room the event belongs to. + include: One of "all" or "participated" to indicate which threads should + be returned. + limit: Only fetch the most recent `limit` events. + from_token: Fetch rows from the given token, or from the start if None. + + Returns: + The pagination chunk. + """ + + user_id = requester.user.to_string() + + # TODO Properly handle a user leaving a room. + (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( + room_id, requester, allow_departed_users=True + ) + + # Note that ignored users are not passed into get_relations_for_event + # below. Ignored users are handled in filter_events_for_client (and by + # not passing them in here we should get a better cache hit rate). + thread_roots, next_batch = await self._main_store.get_threads( + room_id=room_id, limit=limit, from_token=from_token + ) + + events = await self._main_store.get_events_as_list(thread_roots) + + if include == ThreadsListInclude.participated: + # Pre-seed thread participation with whether the requester sent the event. + participated = {event.event_id: event.sender == user_id for event in events} + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [eid for eid, p in participated.items() if not p], + user_id, + ) + ) + + # Limit the returned threads to those the user has participated in. + events = [event for event in events if participated[event.event_id]] + + events = await filter_events_for_client( + self._storage_controllers, + user_id, + events, + is_peeking=(member_event_id is None), + ) + + aggregations = await self.get_bundled_aggregations( + events, requester.user.to_string() + ) + + now = self._clock.time_msec() + serialized_events = self._event_serializer.serialize_events( + events, now, bundle_aggregations=aggregations + ) + + return_value: JsonDict = {"chunk": serialized_events} + + if next_batch: + return_value["next_batch"] = str(next_batch) + + return return_value diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index b31ce5a0d3..d1aa1947a5 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -13,12 +13,15 @@ # limitations under the License. import logging +import re from typing import TYPE_CHECKING, Optional, Tuple +from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet +from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns +from synapse.storage.databases.main.relations import ThreadsNextBatch from synapse.streams.config import PaginationConfig from synapse.types import JsonDict @@ -78,5 +81,50 @@ class RelationPaginationServlet(RestServlet): return 200, result +class ThreadsServlet(RestServlet): + PATTERNS = ( + re.compile( + "^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P[^/]*)/threads" + ), + ) + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self._relations_handler = hs.get_relations_handler() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + limit = parse_integer(request, "limit", default=5) + from_token_str = parse_string(request, "from") + include = parse_string( + request, + "include", + default=ThreadsListInclude.all.value, + allowed_values=[v.value for v in ThreadsListInclude], + ) + + # Return the relations + from_token = None + if from_token_str: + from_token = ThreadsNextBatch.from_string(from_token_str) + + result = await self._relations_handler.get_threads( + requester=requester, + room_id=room_id, + include=ThreadsListInclude(include), + limit=limit, + from_token=from_token, + ) + + return 200, result + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) + if hs.config.experimental.msc3856_enabled: + ThreadsServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index a9f25a5904..0ce3156c9c 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_applicable_edit", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_summary", (relates_to,)) self._attempt_to_invalidate_cache("get_thread_participated", (relates_to,)) + self._attempt_to_invalidate_cache("get_threads", (room_id,)) async def invalidate_cache_and_stream( self, cache_name: str, keys: Tuple[Any, ...] diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 060fe71454..6698cbf664 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -35,7 +35,7 @@ import attr from prometheus_client import Counter import synapse.metrics -from synapse.api.constants import EventContentFields, EventTypes +from synapse.api.constants import EventContentFields, EventTypes, RelationTypes from synapse.api.errors import Codes, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, relation_from_event @@ -1616,7 +1616,7 @@ class PersistEventsStore: ) # Remove from relations table. - self._handle_redact_relations(txn, event.redacts) + self._handle_redact_relations(txn, event.room_id, event.redacts) # Update the event_forward_extremities, event_backward_extremities and # event_edges tables. @@ -1866,6 +1866,34 @@ class PersistEventsStore: }, ) + if relation.rel_type == RelationTypes.THREAD: + # Upsert into the threads table, but only overwrite the value if the + # new event is of a later topological order OR if the topological + # ordering is equal, but the stream ordering is later. + sql = """ + INSERT INTO threads (room_id, thread_id, latest_event_id, topological_ordering, stream_ordering) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT (room_id, thread_id) + DO UPDATE SET + latest_event_id = excluded.latest_event_id, + topological_ordering = excluded.topological_ordering, + stream_ordering = excluded.stream_ordering + WHERE + threads.topological_ordering <= excluded.topological_ordering AND + threads.stream_ordering < excluded.stream_ordering + """ + + txn.execute( + sql, + ( + event.room_id, + relation.parent_id, + event.event_id, + event.depth, + event.internal_metadata.stream_ordering, + ), + ) + def _handle_insertion_event( self, txn: LoggingTransaction, event: EventBase ) -> None: @@ -1989,13 +2017,14 @@ class PersistEventsStore: txn.execute(sql, (batch_id,)) def _handle_redact_relations( - self, txn: LoggingTransaction, redacted_event_id: str + self, txn: LoggingTransaction, room_id: str, redacted_event_id: str ) -> None: """Handles receiving a redaction and checking whether the redacted event has any relations which must be removed from the database. Args: txn + room_id: The room ID of the event that was redacted. redacted_event_id: The event that was redacted. """ @@ -2024,6 +2053,9 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_thread_participated, (redacted_relates_to,) ) + self.store._invalidate_cache_and_stream( + txn, self.store.get_threads, (room_id,) + ) self.db_pool.simple_delete_txn( txn, table="event_relations", keyvalues={"event_id": redacted_event_id} diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index e7fbf950e6..ac9b96ab44 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -14,6 +14,7 @@ import logging from typing import ( + TYPE_CHECKING, Collection, Dict, FrozenSet, @@ -29,17 +30,46 @@ from typing import ( import attr from synapse.api.constants import MAIN_TIMELINE, RelationTypes +from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause +from synapse.storage.database import ( + DatabasePool, + LoggingDatabaseConnection, + LoggingTransaction, + make_in_list_sql_clause, +) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken from synapse.util.caches.descriptors import cached, cachedList +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadsNextBatch: + topological_ordering: int + stream_ordering: int + + def __str__(self) -> str: + return f"{self.topological_ordering}_{self.stream_ordering}" + + @classmethod + def from_string(cls, string: str) -> "ThreadsNextBatch": + """ + Creates a ThreadsNextBatch from its textual representation. + """ + try: + keys = (int(s) for s in string.split("_")) + return cls(*keys) + except Exception: + raise SynapseError(400, "Invalid threads token") + + @attr.s(slots=True, frozen=True, auto_attribs=True) class _RelatedEvent: """ @@ -56,6 +86,76 @@ class _RelatedEvent: class RelationsWorkerStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "threads_backfill", self._backfill_threads + ) + + async def _backfill_threads(self, progress: JsonDict, batch_size: int) -> int: + """Backfill the threads table.""" + + def threads_backfill_txn(txn: LoggingTransaction) -> int: + last_thread_id = progress.get("last_thread_id", "") + + # Get the latest event in each thread by topo ordering / stream ordering. + # + # Note that the MAX(event_id) is needed to abide by the rules of group by, + # but doesn't actually do anything since there should only be a single event + # ID per topo/stream ordering pair. + sql = f""" + SELECT room_id, relates_to_id, MAX(topological_ordering), MAX(stream_ordering), MAX(event_id) + FROM event_relations + INNER JOIN events USING (event_id) + WHERE + relates_to_id > ? AND + relation_type = '{RelationTypes.THREAD}' + GROUP BY room_id, relates_to_id + ORDER BY relates_to_id + LIMIT ? + """ + txn.execute(sql, (last_thread_id, batch_size)) + + # No more rows to process. + rows = txn.fetchall() + if not rows: + return 0 + + # Insert the rows into the threads table. If a matching thread already exists, + # assume it is from a newer event. + sql = """ + INSERT INTO threads (room_id, thread_id, topological_ordering, stream_ordering, latest_event_id) + VALUES %s + ON CONFLICT (room_id, thread_id) + DO NOTHING + """ + if isinstance(txn.database_engine, PostgresEngine): + txn.execute_values(sql % ("?",), rows, fetch=False) + else: + txn.execute_batch(sql % ("?, ?, ?, ?, ?",), rows) + + # Mark the progress. + self.db_pool.updates._background_update_progress_txn( + txn, "threads_backfill", {"last_thread_id": rows[-1][1]} + ) + + return txn.rowcount + + result = await self.db_pool.runInteraction( + "threads_backfill", threads_backfill_txn + ) + + if not result: + await self.db_pool.updates._end_background_update("threads_backfill") + + return result + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, @@ -776,6 +876,70 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) + @cached(tree=True) + async def get_threads( + self, + room_id: str, + limit: int = 5, + from_token: Optional[ThreadsNextBatch] = None, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + """Get a list of thread IDs, ordered by topological ordering of their + latest reply. + + Args: + room_id: The room the event belongs to. + limit: Only fetch the most recent `limit` threads. + from_token: Fetch rows from a previous next_batch, or from the start if None. + + Returns: + A tuple of: + A list of thread root event IDs. + + The next_batch, if one exists. + """ + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is equal / before the last + # thread's topo ordering and earlier in stream ordering. + pagination_clause = "" + pagination_args: tuple = () + if from_token: + pagination_clause = "AND topological_ordering <= ? AND stream_ordering < ?" + pagination_args = ( + from_token.topological_ordering, + from_token.stream_ordering, + ) + + sql = f""" + SELECT thread_id, topological_ordering, stream_ordering + FROM threads + WHERE + room_id = ? + {pagination_clause} + ORDER BY topological_ordering DESC, stream_ordering DESC + LIMIT ? + """ + + def _get_threads_txn( + txn: LoggingTransaction, + ) -> Tuple[List[str], Optional[ThreadsNextBatch]]: + txn.execute(sql, (room_id, *pagination_args, limit + 1)) + + rows = cast(List[Tuple[str, int, int]], txn.fetchall()) + thread_ids = [r[0] for r in rows] + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(thread_ids) > limit: + last_topo_id = rows[-2][1] + last_stream_id = rows[-2][2] + next_token = ThreadsNextBatch(last_topo_id, last_stream_id) + + return thread_ids[:limit], next_token + + return await self.db_pool.runInteraction("get_threads", _get_threads_txn) + @cached() async def get_thread_id(self, event_id: str) -> str: """ diff --git a/synapse/storage/schema/main/delta/73/09threads_table.sql b/synapse/storage/schema/main/delta/73/09threads_table.sql new file mode 100644 index 0000000000..aa7c5e9a2e --- /dev/null +++ b/synapse/storage/schema/main/delta/73/09threads_table.sql @@ -0,0 +1,30 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE threads ( + room_id TEXT NOT NULL, + -- The event ID of the root event in the thread. + thread_id TEXT NOT NULL, + -- The latest event ID and corresponding topo / stream ordering. + latest_event_id TEXT NOT NULL, + topological_ordering BIGINT NOT NULL, + stream_ordering BIGINT NOT NULL, + CONSTRAINT threads_uniqueness UNIQUE (room_id, thread_id) +); + +CREATE INDEX threads_ordering_idx ON threads(room_id, topological_ordering, stream_ordering); + +INSERT INTO background_updates (ordering, update_name, progress_json) VALUES + (7309, 'threads_backfill', '{}'); diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 988cdb746d..d595295e2c 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1707,3 +1707,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase): relations[RelationTypes.THREAD]["latest_event"]["event_id"], related_event_id, ) + + +class ThreadsTestCase(BaseRelationsTestCase): + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_threads(self) -> None: + """Create threads and ensure the ordering is due to their latest event.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1]) + + # Update the first thread, the ordering should swap. + self._send_relation(RelationTypes.THREAD, "m.room.test") + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1, thread_2]) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_pagination(self) -> None: + """Create threads and paginate through them.""" + # Create 2 threads. + thread_1 = self.parent_id + res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token) + thread_2 = res["event_id"] + + self._send_relation(RelationTypes.THREAD, "m.room.test") + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Request the threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2]) + + # Make sure next_batch has something in it that looks like it could be a + # valid token. + next_batch = channel.json_body.get("next_batch") + self.assertIsInstance(next_batch, str, channel.json_body) + + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) + + self.assertNotIn("next_batch", channel.json_body, channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_include(self) -> None: + """Filtering threads to all or participated in should work.""" + # Thread 1 has the user as the root event. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 has the user replying. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Thread 3 has the user not participating in. + res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token) + thread_3 = res["event_id"] + self._send_relation( + RelationTypes.THREAD, + "m.room.test", + access_token=self.user2_token, + parent_id=thread_3, + ) + + # All threads in the room. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual( + thread_roots, [thread_3, thread_2, thread_1], channel.json_body + ) + + # Only participated threads. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body) + + @unittest.override_config({"experimental_features": {"msc3856_enabled": True}}) + def test_ignored_user(self) -> None: + """Events from ignored users should be ignored.""" + # Thread 1 has a reply from an ignored user. + thread_1 = self.parent_id + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + + # Thread 2 is created by an ignored user. + res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token) + thread_2 = res["event_id"] + self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2) + + # Ignore user2. + self.get_success( + self.store.add_account_data_for_user( + self.user_id, + AccountDataTypes.IGNORED_USER_LIST, + {"ignored_users": {self.user2_id: {}}}, + ) + ) + + # Only thread 1 is returned. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads", + access_token=self.user_token, + ) + self.assertEquals(200, channel.code, channel.json_body) + thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]] + self.assertEqual(thread_roots, [thread_1], channel.json_body) -- cgit 1.5.1 From 9ff4155f6cc9fc0b7aff82da9f0a1cae677dbda5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 07:10:44 -0400 Subject: Properly invalidate get_thread_id cache. (#14163) This was missed in 2b6d41ebd685fb546e52acdbcb0024dfcf5a5db1 (#13824). --- changelog.d/14163.feature | 1 + synapse/storage/databases/main/cache.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/14163.feature (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/14163.feature b/changelog.d/14163.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14163.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 0ce3156c9c..b47fc606c7 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -244,6 +244,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): # redacted. self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) -- cgit 1.5.1 From d1bdeccb50550ef454067aa01dd9d004c4704633 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 14 Oct 2022 14:05:25 -0400 Subject: Accept threaded receipts for events related to the root event. (#14174) The root node of a thread (and events related to it) are considered "part of a thread" when validating receipts. This allows clients which show the root node in both the main timeline and the threaded timeline to easily send receipts in either. Note that threaded notifications are not created for these events, these events created notifications on the main timeline. --- changelog.d/14174.feature | 1 + synapse/rest/client/receipts.py | 44 ++++++++++- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/relations.py | 98 ++++++++++++++++++++++-- tests/storage/test_relations.py | 111 ++++++++++++++++++++++++++++ 5 files changed, 247 insertions(+), 8 deletions(-) create mode 100644 changelog.d/14174.feature create mode 100644 tests/storage/test_relations.py (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/14174.feature b/changelog.d/14174.feature new file mode 100644 index 0000000000..5d0ae16e13 --- /dev/null +++ b/changelog.d/14174.feature @@ -0,0 +1 @@ +Support for thread-specific notifications & receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771) and [MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)). diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 14dec7ac4e..18a282b22c 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReceiptTypes +from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -83,7 +83,7 @@ class ReceiptRestServlet(RestServlet): ) # Ensure the event ID roughly correlates to the thread ID. - if thread_id != await self._main_store.get_thread_id(event_id): + if not await self._is_event_in_thread(event_id, thread_id): raise SynapseError( 400, f"event_id {event_id} is not related to thread {thread_id}", @@ -109,6 +109,46 @@ class ReceiptRestServlet(RestServlet): return 200, {} + async def _is_event_in_thread(self, event_id: str, thread_id: str) -> bool: + """ + The event must be related to the thread ID (in a vague sense) to ensure + clients aren't sending bogus receipts. + + A thread ID is considered valid for a given event E if: + + 1. E has a thread relation which matches the thread ID; + 2. E has another event which has a thread relation to E matching the + thread ID; or + 3. E is recursively related (via any rel_type) to an event which + satisfies 1 or 2. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + It is valid to send a receipt for thread A on A, B, C, D, or E. + + It is valid to send a receipt for the main timeline on A, D, and E. + + Args: + event_id: The event ID to check. + thread_id: The thread ID the event is potentially part of. + + Returns: + True if the event belongs to the given thread, otherwise False. + """ + + # If the receipt is on the main timeline, it is enough to check whether + # the event is directly related to a thread. + if thread_id == MAIN_TIMELINE: + return MAIN_TIMELINE == await self._main_store.get_thread_id(event_id) + + # Otherwise, check if the event is directly part of a thread, or is the + # root message (or related to the root message) of a thread. + return thread_id == await self._main_store.get_thread_id_for_receipts(event_id) + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index b47fc606c7..ed0be4abe5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -245,6 +245,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache("get_relations_for_event", (redacts,)) self._attempt_to_invalidate_cache("get_applicable_edit", (redacts,)) self._attempt_to_invalidate_cache("get_thread_id", (redacts,)) + self._attempt_to_invalidate_cache("get_thread_id_for_receipts", (redacts,)) if etype == EventTypes.Member: self._membership_stream_cache.entity_has_changed(state_key, stream_ordering) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7c54ce0b2e..1de62ee9df 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -946,6 +946,20 @@ class RelationsWorkerStore(SQLBaseStore): Get the thread ID for an event. This considers multi-level relations, e.g. an annotation to an event which is part of a thread. + It only searches up the relations tree, i.e. it only searches for events + which the given event is related to (and which those events are related + to, etc.) + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id(X) considers events B and C as part of thread A. + + See also get_thread_id_for_receipts. + Args: event_id: The event ID to fetch the thread ID for. @@ -953,22 +967,32 @@ class RelationsWorkerStore(SQLBaseStore): The event ID of the root event in the thread, if this event is part of a thread. "main", otherwise. """ - # Since event relations form a tree, we should only ever find 0 or 1 - # results from the below query. + + # Recurse event relations up to the *root* event, then search that chain + # of relations for a thread relation. If one is found, the root event is + # returned. + # + # Note that this should only ever find 0 or 1 entries since it is invalid + # for an event to have a thread relation to an event which also has a + # relation. sql = """ WITH RECURSIVE related_events AS ( - SELECT event_id, relates_to_id, relation_type + SELECT event_id, relates_to_id, relation_type, 0 depth FROM event_relations WHERE event_id = ? - UNION SELECT e.event_id, e.relates_to_id, e.relation_type + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 FROM event_relations e INNER JOIN related_events r ON r.relates_to_id = e.event_id - ) SELECT relates_to_id FROM related_events WHERE relation_type = 'm.thread'; + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + WHERE relation_type = 'm.thread' + ORDER BY depth DESC + LIMIT 1; """ def _get_thread_id(txn: LoggingTransaction) -> str: txn.execute(sql, (event_id,)) - # TODO Should we ensure there's only a single result here? row = txn.fetchone() if row: return row[0] @@ -978,6 +1002,68 @@ class RelationsWorkerStore(SQLBaseStore): return await self.db_pool.runInteraction("get_thread_id", _get_thread_id) + @cached() + async def get_thread_id_for_receipts(self, event_id: str) -> str: + """ + Get the thread ID for an event by traversing to the top-most related event + and confirming any children events form a thread. + + Given the following DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + get_thread_id_for_receipts(X) considers events A, B, C, D, and E as part + of thread A. + + See also get_thread_id. + + Args: + event_id: The event ID to fetch the thread ID for. + + Returns: + The event ID of the root event in the thread, if this event is part + of a thread. "main", otherwise. + """ + + # Recurse event relations up to the *root* event, then search for any events + # related to that root node for a thread relation. If one is found, the + # root event is returned. + # + # Note that there cannot be thread relations in the middle of the chain since + # it is invalid for an event to have a thread relation to an event which also + # has a relation. + sql = """ + SELECT relates_to_id FROM event_relations WHERE relates_to_id = COALESCE(( + WITH RECURSIVE related_events AS ( + SELECT event_id, relates_to_id, relation_type, 0 depth + FROM event_relations + WHERE event_id = ? + UNION SELECT e.event_id, e.relates_to_id, e.relation_type, depth + 1 + FROM event_relations e + INNER JOIN related_events r ON r.relates_to_id = e.event_id + WHERE depth <= 3 + ) + SELECT relates_to_id FROM related_events + ORDER BY depth DESC + LIMIT 1 + ), ?) AND relation_type = 'm.thread' LIMIT 1; + """ + + def _get_related_thread_id(txn: LoggingTransaction) -> str: + txn.execute(sql, (event_id, event_id)) + row = txn.fetchone() + if row: + return row[0] + + # If no thread was found, it is part of the main timeline. + return MAIN_TIMELINE + + return await self.db_pool.runInteraction( + "get_related_thread_id", _get_related_thread_id + ) + class RelationsStore(RelationsWorkerStore): pass diff --git a/tests/storage/test_relations.py b/tests/storage/test_relations.py new file mode 100644 index 0000000000..cd1d00208b --- /dev/null +++ b/tests/storage/test_relations.py @@ -0,0 +1,111 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.api.constants import MAIN_TIMELINE +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest + + +class RelationsStoreTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + """ + Creates a DAG: + + A <---[m.thread]-- B <--[m.annotation]-- C + ^ + |--[m.reference]-- D <--[m.annotation]-- E + + F <--[m.annotation]-- G + + """ + self._main_store = self.hs.get_datastores().main + + self._create_relation("A", "B", "m.thread") + self._create_relation("B", "C", "m.annotation") + self._create_relation("A", "D", "m.reference") + self._create_relation("D", "E", "m.annotation") + self._create_relation("F", "G", "m.annotation") + + def _create_relation(self, parent_id: str, event_id: str, rel_type: str) -> None: + self.get_success( + self._main_store.db_pool.simple_insert( + table="event_relations", + values={ + "event_id": event_id, + "relates_to_id": parent_id, + "relation_type": rel_type, + }, + ) + ) + + def test_get_thread_id(self) -> None: + """ + Ensure that get_thread_id only searches up the tree for threads. + """ + # The thread itself and children of it return the thread. + thread_id = self.get_success(self._main_store.get_thread_id("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("C")) + self.assertEqual("A", thread_id) + + # But the root and events related to the root do not. + thread_id = self.get_success(self._main_store.get_thread_id("A")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("D")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("E")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + def test_get_thread_id_for_receipts(self) -> None: + """ + Ensure that get_thread_id_for_receipts searches up and down the tree for a thread. + """ + # All of the events are considered related to this thread. + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("A")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("B")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("C")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("D")) + self.assertEqual("A", thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id_for_receipts("E")) + self.assertEqual("A", thread_id) + + # Events which are not related to a thread at all should return the + # main timeline. + thread_id = self.get_success(self._main_store.get_thread_id("F")) + self.assertEqual(MAIN_TIMELINE, thread_id) + + thread_id = self.get_success(self._main_store.get_thread_id("G")) + self.assertEqual(MAIN_TIMELINE, thread_id) -- cgit 1.5.1 From 2c2c3f8b2c1e33d5aee6d480c60c75c1179e3dba Mon Sep 17 00:00:00 2001 From: Nick Mills-Barrett Date: Mon, 17 Oct 2022 13:27:51 +0100 Subject: Invalidate rooms for user caches when receiving membership events (#14155) This should fix a race where the event notification comes in over replication before the state replication, leaving a window during which a sync may get an incorrect list of rooms for the user. --- changelog.d/14155.misc | 1 + synapse/storage/databases/main/cache.py | 4 ++++ 2 files changed, 5 insertions(+) create mode 100644 changelog.d/14155.misc (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/14155.misc b/changelog.d/14155.misc new file mode 100644 index 0000000000..79539cdc32 --- /dev/null +++ b/changelog.d/14155.misc @@ -0,0 +1 @@ +Invalidate rooms for user caches on replicated event, fix sync cache race in synapse workers. Contributed by Nick @ Beeper (@fizzadar). diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ed0be4abe5..ddb7397714 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -252,6 +252,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._attempt_to_invalidate_cache( "get_invited_rooms_for_local_user", (state_key,) ) + self._attempt_to_invalidate_cache( + "get_rooms_for_user_with_stream_ordering", (state_key,) + ) + self._attempt_to_invalidate_cache("get_rooms_for_user", (state_key,)) if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) -- cgit 1.5.1 From 6d7523ef1484ec56f4a6dffdd2ea3d8736b4cc98 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 22 Nov 2022 09:41:09 -0500 Subject: Batch fetch bundled references (#14508) Avoid an n+1 query problem and fetch the bundled aggregations for m.reference relations in a single query instead of a query per event. This applies similar logic for as was previously done for edits in 8b309adb436c162510ed1402f33b8741d71fc058 (#11660; threads in b65acead428653b988351ae8d7b22127a22039cd (#11752); and annotations in 1799a54a545618782840a60950ef4b64da9ee24d (#14491). --- changelog.d/14508.feature | 1 + synapse/handlers/relations.py | 128 +++++++++++++--------------- synapse/storage/databases/main/cache.py | 1 + synapse/storage/databases/main/events.py | 4 + synapse/storage/databases/main/relations.py | 74 ++++++++++++++-- tests/rest/client/test_relations.py | 4 +- 6 files changed, 133 insertions(+), 79 deletions(-) create mode 100644 changelog.d/14508.feature (limited to 'synapse/storage/databases/main/cache.py') diff --git a/changelog.d/14508.feature b/changelog.d/14508.feature new file mode 100644 index 0000000000..4fca7282f7 --- /dev/null +++ b/changelog.d/14508.feature @@ -0,0 +1 @@ +Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index ca94239f61..8414be5879 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -13,16 +13,7 @@ # limitations under the License. import enum import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional import attr @@ -32,7 +23,7 @@ from synapse.events import EventBase, relation_from_event from synapse.logging.opentracing import trace from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, StreamToken, UserID +from synapse.types import JsonDict, Requester, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -181,40 +172,6 @@ class RelationsHandler: return return_value - async def get_relations_for_event( - self, - event_id: str, - event: EventBase, - room_id: str, - relation_type: str, - ignored_users: FrozenSet[str] = frozenset(), - ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: - """Get a list of events which relate to an event, ordered by topological ordering. - - Args: - event_id: Fetch events that relate to this event ID. - event: The matching EventBase to event_id. - room_id: The room the event belongs to. - relation_type: The type of relation. - ignored_users: The users ignored by the requesting user. - - Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. - """ - - # Call the underlying storage method, which is cached. - related_events, next_token = await self._main_store.get_relations_for_event( - event_id, event, room_id, relation_type, direction="f" - ) - - # Filter out ignored users and convert to the expected format. - related_events = [ - event for event in related_events if event.sender not in ignored_users - ] - - return related_events, next_token - async def redact_events_related_to( self, requester: Requester, @@ -329,6 +286,46 @@ class RelationsHandler: return filtered_results + async def get_references_for_events( + self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset() + ) -> Dict[str, List[_RelatedEvent]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to this event ID. + ignored_users: The users ignored by the requesting user. + + Returns: + A map of event IDs to a list related events. + """ + + related_events = await self._main_store.get_references_for_events(event_ids) + + # Avoid additional logic if there are no ignored users. + if not ignored_users: + return { + event_id: results + for event_id, results in related_events.items() + if results + } + + # Filter out ignored users. + results = {} + for event_id, events in related_events.items(): + # If no references, skip. + if not events: + continue + + # Filter ignored users out. + events = [event for event in events if event.sender not in ignored_users] + # If there are no events left, skip this event. + if not events: + continue + + results[event_id] = events + + return results + async def _get_threads_for_events( self, events_by_id: Dict[str, EventBase], @@ -412,14 +409,18 @@ class RelationsHandler: if event is None: continue - potential_events, _ = await self.get_relations_for_event( - event_id, - event, - room_id, - RelationTypes.THREAD, - ignored_users, + # Attempt to find another event to use as the latest event. + potential_events, _ = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.THREAD, direction="f" ) + # Filter out ignored users. + potential_events = [ + event + for event in potential_events + if event.sender not in ignored_users + ] + # If all found events are from ignored users, do not include # a summary of the thread. if not potential_events: @@ -534,27 +535,16 @@ class RelationsHandler: "chunk": annotations } - # Fetch other relations per event. - for event in events_by_id.values(): - # Fetch any references to bundle with this event. - references, next_token = await self.get_relations_for_event( - event.event_id, - event, - event.room_id, - RelationTypes.REFERENCE, - ignored_users=ignored_users, - ) + # Fetch any references to bundle with this event. + references_by_event_id = await self.get_references_for_events( + events_by_id.keys(), ignored_users=ignored_users + ) + for event_id, references in references_by_event_id.items(): if references: - aggregations = results.setdefault(event.event_id, BundledAggregations()) - aggregations.references = { + results.setdefault(event_id, BundledAggregations()).references = { "chunk": [{"event_id": ev.event_id} for ev in references] } - if next_token: - aggregations.references["next_batch"] = await next_token.to_string( - self._main_store - ) - # Fetch any edits (but not for redacted events). # # Note that there is no use in limiting edits by ignored users since the @@ -600,7 +590,7 @@ class RelationsHandler: room_id, requester, allow_departed_users=True ) - # Note that ignored users are not passed into get_relations_for_event + # Note that ignored users are not passed into get_threads # below. Ignored users are handled in filter_events_for_client (and by # not passing them in here we should get a better cache hit rate). thread_roots, next_batch = await self._main_store.get_threads( diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index ddb7397714..a58668a380 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -259,6 +259,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): if relates_to: self._attempt_to_invalidate_cache("get_relations_for_event", (relates_to,)) + self._attempt_to_invalidate_cache("get_references_for_event", (relates_to,)) self._attempt_to_invalidate_cache( "get_aggregation_groups_for_event", (relates_to,) ) diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index d68f127f9b..0f097a2927 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -2049,6 +2049,10 @@ class PersistEventsStore: self.store._invalidate_cache_and_stream( txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,) ) + if rel_type == RelationTypes.REFERENCE: + self.store._invalidate_cache_and_stream( + txn, self.store.get_references_for_event, (redacted_relates_to,) + ) if rel_type == RelationTypes.REPLACE: self.store._invalidate_cache_and_stream( txn, self.store.get_applicable_edit, (redacted_relates_to,) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index f96a16956a..aea96e9d24 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -82,8 +82,6 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str - topological_ordering: Optional[int] - stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -246,13 +244,17 @@ class RelationsWorkerStore(SQLBaseStore): txn.execute(sql, where_args + [limit + 1]) events = [] - for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: + topo_orderings: List[int] = [] + stream_orderings: List[int] = [] + for event_id, relation_type, sender, topo_ordering, stream_ordering in cast( + List[Tuple[str, str, str, int, int]], txn + ): # Do not include edits for redacted events as they leak event # content. if not is_redacted or relation_type != RelationTypes.REPLACE: - events.append( - _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) - ) + events.append(_RelatedEvent(event_id, sender)) + topo_orderings.append(topo_ordering) + stream_orderings.append(stream_ordering) # If there are more events, generate the next pagination key from the # last event returned. @@ -261,9 +263,11 @@ class RelationsWorkerStore(SQLBaseStore): # Instead of using the last row (which tells us there is more # data), use the last row to be returned. events = events[:limit] + topo_orderings = topo_orderings[:limit] + stream_orderings = stream_orderings[:limit] - topo = events[-1].topological_ordering - token = events[-1].stream_ordering + topo = topo_orderings[-1] + token = stream_orderings[-1] if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. @@ -530,6 +534,60 @@ class RelationsWorkerStore(SQLBaseStore): "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn ) + @cached() + async def get_references_for_event(self, event_id: str) -> List[JsonDict]: + raise NotImplementedError() + + @cachedList(cached_method_name="get_references_for_event", list_name="event_ids") + async def get_references_for_events( + self, event_ids: Collection[str] + ) -> Mapping[str, Optional[List[_RelatedEvent]]]: + """Get a list of references to the given events. + + Args: + event_ids: Fetch events that relate to these event IDs. + + Returns: + A map of event IDs to a list of related event IDs (and their senders). + """ + + clause, args = make_in_list_sql_clause( + self.database_engine, "relates_to_id", event_ids + ) + args.append(RelationTypes.REFERENCE) + + sql = f""" + SELECT relates_to_id, ref.event_id, ref.sender + FROM events AS ref + INNER JOIN event_relations USING (event_id) + INNER JOIN events AS parent ON + parent.event_id = relates_to_id + AND parent.room_id = ref.room_id + WHERE + {clause} + AND relation_type = ? + ORDER BY ref.topological_ordering, ref.stream_ordering + """ + + def _get_references_for_events_txn( + txn: LoggingTransaction, + ) -> Mapping[str, List[_RelatedEvent]]: + txn.execute(sql, args) + + result: Dict[str, List[_RelatedEvent]] = {} + for relates_to_id, event_id, sender in cast( + List[Tuple[str, str, str]], txn + ): + result.setdefault(relates_to_id, []).append( + _RelatedEvent(event_id, sender) + ) + + return result + + return await self.db_pool.runInteraction( + "_get_references_for_events_txn", _get_references_for_events_txn + ) + @cached() def get_applicable_edit(self, event_id: str) -> Optional[EventBase]: raise NotImplementedError() diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 2d2b683548..b86f341ff5 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # The "user" sent the root event and is making queries for the bundled # aggregations: they have participated. - self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 7) # The "user2" sent replies in the thread and is making queries for the # bundled aggregations: they have participated. # @@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 7) def test_nested_thread(self) -> None: """ -- cgit 1.5.1