summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-01-26 08:27:04 -0500
committerGitHub <noreply@github.com>2022-01-26 08:27:04 -0500
commit2897fb6b4fb8bdaea0e919233d5ccaf5dea12742 (patch)
tree86973f5af5bee99ca612fe553372eac4fa7f1080 /synapse/storage/databases/main/stream.py
parentDon't print HTTPStatus.* in "Processed..." logs (#11827) (diff)
downloadsynapse-2897fb6b4fb8bdaea0e919233d5ccaf5dea12742.tar.xz
Improvements to bundling aggregations. (#11815)
This is some odds and ends found during the review of #11791
and while continuing to work in this code:

* Return attrs classes instead of dictionaries from some methods
  to improve type safety.
* Call `get_bundled_aggregations` fewer times.
* Adds a missing assertion in the tests.
* Do not return empty bundled aggregations for an event (preferring
  to not include the bundle at all, as the docstring states).
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py22
1 files changed, 15 insertions, 7 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 319464b1fa..a898f847e7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -81,6 +81,14 @@ class _EventDictReturn:
     stream_ordering: int
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventsAround:
+    events_before: List[EventBase]
+    events_after: List[EventBase]
+    start: RoomStreamToken
+    end: RoomStreamToken
+
+
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         before_limit: int,
         after_limit: int,
         event_filter: Optional[Filter] = None,
-    ) -> dict:
+    ) -> _EventsAround:
         """Retrieve events and pagination tokens around a given event in a
         room.
         """
@@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             list(results["after"]["event_ids"]), get_prev_content=True
         )
 
-        return {
-            "events_before": events_before,
-            "events_after": events_after,
-            "start": results["before"]["token"],
-            "end": results["after"]["token"],
-        }
+        return _EventsAround(
+            events_before=events_before,
+            events_after=events_after,
+            start=results["before"]["token"],
+            end=results["after"]["token"],
+        )
 
     def _get_events_around_txn(
         self,