summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/relations.py61
-rw-r--r--synapse/storage/databases/main/stream.py22
2 files changed, 53 insertions, 30 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c13..a9a5dd5f03 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-    Union,
-    cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
 
 import attr
 from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
     PaginationChunk,
     RelationPaginationToken,
 )
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+    latest_event: EventBase
+    count: int
+    current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+    """
+    The bundled aggregations for an event.
+
+    Some values require additional processing during serialization.
+    """
+
+    annotations: Optional[JsonDict] = None
+    references: Optional[JsonDict] = None
+    replace: Optional[EventBase] = None
+    thread: Optional[_ThreadAggregation] = None
+
+    def __bool__(self) -> bool:
+        return bool(self.annotations or self.references or self.replace or self.thread)
+
+
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
     async def _get_bundled_aggregation_for_event(
         self, event: EventBase, user_id: str
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[BundledAggregations]:
         """Generate bundled aggregations for an event.
 
         Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore):
         # The bundled aggregations to include, a mapping of relation type to a
         # type-specific value. Some types include the direct return type here
         # while others need more processing during serialization.
-        aggregations: Dict[str, Any] = {}
+        aggregations = BundledAggregations()
 
         annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
         if annotations.chunk:
-            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+            aggregations.annotations = annotations.to_dict()
 
         references = await self.get_relations_for_event(
             event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
-            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+            aggregations.references = references.to_dict()
 
         edit = None
         if event.type == EventTypes.Message:
             edit = await self.get_applicable_edit(event_id, room_id)
 
         if edit:
-            aggregations[RelationTypes.REPLACE] = edit
+            aggregations.replace = edit
 
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
@@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore):
                 event_id, room_id, user_id
             )
             if latest_thread_event:
-                aggregations[RelationTypes.THREAD] = {
-                    "latest_event": latest_thread_event,
-                    "count": thread_count,
-                    "current_user_participated": participated,
-                }
+                aggregations.thread = _ThreadAggregation(
+                    latest_event=latest_thread_event,
+                    count=thread_count,
+                    current_user_participated=participated,
+                )
 
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
@@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore):
         self,
         events: Iterable[EventBase],
         user_id: str,
-    ) -> Dict[str, Dict[str, Any]]:
+    ) -> Dict[str, BundledAggregations]:
         """Generate bundled aggregations for events.
 
         Args:
@@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore):
         results = {}
         for event in events:
             event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result is not None:
+            if event_result:
                 results[event.event_id] = event_result
 
         return results
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 319464b1fa..a898f847e7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -81,6 +81,14 @@ class _EventDictReturn:
     stream_ordering: int
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventsAround:
+    events_before: List[EventBase]
+    events_after: List[EventBase]
+    start: RoomStreamToken
+    end: RoomStreamToken
+
+
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         before_limit: int,
         after_limit: int,
         event_filter: Optional[Filter] = None,
-    ) -> dict:
+    ) -> _EventsAround:
         """Retrieve events and pagination tokens around a given event in a
         room.
         """
@@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             list(results["after"]["event_ids"]), get_prev_content=True
         )
 
-        return {
-            "events_before": events_before,
-            "events_after": events_after,
-            "start": results["before"]["token"],
-            "end": results["after"]["token"],
-        }
+        return _EventsAround(
+            events_before=events_before,
+            events_after=events_after,
+            start=results["before"]["token"],
+            end=results["after"]["token"],
+        )
 
     def _get_events_around_txn(
         self,