From 1d11b452b70c768e4919bd9cf6bcaeda2050a3d4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 3 Mar 2022 10:43:06 -0500 Subject: Use the proper serialization format when bundling aggregations. (#12090) This ensures that the `latest_event` field of the bundled aggregation for threads uses the same format as the other events in the response. --- synapse/events/utils.py | 81 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 27 deletions(-) (limited to 'synapse/events/utils.py') diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 9386fa29dd..ee34cb46e4 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -26,6 +26,7 @@ from typing import ( Union, ) +import attr from frozendict import frozendict from synapse.api.constants import EventContentFields, EventTypes, RelationTypes @@ -303,29 +304,37 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict: return d +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SerializeEventConfig: + as_client_event: bool = True + # Function to convert from federation format to client format + event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1 + # ID of the user's auth token - used for namespacing of transaction IDs + token_id: Optional[int] = None + # List of event fields to include. If empty, all fields will be returned. + only_event_fields: Optional[List[str]] = None + # Some events can have stripped room state stored in the `unsigned` field. + # This is required for invite and knock functionality. If this option is + # False, that state will be removed from the event before it is returned. + # Otherwise, it will be kept. + include_stripped_room_state: bool = False + + +_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig() + + def serialize_event( e: Union[JsonDict, EventBase], time_now_ms: int, *, - as_client_event: bool = True, - event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1, - token_id: Optional[str] = None, - only_event_fields: Optional[List[str]] = None, - include_stripped_room_state: bool = False, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, ) -> JsonDict: """Serialize event for clients Args: e time_now_ms - as_client_event - event_format - token_id - only_event_fields - include_stripped_room_state: Some events can have stripped room state - stored in the `unsigned` field. This is required for invite and knock - functionality. If this option is False, that state will be removed from the - event before it is returned. Otherwise, it will be kept. + config: Event serialization config Returns: The serialized event dictionary. @@ -348,11 +357,11 @@ def serialize_event( if "redacted_because" in e.unsigned: d["unsigned"]["redacted_because"] = serialize_event( - e.unsigned["redacted_because"], time_now_ms, event_format=event_format + e.unsigned["redacted_because"], time_now_ms, config=config ) - if token_id is not None: - if token_id == getattr(e.internal_metadata, "token_id", None): + if config.token_id is not None: + if config.token_id == getattr(e.internal_metadata, "token_id", None): txn_id = getattr(e.internal_metadata, "txn_id", None) if txn_id is not None: d["unsigned"]["transaction_id"] = txn_id @@ -361,13 +370,14 @@ def serialize_event( # that are meant to provide metadata about a room to an invitee/knocker. They are # intended to only be included in specific circumstances, such as down sync, and # should not be included in any other case. - if not include_stripped_room_state: + if not config.include_stripped_room_state: d["unsigned"].pop("invite_room_state", None) d["unsigned"].pop("knock_room_state", None) - if as_client_event: - d = event_format(d) + if config.as_client_event: + d = config.event_format(d) + only_event_fields = config.only_event_fields if only_event_fields: if not isinstance(only_event_fields, list) or not all( isinstance(f, str) for f in only_event_fields @@ -390,18 +400,18 @@ class EventClientSerializer: event: Union[JsonDict, EventBase], time_now: int, *, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, - **kwargs: Any, ) -> JsonDict: """Serializes a single event. Args: event: The event being serialized. time_now: The current time in milliseconds + config: Event serialization config bundle_aggregations: Whether to include the bundled aggregations for this event. Only applies to non-state events. (State events never include bundled aggregations.) - **kwargs: Arguments to pass to `serialize_event` Returns: The serialized event @@ -410,7 +420,7 @@ class EventClientSerializer: if not isinstance(event, EventBase): return event - serialized_event = serialize_event(event, time_now, **kwargs) + serialized_event = serialize_event(event, time_now, config=config) # Check if there are any bundled aggregations to include with the event. if bundle_aggregations: @@ -419,6 +429,7 @@ class EventClientSerializer: self._inject_bundled_aggregations( event, time_now, + config, bundle_aggregations[event.event_id], serialized_event, ) @@ -456,6 +467,7 @@ class EventClientSerializer: self, event: EventBase, time_now: int, + config: SerializeEventConfig, aggregations: "BundledAggregations", serialized_event: JsonDict, ) -> None: @@ -466,6 +478,7 @@ class EventClientSerializer: time_now: The current time in milliseconds aggregations: The bundled aggregation to serialize. serialized_event: The serialized event which may be modified. + config: Event serialization config """ serialized_aggregations = {} @@ -493,8 +506,8 @@ class EventClientSerializer: thread = aggregations.thread # Don't bundle aggregations as this could recurse forever. - serialized_latest_event = self.serialize_event( - thread.latest_event, time_now, bundle_aggregations=None + serialized_latest_event = serialize_event( + thread.latest_event, time_now, config=config ) # Manually apply an edit, if one exists. if thread.latest_edit: @@ -515,20 +528,34 @@ class EventClientSerializer: ) def serialize_events( - self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any + self, + events: Iterable[Union[JsonDict, EventBase]], + time_now: int, + *, + config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, + bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, ) -> List[JsonDict]: """Serializes multiple events. Args: event time_now: The current time in milliseconds - **kwargs: Arguments to pass to `serialize_event` + config: Event serialization config + bundle_aggregations: Whether to include the bundled aggregations for this + event. Only applies to non-state events. (State events never include + bundled aggregations.) Returns: The list of serialized events """ return [ - self.serialize_event(event, time_now=time_now, **kwargs) for event in events + self.serialize_event( + event, + time_now, + config=config, + bundle_aggregations=bundle_aggregations, + ) + for event in events ] -- cgit 1.5.1 From ea27528b5d177dcfc5a4e38b463baeace916dc8e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Mar 2022 10:36:13 -0500 Subject: Support stable identifiers for MSC3440: Threading (#12151) The unstable identifiers are still supported if the experimental configuration flag is enabled. The unstable identifiers will be removed in a future release. --- changelog.d/12151.feature | 1 + synapse/api/constants.py | 4 +- synapse/api/filtering.py | 23 ++++----- synapse/events/utils.py | 9 +++- synapse/handlers/message.py | 5 +- synapse/rest/client/versions.py | 1 + synapse/server.py | 2 +- synapse/storage/databases/main/events.py | 5 +- synapse/storage/databases/main/relations.py | 77 ++++++++++++++++++----------- synapse/storage/databases/main/stream.py | 18 ++++--- tests/rest/client/test_relations.py | 7 +-- tests/rest/client/test_rooms.py | 18 +++---- tests/storage/test_stream.py | 20 ++++---- 13 files changed, 109 insertions(+), 81 deletions(-) create mode 100644 changelog.d/12151.feature (limited to 'synapse/events/utils.py') diff --git a/changelog.d/12151.feature b/changelog.d/12151.feature new file mode 100644 index 0000000000..18432b2da9 --- /dev/null +++ b/changelog.d/12151.feature @@ -0,0 +1 @@ +Support the stable identifiers from [MSC3440](https://github.com/matrix-org/matrix-doc/pull/3440): threads. diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 36ace7c613..b0c08a074d 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -178,7 +178,9 @@ class RelationTypes: ANNOTATION: Final = "m.annotation" REPLACE: Final = "m.replace" REFERENCE: Final = "m.reference" - THREAD: Final = "io.element.thread" + THREAD: Final = "m.thread" + # TODO Remove this in Synapse >= v1.57.0. + UNSTABLE_THREAD: Final = "io.element.thread" class LimitBlockingTypes: diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py index cb532d7238..27e97d6f37 100644 --- a/synapse/api/filtering.py +++ b/synapse/api/filtering.py @@ -88,7 +88,9 @@ ROOM_EVENT_FILTER_SCHEMA = { "org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.not_labels": {"type": "array", "items": {"type": "string"}}, # MSC3440, filtering by event relations. + "related_by_senders": {"type": "array", "items": {"type": "string"}}, "io.element.relation_senders": {"type": "array", "items": {"type": "string"}}, + "related_by_rel_types": {"type": "array", "items": {"type": "string"}}, "io.element.relation_types": {"type": "array", "items": {"type": "string"}}, }, } @@ -318,19 +320,18 @@ class Filter: self.labels = filter_json.get("org.matrix.labels", None) self.not_labels = filter_json.get("org.matrix.not_labels", []) - # Ideally these would be rejected at the endpoint if they were provided - # and not supported, but that would involve modifying the JSON schema - # based on the homeserver configuration. + self.related_by_senders = self.filter_json.get("related_by_senders", None) + self.related_by_rel_types = self.filter_json.get("related_by_rel_types", None) + + # Fallback to the unstable prefix if the stable version is not given. if hs.config.experimental.msc3440_enabled: - self.relation_senders = self.filter_json.get( + self.related_by_senders = self.related_by_senders or self.filter_json.get( "io.element.relation_senders", None ) - self.relation_types = self.filter_json.get( - "io.element.relation_types", None + self.related_by_rel_types = ( + self.related_by_rel_types + or self.filter_json.get("io.element.relation_types", None) ) - else: - self.relation_senders = None - self.relation_types = None def filters_all_types(self) -> bool: return "*" in self.not_types @@ -461,7 +462,7 @@ class Filter: event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined] event_ids_to_keep = set( await self._store.events_have_relations( - event_ids, self.relation_senders, self.relation_types + event_ids, self.related_by_senders, self.related_by_rel_types ) ) @@ -474,7 +475,7 @@ class Filter: async def filter(self, events: Iterable[FilterEvent]) -> List[FilterEvent]: result = [event for event in events if self._check(event)] - if self.relation_senders or self.relation_types: + if self.related_by_senders or self.related_by_rel_types: return await self._check_event_relations(result) return result diff --git a/synapse/events/utils.py b/synapse/events/utils.py index ee34cb46e4..b2a237c1e0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,6 +38,7 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.server import HomeServer from synapse.storage.databases.main.relations import BundledAggregations @@ -395,6 +396,9 @@ class EventClientSerializer: clients. """ + def __init__(self, hs: "HomeServer"): + self._msc3440_enabled = hs.config.experimental.msc3440_enabled + def serialize_event( self, event: Union[JsonDict, EventBase], @@ -515,11 +519,14 @@ class EventClientSerializer: thread.latest_event, serialized_latest_event, thread.latest_edit ) - serialized_aggregations[RelationTypes.THREAD] = { + thread_summary = { "latest_event": serialized_latest_event, "count": thread.count, "current_user_participated": thread.current_user_participated, } + serialized_aggregations[RelationTypes.THREAD] = thread_summary + if self._msc3440_enabled: + serialized_aggregations[RelationTypes.UNSTABLE_THREAD] = thread_summary # Include the bundled aggregations in the event. if serialized_aggregations: diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 0799ec9a84..f9544fe7fb 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1079,7 +1079,10 @@ class EventCreationHandler: raise SynapseError(400, "Can't send same reaction twice") # Don't attempt to start a thread if the parent event is a relation. - elif relation_type == RelationTypes.THREAD: + elif ( + relation_type == RelationTypes.THREAD + or relation_type == RelationTypes.UNSTABLE_THREAD + ): if await self.store.event_includes_relation(relates_to): raise SynapseError( 400, "Cannot start threads from an event with a relation" diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 2e5d0e4e22..9a65aa4843 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -101,6 +101,7 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3030": self.config.experimental.msc3030_enabled, # Adds support for thread relations, per MSC3440. "org.matrix.msc3440": self.config.experimental.msc3440_enabled, + "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 1270abb5a3..7741ff29dc 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -754,7 +754,7 @@ class HomeServer(metaclass=abc.ABCMeta): @cache_in_self def get_event_client_serializer(self) -> EventClientSerializer: - return EventClientSerializer() + return EventClientSerializer(self) @cache_in_self def get_password_policy_handler(self) -> PasswordPolicyHandler: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 1a322882bf..1f60aef180 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -1814,7 +1814,10 @@ class PersistEventsStore: if rel_type == RelationTypes.REPLACE: txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,)) - if rel_type == RelationTypes.THREAD: + if ( + rel_type == RelationTypes.THREAD + or rel_type == RelationTypes.UNSTABLE_THREAD + ): txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,)) # It should be safe to only invalidate the cache if the user has not # previously participated in the thread, but that's difficult (and diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index be1500092b..c4869d64e6 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -508,7 +508,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC """ else: @@ -523,16 +523,22 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s ORDER BY child.topological_ordering DESC, child.stream_ordering DESC """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) latest_event_ids = {} for parent_event_id, child_event_id in txn: # Only consider the latest threaded reply (by topological ordering). @@ -552,7 +558,7 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s GROUP BY parent.event_id """ @@ -561,9 +567,15 @@ class RelationsWorkerStore(SQLBaseStore): clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", latest_event_ids.keys() ) - args.append(RelationTypes.THREAD) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + txn.execute(sql % (clause, relations_clause), args) counts = dict(cast(List[Tuple[str, int]], txn.fetchall())) return counts, latest_event_ids @@ -626,16 +638,24 @@ class RelationsWorkerStore(SQLBaseStore): AND parent.room_id = child.room_id WHERE %s - AND relation_type = ? + AND %s AND child.sender = ? """ clause, args = make_in_list_sql_clause( txn.database_engine, "relates_to_id", event_ids ) - args.extend((RelationTypes.THREAD, user_id)) - txn.execute(sql % (clause,), args) + if self._msc3440_enabled: + relations_clause = "(relation_type = ? OR relation_type = ?)" + args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD)) + else: + relations_clause = "relation_type = ?" + args.append(RelationTypes.THREAD) + + args.append(user_id) + + txn.execute(sql % (clause, relations_clause), args) return {row[0] for row in txn.fetchall()} participated_threads = await self.db_pool.runInteraction( @@ -834,26 +854,23 @@ class RelationsWorkerStore(SQLBaseStore): results.setdefault(event_id, BundledAggregations()).replace = edit # Fetch thread summaries. - if self._msc3440_enabled: - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - summaries.keys(), user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) + summaries = await self._get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._get_threads_participated(summaries.keys(), user_id) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index a898f847e7..39e1efe373 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -325,21 +325,23 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: args.extend(event_filter.labels) # Filter on relation_senders / relation types from the joined tables. - if event_filter.relation_senders: + if event_filter.related_by_senders: clauses.append( "(%s)" % " OR ".join( - "related_event.sender = ?" for _ in event_filter.relation_senders + "related_event.sender = ?" for _ in event_filter.related_by_senders ) ) - args.extend(event_filter.relation_senders) + args.extend(event_filter.related_by_senders) - if event_filter.relation_types: + if event_filter.related_by_rel_types: clauses.append( "(%s)" - % " OR ".join("relation_type = ?" for _ in event_filter.relation_types) + % " OR ".join( + "relation_type = ?" for _ in event_filter.related_by_rel_types + ) ) - args.extend(event_filter.relation_types) + args.extend(event_filter.related_by_rel_types) return " AND ".join(clauses), args @@ -1203,7 +1205,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): # If there is a filter on relation_senders and relation_types join to the # relations table. if event_filter and ( - event_filter.relation_senders or event_filter.relation_types + event_filter.related_by_senders or event_filter.related_by_rel_types ): # Filtering by relations could cause the same event to appear multiple # times (since there's no limit on the number of relations to an event). @@ -1211,7 +1213,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): join_clause += """ LEFT JOIN event_relations AS relation ON (event.event_id = relation.relates_to_id) """ - if event_filter.relation_senders: + if event_filter.related_by_senders: join_clause += """ LEFT JOIN events AS related_event ON (relation.event_id = related_event.event_id) """ diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index f9ae6e663f..0cbe6c0cf7 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -547,9 +547,7 @@ class RelationsTestCase(BaseRelationsTestCase): ) self.assertEqual(400, channel.code, channel.json_body) - @unittest.override_config( - {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}} - ) + @unittest.override_config({"experimental_features": {"msc3666_enabled": True}}) def test_bundled_aggregations(self) -> None: """ Test that annotations, references, and threads get correctly bundled. @@ -758,7 +756,6 @@ class RelationsTestCase(BaseRelationsTestCase): }, ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_ignore_invalid_room(self) -> None: """Test that we ignore invalid relations over federation.""" # Create another room and send a message in it. @@ -1065,7 +1062,6 @@ class RelationsTestCase(BaseRelationsTestCase): {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict ) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_edit_thread(self) -> None: """Test that editing a thread works.""" @@ -1383,7 +1379,6 @@ class RelationRedactionTestCase(BaseRelationsTestCase): chunk = self._get_aggregations() self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}]) - @unittest.override_config({"experimental_features": {"msc3440_enabled": True}}) def test_redact_relation_thread(self) -> None: """ Test that thread replies are properly handled after the thread reply redacted. diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py index 37866ee330..3a9617d6da 100644 --- a/tests/rest/client/test_rooms.py +++ b/tests/rest/client/test_rooms.py @@ -2141,21 +2141,19 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders(self) -> None: # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -2164,20 +2162,20 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_type(self) -> None: # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0]["event_id"], self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -2191,8 +2189,8 @@ class RelationsTestCase(unittest.HomeserverTestCase): def test_filter_relation_senders_and_type(self) -> None: # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py index 6a1cf33054..eaa0d7d749 100644 --- a/tests/storage/test_stream.py +++ b/tests/storage/test_stream.py @@ -129,21 +129,19 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders(self): # Messages which second user reacted to. - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which third user reacted to. - filter = {"io.element.relation_senders": [self.third_user_id]} + filter = {"related_by_senders": [self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which either user reacted to. - filter = { - "io.element.relation_senders": [self.second_user_id, self.third_user_id] - } + filter = {"related_by_senders": [self.second_user_id, self.third_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 2, chunk) self.assertCountEqual( @@ -152,20 +150,20 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_type(self): # Messages which have annotations. - filter = {"io.element.relation_types": [RelationTypes.ANNOTATION]} + filter = {"related_by_rel_types": [RelationTypes.ANNOTATION]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) # Messages which have references. - filter = {"io.element.relation_types": [RelationTypes.REFERENCE]} + filter = {"related_by_rel_types": [RelationTypes.REFERENCE]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_2) # Messages which have either annotations or references. filter = { - "io.element.relation_types": [ + "related_by_rel_types": [ RelationTypes.ANNOTATION, RelationTypes.REFERENCE, ] @@ -179,8 +177,8 @@ class PaginationTestCase(HomeserverTestCase): def test_filter_relation_senders_and_type(self): # Messages which second user reacted to. filter = { - "io.element.relation_senders": [self.second_user_id], - "io.element.relation_types": [RelationTypes.ANNOTATION], + "related_by_senders": [self.second_user_id], + "related_by_rel_types": [RelationTypes.ANNOTATION], } chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) @@ -201,7 +199,7 @@ class PaginationTestCase(HomeserverTestCase): tok=self.second_tok, ) - filter = {"io.element.relation_senders": [self.second_user_id]} + filter = {"related_by_senders": [self.second_user_id]} chunk = self._filter_messages(filter) self.assertEqual(len(chunk), 1, chunk) self.assertEqual(chunk[0].event_id, self.event_id_1) -- cgit 1.5.1 From 96274565ff0dbb7d21b02b04fcef115330426707 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Mar 2022 12:17:39 -0400 Subject: Fix bundling aggregations if unsigned is not a returned event field. (#12234) An error occured if a filter was supplied with `event_fields` which did not include `unsigned`. In that case, bundled aggregations are still added as the spec states it is allowed for servers to add additional fields. --- changelog.d/12234.bugfix | 1 + synapse/events/utils.py | 9 ++++++--- tests/rest/client/test_relations.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 35 insertions(+), 3 deletions(-) create mode 100644 changelog.d/12234.bugfix (limited to 'synapse/events/utils.py') diff --git a/changelog.d/12234.bugfix b/changelog.d/12234.bugfix new file mode 100644 index 0000000000..dbb77f36ff --- /dev/null +++ b/changelog.d/12234.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug when a `filter` argument with `event_fields` supplied but not including the `unsigned` field could result in a 500 error on `/sync`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index b2a237c1e0..a0520068e0 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -530,9 +530,12 @@ class EventClientSerializer: # Include the bundled aggregations in the event. if serialized_aggregations: - serialized_event["unsigned"].setdefault("m.relations", {}).update( - serialized_aggregations - ) + # There is likely already an "unsigned" field, but a filter might + # have stripped it off (via the event_fields option). The server is + # allowed to return additional fields, so add it back. + serialized_event.setdefault("unsigned", {}).setdefault( + "m.relations", {} + ).update(serialized_aggregations) def serialize_events( self, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 0cbe6c0cf7..171f4e97c8 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -1267,6 +1267,34 @@ class RelationsTestCase(BaseRelationsTestCase): [annotation_event_id_good, thread_event_id], ) + def test_bundled_aggregations_with_filter(self) -> None: + """ + If "unsigned" is an omitted field (due to filtering), adding the bundled + aggregations should not break. + + Note that the spec allows for a server to return additional fields beyond + what is specified. + """ + self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a") + + # Note that the sync filter does not include "unsigned" as a field. + filter = urllib.parse.quote_plus( + b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}' + ) + channel = self.make_request( + "GET", f"/sync?filter={filter}", access_token=self.user_token + ) + self.assertEqual(200, channel.code, channel.json_body) + + # Ensure the timeline is limited, find the parent event. + room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] + self.assertTrue(room_timeline["limited"]) + parent_event = self._find_event_in_chunk(room_timeline["events"]) + + # Ensure there's bundled aggregations on it. + self.assertIn("unsigned", parent_event) + self.assertIn("m.relations", parent_event["unsigned"]) + class RelationRedactionTestCase(BaseRelationsTestCase): """ -- cgit 1.5.1 From 8fe930c215f69913fbcd96d609ec6950644e4ec4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Mar 2022 13:49:32 -0400 Subject: Move get_bundled_aggregations to relations handler. (#12237) The get_bundled_aggregations code is fairly high-level and uses a lot of store methods, we move it into the handler as that seems like a better fit. --- changelog.d/12237.misc | 1 + synapse/events/utils.py | 2 +- synapse/handlers/pagination.py | 5 +- synapse/handlers/relations.py | 151 +++++++++++++++++++++++++++- synapse/handlers/room.py | 5 +- synapse/handlers/search.py | 3 +- synapse/handlers/sync.py | 9 +- synapse/rest/client/room.py | 3 +- synapse/storage/databases/main/relations.py | 151 +--------------------------- 9 files changed, 173 insertions(+), 157 deletions(-) create mode 100644 changelog.d/12237.misc (limited to 'synapse/events/utils.py') diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc new file mode 100644 index 0000000000..41c9dcbd37 --- /dev/null +++ b/changelog.d/12237.misc @@ -0,0 +1 @@ +Refactor the relations endpoints to add a `RelationsHandler`. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index a0520068e0..7120062127 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze from . import EventBase if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations from synapse.server import HomeServer - from synapse.storage.databases.main.relations import BundledAggregations # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index 41679f7f86..876b879483 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -134,6 +134,7 @@ class PaginationHandler: self.clock = hs.get_clock() self._server_name = hs.hostname self._room_shutdown_handler = hs.get_room_shutdown_handler() + self._relations_handler = hs.get_relations_handler() self.pagination_lock = ReadWriteLock() # IDs of rooms in which there currently an active purge *or delete* operation. @@ -539,7 +540,9 @@ class PaginationHandler: state_dict = await self.store.get_events(list(state_ids.values())) state = state_dict.values() - aggregations = await self.store.get_bundled_aggregations(events, user_id) + aggregations = await self._relations_handler.get_bundled_aggregations( + events, user_id + ) time_now = self.clock.time_msec() diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8e475475ad..57135d4519 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,18 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +import attr +from frozendict import frozendict + +from synapse.api.constants import RelationTypes from synapse.api.errors import SynapseError +from synapse.events import EventBase from synapse.types import JsonDict, Requester, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + # The latest event in the thread. + latest_event: EventBase + # The latest edit to the latest event in the thread. + latest_edit: Optional[EventBase] + # The total number of events in the thread. + count: int + # True if the current user has sent an event to the thread. + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsHandler: def __init__(self, hs: "HomeServer"): self._main_store = hs.get_datastores().main @@ -103,7 +138,7 @@ class RelationsHandler: ) # The relations returned for the requested event do include their # bundled aggregations. - aggregations = await self._main_store.get_bundled_aggregations( + aggregations = await self.get_bundled_aggregations( events, requester.user.to_string() ) serialized_events = self._event_serializer.serialize_events( @@ -115,3 +150,115 @@ class RelationsHandler: return_value["original_event"] = original_event return return_value + + async def _get_bundled_aggregation_for_event( + self, event: EventBase, user_id: str + ) -> Optional[BundledAggregations]: + """Generate bundled aggregations for an event. + + Note that this does not use a cache, but depends on cached methods. + + Args: + event: The event to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + The bundled aggregations for an event, if bundled aggregations are + enabled and the event can have bundled aggregations. + """ + + # Do not bundle aggregations for an event which represents an edit or an + # annotation. It does not make sense for them to have related events. + relates_to = event.content.get("m.relates_to") + if isinstance(relates_to, (dict, frozendict)): + relation_type = relates_to.get("rel_type") + if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): + return None + + event_id = event.event_id + room_id = event.room_id + + # The bundled aggregations to include, a mapping of relation type to a + # type-specific value. Some types include the direct return type here + # while others need more processing during serialization. + aggregations = BundledAggregations() + + annotations = await self._main_store.get_aggregation_groups_for_event( + event_id, room_id + ) + if annotations.chunk: + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) + + references = await self._main_store.get_relations_for_event( + event_id, event, room_id, RelationTypes.REFERENCE, direction="f" + ) + if references.chunk: + aggregations.references = await references.to_dict(cast("DataStore", self)) + + # Store the bundled aggregations in the event metadata for later use. + return aggregations + + async def get_bundled_aggregations( + self, events: Iterable[EventBase], user_id: str + ) -> Dict[str, BundledAggregations]: + """Generate bundled aggregations for events. + + Args: + events: The iterable of events to calculate bundled aggregations for. + user_id: The user requesting the bundled aggregations. + + Returns: + A map of event ID to the bundled aggregation for the event. Not all + events may have bundled aggregations in the results. + """ + # De-duplicate events by ID to handle the same event requested multiple times. + # + # State events do not get bundled aggregations. + events_by_id = { + event.event_id: event for event in events if not event.is_state() + } + + # event ID -> bundled aggregation in non-serialized form. + results: Dict[str, BundledAggregations] = {} + + # Fetch other relations per event. + for event in events_by_id.values(): + event_result = await self._get_bundled_aggregation_for_event(event, user_id) + if event_result: + results[event.event_id] = event_result + + # Fetch any edits (but not for redacted events). + edits = await self._main_store.get_applicable_edits( + [ + event_id + for event_id, event in events_by_id.items() + if not event.internal_metadata.is_redacted() + ] + ) + for event_id, edit in edits.items(): + results.setdefault(event_id, BundledAggregations()).replace = edit + + # Fetch thread summaries. + summaries = await self._main_store.get_thread_summaries(events_by_id.keys()) + # Only fetch participated for a limited selection based on what had + # summaries. + participated = await self._main_store.get_threads_participated( + [event_id for event_id, summary in summaries.items() if summary], user_id + ) + for event_id, summary in summaries.items(): + if summary: + thread_count, latest_thread_event, edit = summary + results.setdefault( + event_id, BundledAggregations() + ).thread = _ThreadAggregation( + latest_event=latest_thread_event, + latest_edit=edit, + count=thread_count, + # If there's a thread summary it must also exist in the + # participated dictionary. + current_user_participated=participated[event_id], + ) + + return results diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index b9735631fc..092e185c99 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -60,8 +60,8 @@ from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state +from synapse.handlers.relations import BundledAggregations from synapse.rest.admin._base import assert_user_is_admin -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -1118,6 +1118,7 @@ class RoomContextHandler: self.store = hs.get_datastores().main self.storage = hs.get_storage() self.state_store = self.storage.state + self._relations_handler = hs.get_relations_handler() async def get_event_context( self, @@ -1190,7 +1191,7 @@ class RoomContextHandler: event = filtered[0] # Fetch the aggregations. - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( itertools.chain(events_before, (event,), events_after), user.to_string(), ) diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index aa16e417eb..30eddda65f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -54,6 +54,7 @@ class SearchHandler: self.clock = hs.get_clock() self.hs = hs self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.storage = hs.get_storage() self.state_store = self.storage.state self.auth = hs.get_auth() @@ -354,7 +355,7 @@ class SearchHandler: aggregations = None if self._msc3666_enabled: - aggregations = await self.store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( # Generate an iterable of EventBase for all the events that will be # returned, including contextual events. itertools.chain( diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c9d6a18bd7..6c569cfb1c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -33,11 +33,11 @@ from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.handlers.relations import BundledAggregations from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts -from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -269,6 +269,7 @@ class SyncHandler: self.store = hs.get_datastores().main self.notifier = hs.get_notifier() self.presence_handler = hs.get_presence_handler() + self._relations_handler = hs.get_relations_handler() self.event_sources = hs.get_event_sources() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -638,8 +639,10 @@ class SyncHandler: # as clients will have all the necessary information. bundled_aggregations = None if limited or newly_joined_room: - bundled_aggregations = await self.store.get_bundled_aggregations( - recents, sync_config.user.to_string() + bundled_aggregations = ( + await self._relations_handler.get_bundled_aggregations( + recents, sync_config.user.to_string() + ) ) return TimelineBatch( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 8a06ab8c5f..47e152c8cc 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet): self._store = hs.get_datastores().main self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() + self._relations_handler = hs.get_relations_handler() self.auth = hs.get_auth() async def on_GET( @@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet): if event: # Ensure there are bundled aggregations available. - aggregations = await self._store.get_bundled_aggregations( + aggregations = await self._relations_handler.get_bundled_aggregations( [event], requester.user.to_string() ) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index af2334a65e..b2295fd51f 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -27,7 +27,6 @@ from typing import ( ) import attr -from frozendict import frozendict from synapse.api.constants import RelationTypes from synapse.events import EventBase @@ -41,45 +40,15 @@ from synapse.storage.database import ( from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine from synapse.storage.relations import AggregationPaginationToken, PaginationChunk -from synapse.types import JsonDict, RoomStreamToken, StreamToken +from synapse.types import RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _ThreadAggregation: - # The latest event in the thread. - latest_event: EventBase - # The latest edit to the latest event in the thread. - latest_edit: Optional[EventBase] - # The total number of events in the thread. - count: int - # True if the current user has sent an event to the thread. - current_user_participated: bool - - -@attr.s(slots=True, auto_attribs=True) -class BundledAggregations: - """ - The bundled aggregations for an event. - - Some values require additional processing during serialization. - """ - - annotations: Optional[JsonDict] = None - references: Optional[JsonDict] = None - replace: Optional[EventBase] = None - thread: Optional[_ThreadAggregation] = None - - def __bool__(self) -> bool: - return bool(self.annotations or self.references or self.replace or self.thread) - - class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids") - async def _get_applicable_edits( + async def get_applicable_edits( self, event_ids: Collection[str] ) -> Dict[str, Optional[EventBase]]: """Get the most recent edit (if any) that has happened for the given @@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_summary", list_name="event_ids") - async def _get_thread_summaries( + async def get_thread_summaries( self, event_ids: Collection[str] ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]: """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event. @@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore): latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] # Check to see if any of those events are edited. - latest_edits = await self._get_applicable_edits(latest_event_ids.values()) + latest_edits = await self.get_applicable_edits(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore): raise NotImplementedError() @cachedList(cached_method_name="get_thread_participated", list_name="event_ids") - async def _get_threads_participated( + async def get_threads_participated( self, event_ids: Collection[str], user_id: str ) -> Dict[str, bool]: """Get whether the requesting user participated in the given threads. @@ -766,116 +735,6 @@ class RelationsWorkerStore(SQLBaseStore): "get_if_user_has_annotated_event", _get_if_user_has_annotated_event ) - async def _get_bundled_aggregation_for_event( - self, event: EventBase, user_id: str - ) -> Optional[BundledAggregations]: - """Generate bundled aggregations for an event. - - Note that this does not use a cache, but depends on cached methods. - - Args: - event: The event to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - The bundled aggregations for an event, if bundled aggregations are - enabled and the event can have bundled aggregations. - """ - - # Do not bundle aggregations for an event which represents an edit or an - # annotation. It does not make sense for them to have related events. - relates_to = event.content.get("m.relates_to") - if isinstance(relates_to, (dict, frozendict)): - relation_type = relates_to.get("rel_type") - if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE): - return None - - event_id = event.event_id - room_id = event.room_id - - # The bundled aggregations to include, a mapping of relation type to a - # type-specific value. Some types include the direct return type here - # while others need more processing during serialization. - aggregations = BundledAggregations() - - annotations = await self.get_aggregation_groups_for_event(event_id, room_id) - if annotations.chunk: - aggregations.annotations = await annotations.to_dict( - cast("DataStore", self) - ) - - references = await self.get_relations_for_event( - event_id, event, room_id, RelationTypes.REFERENCE, direction="f" - ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) - - # Store the bundled aggregations in the event metadata for later use. - return aggregations - - async def get_bundled_aggregations( - self, events: Iterable[EventBase], user_id: str - ) -> Dict[str, BundledAggregations]: - """Generate bundled aggregations for events. - - Args: - events: The iterable of events to calculate bundled aggregations for. - user_id: The user requesting the bundled aggregations. - - Returns: - A map of event ID to the bundled aggregation for the event. Not all - events may have bundled aggregations in the results. - """ - # De-duplicate events by ID to handle the same event requested multiple times. - # - # State events do not get bundled aggregations. - events_by_id = { - event.event_id: event for event in events if not event.is_state() - } - - # event ID -> bundled aggregation in non-serialized form. - results: Dict[str, BundledAggregations] = {} - - # Fetch other relations per event. - for event in events_by_id.values(): - event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result: - results[event.event_id] = event_result - - # Fetch any edits (but not for redacted events). - edits = await self._get_applicable_edits( - [ - event_id - for event_id, event in events_by_id.items() - if not event.internal_metadata.is_redacted() - ] - ) - for event_id, edit in edits.items(): - results.setdefault(event_id, BundledAggregations()).replace = edit - - # Fetch thread summaries. - summaries = await self._get_thread_summaries(events_by_id.keys()) - # Only fetch participated for a limited selection based on what had - # summaries. - participated = await self._get_threads_participated( - [event_id for event_id, summary in summaries.items() if summary], user_id - ) - for event_id, summary in summaries.items(): - if summary: - thread_count, latest_thread_event, edit = summary - results.setdefault( - event_id, BundledAggregations() - ).thread = _ThreadAggregation( - latest_event=latest_thread_event, - latest_edit=edit, - count=thread_count, - # If there's a thread summary it must also exist in the - # participated dictionary. - current_user_participated=participated[event_id], - ) - - return results - class RelationsStore(RelationsWorkerStore): pass -- cgit 1.5.1