From 88cd6f937807e64c05458cec86ef0ba0c1c656b3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 09:03:59 -0500 Subject: Allow retrieving the relations of a redacted event. (#12130) This is allowed per MSC2675, although the original implementation did not allow for it and would return an empty chunk / not bundle aggregations. The main thing to improve is that the various caches get cleared properly when an event is redacted, and that edits must not leak if the original event is redacted (as that would presumably leak something similar to the original event content). --- synapse/storage/databases/main/relations.py | 60 +++++++++++++++-------------- 1 file changed, 32 insertions(+), 28 deletions(-) (limited to 'synapse/storage/databases/main/relations.py') diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 36aa1092f6..be1500092b 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -91,10 +91,11 @@ class RelationsWorkerStore(SQLBaseStore): self._msc3440_enabled = hs.config.experimental.msc3440_enabled - @cached(tree=True) + @cached(uncached_args=("event",), tree=True) async def get_relations_for_event( self, event_id: str, + event: EventBase, room_id: str, relation_type: Optional[str] = None, event_type: Optional[str] = None, @@ -108,6 +109,7 @@ class RelationsWorkerStore(SQLBaseStore): Args: event_id: Fetch events that relate to this event ID. + event: The matching EventBase to event_id. room_id: The room the event belongs to. relation_type: Only fetch events with this relation type, if given. event_type: Only fetch events with this event type, if given. @@ -122,9 +124,13 @@ class RelationsWorkerStore(SQLBaseStore): List of event IDs that match relations requested. The rows are of the form `{"event_id": "..."}`. """ + # We don't use `event_id`, it's there so that we can cache based on + # it. The `event_id` must match the `event.event_id`. + assert event.event_id == event_id where_clause = ["relates_to_id = ?", "room_id = ?"] - where_args: List[Union[str, int]] = [event_id, room_id] + where_args: List[Union[str, int]] = [event.event_id, room_id] + is_redacted = event.internal_metadata.is_redacted() if relation_type is not None: where_clause.append("relation_type = ?") @@ -157,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore): order = "ASC" sql = """ - SELECT event_id, topological_ordering, stream_ordering + SELECT event_id, relation_type, topological_ordering, stream_ordering FROM event_relations INNER JOIN events USING (event_id) WHERE %s @@ -178,9 +184,12 @@ class RelationsWorkerStore(SQLBaseStore): last_stream_id = None events = [] for row in txn: - events.append({"event_id": row[0]}) - last_topo_id = row[1] - last_stream_id = row[2] + # Do not include edits for redacted events as they leak event + # content. + if not is_redacted or row[1] != RelationTypes.REPLACE: + events.append({"event_id": row[0]}) + last_topo_id = row[2] + last_stream_id = row[3] # If there are more events, generate the next pagination key. next_token = None @@ -776,7 +785,7 @@ class RelationsWorkerStore(SQLBaseStore): ) references = await self.get_relations_for_event( - event_id, room_id, RelationTypes.REFERENCE, direction="f" + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: aggregations.references = await references.to_dict(cast("DataStore", self)) @@ -797,41 +806,36 @@ class RelationsWorkerStore(SQLBaseStore): A map of event ID to the bundled aggregation for the event. Not all events may have bundled aggregations in the results. """ - # The already processed event IDs. Tracked separately from the result - # since the result omits events which do not have bundled aggregations. - seen_event_ids = set() - - # State events and redacted events do not get bundled aggregations. - events = [ - event - for event in events - if not event.is_state() and not event.internal_metadata.is_redacted() - ] + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } # event ID -> bundled aggregation in non-serialized form. results: Dict[str, BundledAggregations] = {} # Fetch other relations per event. - for event in events: - # De-duplicate events by ID to handle the same event requested multiple - # times. The caches that _get_bundled_aggregation_for_event use should - # capture this, but best to reduce work. - if event.event_id in seen_event_ids: - continue - seen_event_ids.add(event.event_id) - + for event in events_by_id.values(): event_result = await self._get_bundled_aggregation_for_event(event, user_id) if event_result: results[event.event_id] = event_result - # Fetch any edits. - edits = await self._get_applicable_edits(seen_event_ids) + # Fetch any edits (but not for redacted events). + edits = await self._get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) for event_id, edit in edits.items(): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. if self._msc3440_enabled: - summaries = await self._get_thread_summaries(seen_event_ids) + summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. participated = await self._get_threads_participated( -- cgit 1.5.1 From ea27528b5d177dcfc5a4e38b463baeace916dc8e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 10:36:13 -0500 Subject: Support stable identifiers for MSC3440: Threading (#12151) The unstable identifiers are still supported if the experimental configuration flag is enabled. The unstable identifiers will be removed in a future release. --- changelog.d/12151.feature | 1 + synapse/api/constants.py | 4 +- synapse/api/filtering.py | 23 ++++----- synapse/events/utils.py | 9 +++- synapse/handlers/message.py | 5 +- synapse/rest/client/versions.py | 1 + synapse/server.py | 2 +- synapse/storage/databases/main/events.py | 5 +- synapse/storage/databases/main/relations.py | 77 ++++++++++++++++++----------- synapse/storage/databases/main/stream.py | 18 ++++--- tests/rest/client/test_relations.py | 7 +-- tests/rest/client/test_rooms.py | 18 +++---- tests/storage/test_stream.py | 20 ++++---- 13 files changed, 109 insertions(+), 81 deletions(-) create mode 100644 changelog.d/12151.feature (limited to 'synapse/storage/databases/main/relations.py') diff --git a/changelog.d/12151.feature b/changelog.d/12151.feature new file mode 100644 index 0000000000..18432b2da9 --- /dev/null +++ b/changelog.d/12151.feature @@ -0,0 +1 @@ +Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 36ace7c613..b0c08a074d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -178,7 +178,9 @@ class RelationTypes: ANNOTATION: Final = "m.annotation" REPLACE: Final = "m.replace" REFERENCE: Final = "m.reference" - THREAD: Final = "io.element.thread" + THREAD: Final = "m.thread" + # TODO Remove this in Synapse >= v1.57.0. + UNSTABLE_THREAD: Final = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cb532d7238..27e97d6f37 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -88,7 +88,9 @@ ROOM_EVENT_FILTER_SCHEMA = { "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, # MSC3440, filtering by event relations. + "related_by_senders": {"type": "array", "items": {"type": "string"}}, "io.element.relation_senders": {"type": "array", "items": {"type": "string"}}, + "related_by_rel_types": {"type": "array", "items": {"type": "string"}}, "io.element.relation_types": {"type": "array", "items": {"type": "string"}}, }, } @@ -318,19 +320,18 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - # Ideally these would be rejected at the endpoint if they were provided - # and not supported, but that would involve modifying the JSON schema - # based on the homeserver configuration. + self.related_by_senders = self.filter_json.get("related_by_senders", None) + self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + + # Fallback to the unstable prefix if the stable version is not given. if hs.config.experimental.msc3440_enabled: - self.relation_senders = self.filter_json.get( + self.related_by_senders = self.related_by_senders or self.filter_json.get( "io.element.relation_senders", None ) - self.relation_types = self.filter_json.get( - "io.element.relation_types", None + self.related_by_rel_types = ( + self.related_by_rel_types + or self.filter_json.get("io.element.relation_types", None) ) - else: - self.relation_senders = None - self.relation_types = None def filters_all_types(self) -> bool: return "*" in self.not_types @@ -461,7 +462,7 @@ class Filter: event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] event_ids_to_keep = set( await self._store.events_have_relations( - event_ids, self.relation_senders, self.relation_types + event_ids, self.related_by_senders, self.related_by_rel_types ) ) @@ -474,7 +475,7 @@ class Filter: async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: result = [event for event in events if self._check(event)] - if self.relation_senders or self.relation_types: + if self.related_by_senders or self.related_by_rel_types: return await self._check_event_relations(result) return result diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ee34cb46e4..b2a237c1e0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,6 +38,7 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.storage.databases.main.relations import BundledAggregations @@ -395,6 +396,9 @@ class EventClientSerializer: clients. """ + def __init__(self, hs: "HomeServer"): + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + def serialize_event( self, event: Union[JsonDict, EventBase], @@ -515,11 +519,14 @@ class EventClientSerializer: thread.latest_event, serialized_latest_event, thread.latest_edit ) - serialized_aggregations[RelationTypes.THREAD] = { + thread_summary = { "latest_event": serialized_latest_event, "count": thread.count, "current_user_participated": thread.current_user_participated, } + serialized_aggregations[RelationTypes.THREAD] = thread_summary + if self._msc3440_enabled: + serialized_aggregations[RelationTypes.UNSTABLE_THREAD] = thread_summary # Include the bundled aggregations in the event. if serialized_aggregations: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0799ec9a84..f9544fe7fb 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1079,7 +1079,10 @@ class EventCreationHandler: raise SynapseError(400, "Can't send same reaction twice") # Don't attempt to start a thread if the parent event is a relation. - elif relation_type == RelationTypes.THREAD: + elif ( + relation_type == RelationTypes.THREAD + or relation_type == RelationTypes.UNSTABLE_THREAD + ): if await self.store.event_includes_relation(relates_to): raise SynapseError( 400, "Cannot start threads from an event with a relation" diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2e5d0e4e22..9a65aa4843 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,6 +101,7 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440": self.config.experimental.msc3440_enabled, + "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 1270abb5a3..7741ff29dc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -754,7 +754,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer() + return EventClientSerializer(self) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1a322882bf..1f60aef180 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1814,7 +1814,10 @@ class PersistEventsStore: if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - if rel_type == RelationTypes.THREAD: + if ( + rel_type == RelationTypes.THREAD + or rel_type == RelationTypes.UNSTABLE_THREAD + ): txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index be1500092b..c4869d64e6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -508,7 +508,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC """ else: @@ -523,16 +523,22 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY child.topological_ordering DESC, child.stream_ordering DESC """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) latest_event_ids = {} for parent_event_id, child_event_id in txn: # Only consider the latest threaded reply (by topological ordering). @@ -552,7 +558,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s GROUP BY parent.event_id """ @@ -561,9 +567,15 @@ class RelationsWorkerStore(SQLBaseStore): clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", latest_event_ids.keys() ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) return counts, latest_event_ids @@ -626,16 +638,24 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s AND child.sender = ? """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.extend((RelationTypes.THREAD, user_id)) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + args.append(user_id) + + txn.execute(sql % (clause, relations_clause), args) return {row[0] for row in txn.fetchall()} participated_threads = await self.db_pool.runInteraction( @@ -834,26 +854,23 @@ class RelationsWorkerStore(SQLBaseStore): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. - if self._msc3440_enabled: - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - summaries.keys(), user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) + summaries = await self._get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._get_threads_participated(summaries.keys(), user_id) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a898f847e7..39e1efe373 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: args.extend(event_filter.labels) # Filter on relation_senders / relation types from the joined tables. - if event_filter.relation_senders: + if event_filter.related_by_senders: clauses.append( "(%s)" % " OR ".join( - "related_event.sender = ?" for _ in event_filter.relation_senders + "related_event.sender = ?" for _ in event_filter.related_by_senders ) ) - args.extend(event_filter.relation_senders) + args.extend(event_filter.related_by_senders) - if event_filter.relation_types: + if event_filter.related_by_rel_types: clauses.append( "(%s)" - % " OR ".join("relation_type = ?" for _ in event_filter.relation_types) + % " OR ".join( + "relation_type = ?" for _ in event_filter.related_by_rel_types + ) ) - args.extend(event_filter.relation_types) + args.extend(event_filter.related_by_rel_types) return " AND ".join(clauses), args @@ -1203,7 +1205,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # If there is a filter on relation_senders and relation_types join to the # relations table. if event_filter and ( - event_filter.relation_senders or event_filter.relation_types + event_filter.related_by_senders or event_filter.related_by_rel_types ): # Filtering by relations could cause the same event to appear multiple # times (since there's no limit on the number of relations to an event). @@ -1211,7 +1213,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): join_clause += """ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) """ - if event_filter.relation_senders: + if event_filter.related_by_senders: join_clause += """ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f9ae6e663f..0cbe6c0cf7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -547,9 +547,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config( - {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} - ) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -758,7 +756,6 @@ class RelationsTestCase(BaseRelationsTestCase): }, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -1065,7 +1062,6 @@ class RelationsTestCase(BaseRelationsTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_edit_thread(self) -> None: """Test that editing a thread works.""" @@ -1383,7 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): chunk = self._get_aggregations() self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 37866ee330..3a9617d6da 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_type(self) -> None: # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 6a1cf33054..eaa0d7d749 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -129,21 +129,19 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders(self): # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -152,20 +150,20 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_type(self): # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -179,8 +177,8 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) @@ -201,7 +199,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.second_tok, ) - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) -- cgit 1.5.1 From 80e0e1f35e6b1cdfa0267f9c40a6f212b7d774de Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:15:45 -0400 Subject: Only fetch thread participation for events with threads. (#12228) We fetch the thread summary in two phases: 1. The summary that is shared by all users (count of messages and latest event). 2. Whether the requesting user has participated in the thread. There's no use in attempting step 2 for events which did not return a summary from step 1. --- changelog.d/12228.bugfix | 1 + synapse/storage/databases/main/relations.py | 4 +- tests/rest/client/test_relations.py | 509 +++++++++++++++------------- tests/server.py | 20 +- 4 files changed, 289 insertions(+), 245 deletions(-) create mode 100644 changelog.d/12228.bugfix (limited to 'synapse/storage/databases/main/relations.py') diff --git a/changelog.d/12228.bugfix b/changelog.d/12228.bugfix new file mode 100644 index 0000000000..4755777139 --- /dev/null +++ b/changelog.d/12228.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in v1.53.0 where an unnecessary query could be performed when fetching bundled aggregations for threads. diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index c4869d64e6..af2334a65e 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -857,7 +857,9 @@ class RelationsWorkerStore(SQLBaseStore): summaries = await self._get_thread_summaries(events_by_id.keys()) # Only fetch participated for a limited selection based on what had # summaries. - participated = await self._get_threads_participated(summaries.keys(), user_id) + participated = await self._get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) for event_id, summary in summaries.items(): if summary: thread_count, latest_thread_event, edit = summary diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f3741b3001..329690f8f7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -15,7 +15,7 @@ import itertools import urllib.parse -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from unittest.mock import patch from twisted.test.proto_helpers import MemoryReactor @@ -155,6 +155,16 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase): self.assertEqual(200, channel.code, channel.json_body) return channel.json_body["chunk"] + def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: + """ + Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. + """ + for event in events: + if event["event_id"] == self.parent_id: + return event + + raise AssertionError(f"Event {self.parent_id} not found in chunk") + class RelationsTestCase(BaseRelationsTestCase): def test_send_relation(self) -> None: @@ -291,202 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) - def test_bundled_aggregations(self) -> None: - """ - Test that annotations, references, and threads get correctly bundled. - - Note that this doesn't test against /relations since only thread relations - get bundled via that API. See test_aggregation_get_event_for_thread. - - See test_edit for a similar test for edits. - """ - # Setup by sending a variety of relations. - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token - ) - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - reply_1 = channel.json_body["event_id"] - - channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") - reply_2 = channel.json_body["event_id"] - - self._send_relation(RelationTypes.THREAD, "m.room.test") - - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - thread_2 = channel.json_body["event_id"] - - def assert_bundle(event_json: JsonDict) -> None: - """Assert the expected values of the bundled aggregations.""" - relations_dict = event_json["unsigned"].get("m.relations") - - # Ensure the fields are as expected. - self.assertCountEqual( - relations_dict.keys(), - ( - RelationTypes.ANNOTATION, - RelationTypes.REFERENCE, - RelationTypes.THREAD, - ), - ) - - # Check the values of each field. - self.assertEqual( - { - "chunk": [ - {"type": "m.reaction", "key": "a", "count": 2}, - {"type": "m.reaction", "key": "b", "count": 1}, - ] - }, - relations_dict[RelationTypes.ANNOTATION], - ) - - self.assertEqual( - {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, - relations_dict[RelationTypes.REFERENCE], - ) - - self.assertEqual( - 2, - relations_dict[RelationTypes.THREAD].get("count"), - ) - self.assertTrue( - relations_dict[RelationTypes.THREAD].get("current_user_participated") - ) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } - }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - relations_dict[RelationTypes.THREAD].get("latest_event"), - ) - - # Request the event directly. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body) - - # Request the room messages. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) - - # Request the room context. - channel = self.make_request( - "GET", - f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - assert_bundle(channel.json_body["event"]) - - # Request sync. - channel = self.make_request("GET", "/sync", access_token=self.user_token) - self.assertEqual(200, channel.code, channel.json_body) - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - assert_bundle(self._find_event_in_chunk(room_timeline["events"])) - - # Request search. - channel = self.make_request( - "POST", - "/search", - # Search term matches the parent message. - content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - chunk = [ - result["result"] - for result in channel.json_body["search_categories"]["room_events"][ - "results" - ] - ] - assert_bundle(self._find_event_in_chunk(chunk)) - - def test_aggregation_get_event_for_annotation(self) -> None: - """Test that annotations do not get bundled aggregations included - when directly requested. - """ - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - annotation_id = channel.json_body["event_id"] - - # Annotate the annotation. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id - ) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{annotation_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) - - def test_aggregation_get_event_for_thread(self) -> None: - """Test that threads get bundled aggregations included when directly requested.""" - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") - thread_id = channel.json_body["event_id"] - - # Annotate the annotation. - self._send_relation( - RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id - ) - - channel = self.make_request( - "GET", - f"/rooms/{self.room}/event/{thread_id}", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual( - channel.json_body["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - - # It should also be included when the entire thread is requested. - channel = self.make_request( - "GET", - f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", - access_token=self.user_token, - ) - self.assertEqual(200, channel.code, channel.json_body) - self.assertEqual(len(channel.json_body["chunk"]), 1) - - thread_message = channel.json_body["chunk"][0] - self.assertEqual( - thread_message["unsigned"].get("m.relations"), - { - RelationTypes.ANNOTATION: { - "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] - }, - }, - ) - def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -796,7 +610,7 @@ class RelationsTestCase(BaseRelationsTestCase): threaded_event_id = channel.json_body["event_id"] new_body = {"msgtype": "m.text", "body": "I've been edited!"} - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body}, @@ -836,7 +650,7 @@ class RelationsTestCase(BaseRelationsTestCase): edit_event_id = channel.json_body["event_id"] # Edit the edit event. - channel = self._send_relation( + self._send_relation( RelationTypes.REPLACE, "m.room.message", content={ @@ -912,16 +726,6 @@ class RelationsTestCase(BaseRelationsTestCase): self.assertEqual(200, channel.code, channel.json_body) self.assertEqual(channel.json_body["chunk"], []) - def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict: - """ - Find the parent event in a chunk of events and assert that it has the proper bundled aggregations. - """ - for event in events: - if event["event_id"] == self.parent_id: - return event - - raise AssertionError(f"Event {self.parent_id} not found in chunk") - def test_background_update(self) -> None: """Test the event_arbitrary_relations background update.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍") @@ -981,34 +785,6 @@ class RelationsTestCase(BaseRelationsTestCase): [annotation_event_id_good, thread_event_id], ) - def test_bundled_aggregations_with_filter(self) -> None: - """ - If "unsigned" is an omitted field (due to filtering), adding the bundled - aggregations should not break. - - Note that the spec allows for a server to return additional fields beyond - what is specified. - """ - self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") - - # Note that the sync filter does not include "unsigned" as a field. - filter = urllib.parse.quote_plus( - b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' - ) - channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token - ) - self.assertEqual(200, channel.code, channel.json_body) - - # Ensure the timeline is limited, find the parent event. - room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] - self.assertTrue(room_timeline["limited"]) - parent_event = self._find_event_in_chunk(room_timeline["events"]) - - # Ensure there's bundled aggregations on it. - self.assertIn("unsigned", parent_event) - self.assertIn("m.relations", parent_event["unsigned"]) - class RelationPaginationTestCase(BaseRelationsTestCase): def test_basic_paginate_relations(self) -> None: @@ -1255,7 +1031,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): idx += 1 # Also send a different type of reaction so that we test we don't see it - channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a") prev_token = "" found_event_ids: List[str] = [] @@ -1291,6 +1067,263 @@ class RelationPaginationTestCase(BaseRelationsTestCase): self.assertEqual(found_event_ids, expected_event_ids) +class BundledAggregationsTestCase(BaseRelationsTestCase): + """ + See RelationsTestCase.test_edit for a similar test for edits. + + Note that this doesn't test against /relations since only thread relations + get bundled via that API. See test_aggregation_get_event_for_thread. + """ + + def _test_bundled_aggregations( + self, + relation_type: str, + assertion_callable: Callable[[JsonDict], None], + expected_db_txn_for_event: int, + ) -> None: + """ + Makes requests to various endpoints which should include bundled aggregations + and then calls an assertion function on the bundled aggregations. + + Args: + relation_type: The field to search for in the `m.relations` field in unsigned. + assertion_callable: Called with the contents of unsigned["m.relations"][relation_type] + for relation-specific assertions. + expected_db_txn_for_event: The number of database transactions which + are expected for a call to /event/. + """ + + def assert_bundle(event_json: JsonDict) -> None: + """Assert the expected values of the bundled aggregations.""" + relations_dict = event_json["unsigned"].get("m.relations") + + # Ensure the fields are as expected. + self.assertCountEqual(relations_dict.keys(), (relation_type,)) + assertion_callable(relations_dict[relation_type]) + + # Request the event directly. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body) + assert channel.resource_usage is not None + self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event) + + # Request the room messages. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/messages?dir=b", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) + + # Request the room context. + channel = self.make_request( + "GET", + f"/rooms/{self.room}/context/{self.parent_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + assert_bundle(channel.json_body["event"]) + + # Request sync. + filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) + + # Request search. + channel = self.make_request( + "POST", + "/search", + # Search term matches the parent message. + content={"search_categories": {"room_events": {"search_term": "Hi"}}}, + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + chunk = [ + result["result"] + for result in channel.json_body["search_categories"]["room_events"][ + "results" + ] + ] + assert_bundle(self._find_event_in_chunk(chunk)) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_annotation(self) -> None: + """ + Test that annotations get correctly bundled. + """ + # Setup by sending a variety of relations. + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token + ) + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b") + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + { + "chunk": [ + {"type": "m.reaction", "key": "a", "count": 2}, + {"type": "m.reaction", "key": "b", "count": 1}, + ] + }, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_reference(self) -> None: + """ + Test that references get correctly bundled. + """ + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_1 = channel.json_body["event_id"] + + channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test") + reply_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual( + {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]}, + bundled_aggregations, + ) + + self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7) + + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) + def test_thread(self) -> None: + """ + Test that threads get correctly bundled. + """ + self._send_relation(RelationTypes.THREAD, "m.room.test") + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_2 = channel.json_body["event_id"] + + def assert_annotations(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertTrue(bundled_aggregations.get("current_user_participated")) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user_id, + "type": "m.room.test", + }, + bundled_aggregations.get("latest_event"), + ) + + self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9) + + def test_aggregation_get_event_for_annotation(self) -> None: + """Test that annotations do not get bundled aggregations included + when directly requested. + """ + channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + annotation_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{annotation_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertIsNone(channel.json_body["unsigned"].get("m.relations")) + + def test_aggregation_get_event_for_thread(self) -> None: + """Test that threads get bundled aggregations included when directly requested.""" + channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + thread_id = channel.json_body["event_id"] + + # Annotate the annotation. + self._send_relation( + RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id + ) + + channel = self.make_request( + "GET", + f"/rooms/{self.room}/event/{thread_id}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual( + channel.json_body["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + # It should also be included when the entire thread is requested. + channel = self.make_request( + "GET", + f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + self.assertEqual(len(channel.json_body["chunk"]), 1) + + thread_message = channel.json_body["chunk"][0] + self.assertEqual( + thread_message["unsigned"].get("m.relations"), + { + RelationTypes.ANNOTATION: { + "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}] + }, + }, + ) + + def test_bundled_aggregations_with_filter(self) -> None: + """ + If "unsigned" is an omitted field (due to filtering), adding the bundled + aggregations should not break. + + Note that the spec allows for a server to return additional fields beyond + what is specified. + """ + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + + # Note that the sync filter does not include "unsigned" as a field. + filter = urllib.parse.quote_plus( + b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Ensure the timeline is limited, find the parent event. + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + parent_event = self._find_event_in_chunk(room_timeline["events"]) + + # Ensure there's bundled aggregations on it. + self.assertIn("unsigned", parent_event) + self.assertIn("m.relations", parent_event["unsigned"]) + + class RelationRedactionTestCase(BaseRelationsTestCase): """ Test the behaviour of relations when the parent or child event is redacted. diff --git a/tests/server.py b/tests/server.py index 82990c2eb9..6ce2a17bf4 100644 --- a/tests/server.py +++ b/tests/server.py @@ -54,13 +54,18 @@ from twisted.internet.interfaces import ( ITransport, ) from twisted.python.failure import Failure -from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactorClock +from twisted.test.proto_helpers import ( + AccumulatingProtocol, + MemoryReactor, + MemoryReactorClock, +) from twisted.web.http_headers import Headers from twisted.web.resource import IResource from twisted.web.server import Request, Site from synapse.config.database import DatabaseConnectionConfig from synapse.http.site import SynapseRequest +from synapse.logging.context import ContextResourceUsage from synapse.server import HomeServer from synapse.storage import DataStore from synapse.storage.engines import PostgresEngine, create_engine @@ -88,18 +93,19 @@ class TimedOutException(Exception): """ -@attr.s +@attr.s(auto_attribs=True) class FakeChannel: """ A fake Twisted Web Channel (the part that interfaces with the wire). """ - site = attr.ib(type=Union[Site, "FakeSite"]) - _reactor = attr.ib() - result = attr.ib(type=dict, default=attr.Factory(dict)) - _ip = attr.ib(type=str, default="127.0.0.1") + site: Union[Site, "FakeSite"] + _reactor: MemoryReactor + result: dict = attr.Factory(dict) + _ip: str = "127.0.0.1" _producer: Optional[Union[IPullProducer, IPushProducer]] = None + resource_usage: Optional[ContextResourceUsage] = None @property def json_body(self): @@ -168,6 +174,8 @@ class FakeChannel: def requestDone(self, _self): self.result["done"] = True + if isinstance(_self, SynapseRequest): + self.resource_usage = _self.logcontext.get_resource_usage() def getPeer(self): # We give an address so that getClientIP returns a non null entry, -- cgit 1.5.1 From 8fe930c215f69913fbcd96d609ec6950644e4ec4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:49:32 -0400 Subject: Move get_bundled_aggregations to relations handler. (#12237) The get_bundled_aggregations code is fairly high-level and uses a lot of store methods, we move it into the handler as that seems like a better fit. --- changelog.d/12237.misc | 1 + synapse/events/utils.py | 2 +- synapse/handlers/pagination.py | 5 +- synapse/handlers/relations.py | 151 +++++++++++++++++++++++++++- synapse/handlers/room.py | 5 +- synapse/handlers/search.py | 3 +- synapse/handlers/sync.py | 9 +- synapse/rest/client/room.py | 3 +- synapse/storage/databases/main/relations.py | 151 +--------------------------- 9 files changed, 173 insertions(+), 157 deletions(-) create mode 100644 changelog.d/12237.misc (limited to 'synapse/storage/databases/main/relations.py') diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc new file mode 100644 index 0000000000..41c9dcbd37 --- /dev/null +++ b/changelog.d/12237.misc @@ -0,0 +1 @@ +Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0520068e0..7120062127 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations from synapse.server import HomeServer - from synapse.storage.databases.main.relations import BundledAggregations # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 41679f7f86..876b879483 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -134,6 +134,7 @@ class PaginationHandler: self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() + self._relations_handler = hs.get_relations_handler() self.pagination_lock = ReadWriteLock() # IDs of rooms in which there currently an active purge *or delete* operation. @@ -539,7 +540,9 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() - aggregations = await self.store.get_bundled_aggregations(events, user_id) + aggregations = await self._relations_handler.get_bundled_aggregations( + events, user_id + ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e475475ad..57135d4519 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,18 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +import attr +from frozendict import frozendict + +from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.types import JsonDict, Requester, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + # The latest event in the thread. + latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. + count: int + # True if the current user has sent an event to the thread. + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main @@ -103,7 +138,7 @@ class RelationsHandler: ) # The relations returned for the requested event do include their # bundled aggregations. - aggregations = await self._main_store.get_bundled_aggregations( + aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) serialized_events = self._event_serializer.serialize_events( @@ -115,3 +150,115 @@ class RelationsHandler: return_value["original_event"] = original_event return return_value + + async def _get_bundled_aggregation_for_event( + self, event: EventBase, user_id: str + ) -> Optional[BundledAggregations]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations = BundledAggregations() + + annotations = await self._main_store.get_aggregation_groups_for_event( + event_id, room_id + ) + if annotations.chunk: + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) + + references = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations.references = await references.to_dict(cast("DataStore", self)) + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase], user_id: str + ) -> Dict[str, BundledAggregations]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } + + # event ID -> bundled aggregation in non-serialized form. + results: Dict[str, BundledAggregations] = {} + + # Fetch other relations per event. + for event in events_by_id.values(): + event_result = await self._get_bundled_aggregation_for_event(event, user_id) + if event_result: + results[event.event_id] = event_result + + # Fetch any edits (but not for redacted events). + edits = await self._main_store.get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Fetch thread summaries. + summaries = await self._main_store.get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._main_store.get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + + return results diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9735631fc..092e185c99 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -60,8 +60,8 @@ from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state +from synapse.handlers.relations import BundledAggregations from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -1118,6 +1118,7 @@ class RoomContextHandler: self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state + self._relations_handler = hs.get_relations_handler() async def get_event_context( self, @@ -1190,7 +1191,7 @@ class RoomContextHandler: event = filtered[0] # Fetch the aggregations. - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( itertools.chain(events_before, (event,), events_after), user.to_string(), ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index aa16e417eb..30eddda65f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -54,6 +54,7 @@ class SearchHandler: self.clock = hs.get_clock() self.hs = hs self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.storage = hs.get_storage() self.state_store = self.storage.state self.auth = hs.get_auth() @@ -354,7 +355,7 @@ class SearchHandler: aggregations = None if self._msc3666_enabled: - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( # Generate an iterable of EventBase for all the events that will be # returned, including contextual events. itertools.chain( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c9d6a18bd7..6c569cfb1c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -33,11 +33,11 @@ from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -269,6 +269,7 @@ class SyncHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() + self._relations_handler = hs.get_relations_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -638,8 +639,10 @@ class SyncHandler: # as clients will have all the necessary information. bundled_aggregations = None if limited or newly_joined_room: - bundled_aggregations = await self.store.get_bundled_aggregations( - recents, sync_config.user.to_string() + bundled_aggregations = ( + await self._relations_handler.get_bundled_aggregations( + recents, sync_config.user.to_string() + ) ) return TimelineBatch( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 8a06ab8c5f..47e152c8cc 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet): self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.auth = hs.get_auth() async def on_GET( @@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet): if event: # Ensure there are bundled aggregations available. - aggregations = await self._store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( [event], requester.user.to_string() ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index af2334a65e..b2295fd51f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -27,7 +27,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.events import EventBase @@ -41,45 +40,15 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _ThreadAggregation: - # The latest event in the thread. - latest_event: EventBase - # The latest edit to the latest event in the thread. - latest_edit: Optional[EventBase] - # The total number of events in the thread. - count: int - # True if the current user has sent an event to the thread. - current_user_participated: bool - - -@attr.s(slots=True, auto_attribs=True) -class BundledAggregations: - """ - The bundled aggregations for an event. - - Some values require additional processing during serialization. - """ - - annotations: Optional[JsonDict] = None - references: Optional[JsonDict] = None - replace: Optional[EventBase] = None - thread: Optional[_ThreadAggregation] = None - - def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) - - class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") - async def _get_applicable_edits( + async def get_applicable_edits( self, event_ids: Collection[str] ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given @@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") - async def _get_thread_summaries( + async def get_thread_summaries( self, event_ids: Collection[str] ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. @@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] # Check to see if any of those events are edited. - latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + latest_edits = await self.get_applicable_edits(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") - async def _get_threads_participated( + async def get_threads_participated( self, event_ids: Collection[str], user_id: str ) -> Dict[str, bool]: """Get whether the requesting user participated in the given threads. @@ -766,116 +735,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - async def _get_bundled_aggregation_for_event( - self, event: EventBase, user_id: str - ) -> Optional[BundledAggregations]: - """Generate bundled aggregations for an event. - - Note that this does not use a cache, but depends on cached methods. - - Args: - event: The event to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - The bundled aggregations for an event, if bundled aggregations are - enabled and the event can have bundled aggregations. - """ - - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return None - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include, a mapping of relation type to a - # type-specific value. Some types include the direct return type here - # while others need more processing during serialization. - aggregations = BundledAggregations() - - annotations = await self.get_aggregation_groups_for_event(event_id, room_id) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) - - references = await self.get_relations_for_event( - event_id, event, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) - - # Store the bundled aggregations in the event metadata for later use. - return aggregations - - async def get_bundled_aggregations( - self, events: Iterable[EventBase], user_id: str - ) -> Dict[str, BundledAggregations]: - """Generate bundled aggregations for events. - - Args: - events: The iterable of events to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - A map of event ID to the bundled aggregation for the event. Not all - events may have bundled aggregations in the results. - """ - # De-duplicate events by ID to handle the same event requested multiple times. - # - # State events do not get bundled aggregations. - events_by_id = { - event.event_id: event for event in events if not event.is_state() - } - - # event ID -> bundled aggregation in non-serialized form. - results: Dict[str, BundledAggregations] = {} - - # Fetch other relations per event. - for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result: - results[event.event_id] = event_result - - # Fetch any edits (but not for redacted events). - edits = await self._get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] - ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit - - # Fetch thread summaries. - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - [event_id for event_id, summary in summaries.items() if summary], user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) - - return results - class RelationsStore(RelationsWorkerStore): pass -- cgit 1.5.1