summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12338.misc1
-rw-r--r--synapse/handlers/relations.py35
-rw-r--r--synapse/storage/databases/main/relations.py17
-rw-r--r--synapse/storage/relations.py53
4 files changed, 32 insertions, 74 deletions
diff --git a/changelog.d/12338.misc b/changelog.d/12338.misc
new file mode 100644
index 0000000000..376089f327
--- /dev/null
+++ b/changelog.d/12338.misc
@@ -0,0 +1 @@
+Refactor relations code to remove an unnecessary class.
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index b9497ff3f3..a36936b520 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
 
 import attr
 from frozendict import frozendict
@@ -25,7 +25,6 @@ from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
-    from synapse.storage.databases.main import DataStore
 
 
 logger = logging.getLogger(__name__)
@@ -116,7 +115,7 @@ class RelationsHandler:
         if event is None:
             raise SynapseError(404, "Unknown parent event.")
 
-        pagination_chunk = await self._main_store.get_relations_for_event(
+        related_events, next_token = await self._main_store.get_relations_for_event(
             event_id=event_id,
             event=event,
             room_id=room_id,
@@ -129,9 +128,7 @@ class RelationsHandler:
             to_token=to_token,
         )
 
-        events = await self._main_store.get_events_as_list(
-            [c["event_id"] for c in pagination_chunk.chunk]
-        )
+        events = await self._main_store.get_events_as_list(related_events)
 
         events = await filter_events_for_client(
             self._storage, user_id, events, is_peeking=(member_event_id is None)
@@ -152,9 +149,16 @@ class RelationsHandler:
             events, now, bundle_aggregations=aggregations
         )
 
-        return_value = await pagination_chunk.to_dict(self._main_store)
-        return_value["chunk"] = serialized_events
-        return_value["original_event"] = original_event
+        return_value = {
+            "chunk": serialized_events,
+            "original_event": original_event,
+        }
+
+        if next_token:
+            return_value["next_batch"] = await next_token.to_string(self._main_store)
+
+        if from_token:
+            return_value["prev_batch"] = await from_token.to_string(self._main_store)
 
         return return_value
 
@@ -196,11 +200,18 @@ class RelationsHandler:
         if annotations:
             aggregations.annotations = {"chunk": annotations}
 
-        references = await self._main_store.get_relations_for_event(
+        references, next_token = await self._main_store.get_relations_for_event(
             event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
         )
-        if references.chunk:
-            aggregations.references = await references.to_dict(cast("DataStore", self))
+        if references:
+            aggregations.references = {
+                "chunk": [{"event_id": event_id} for event_id in references]
+            }
+
+            if next_token:
+                aggregations.references["next_batch"] = await next_token.to_string(
+                    self._main_store
+                )
 
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 3285450742..64a7808140 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -37,7 +37,6 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import PaginationChunk
 from synapse.types import JsonDict, RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -71,7 +70,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: str = "b",
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
-    ) -> PaginationChunk:
+    ) -> Tuple[List[str], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
@@ -88,8 +87,10 @@ class RelationsWorkerStore(SQLBaseStore):
             to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
-            List of event IDs that match relations requested. The rows are of
-            the form `{"event_id": "..."}`.
+            A tuple of:
+                A list of related event IDs
+
+                The next stream token, if one exists.
         """
         # We don't use `event_id`, it's there so that we can cache based on
         # it. The `event_id` must match the `event.event_id`.
@@ -144,7 +145,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         def _get_recent_references_for_event_txn(
             txn: LoggingTransaction,
-        ) -> PaginationChunk:
+        ) -> Tuple[List[str], Optional[StreamToken]]:
             txn.execute(sql, where_args + [limit + 1])
 
             last_topo_id = None
@@ -154,7 +155,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 # Do not include edits for redacted events as they leak event
                 # content.
                 if not is_redacted or row[1] != RelationTypes.REPLACE:
-                    events.append({"event_id": row[0]})
+                    events.append(row[0])
                 last_topo_id = row[2]
                 last_stream_id = row[3]
 
@@ -177,9 +178,7 @@ class RelationsWorkerStore(SQLBaseStore):
                         groups_key=0,
                     )
 
-            return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
-            )
+            return events[:limit], next_token
 
         return await self.db_pool.runInteraction(
             "get_recent_references_for_event", _get_recent_references_for_event_txn
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
deleted file mode 100644
index b9d2b46799..0000000000
--- a/synapse/storage/relations.py
+++ /dev/null
@@ -1,53 +0,0 @@
-# Copyright 2019 New Vector Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
-
-import attr
-
-from synapse.types import JsonDict
-
-if TYPE_CHECKING:
-    from synapse.storage.databases.main import DataStore
-
-logger = logging.getLogger(__name__)
-
-
-@attr.s(slots=True, auto_attribs=True)
-class PaginationChunk:
-    """Returned by relation pagination APIs.
-
-    Attributes:
-        chunk: The rows returned by pagination
-        next_batch: Token to fetch next set of results with, if
-            None then there are no more results.
-        prev_batch: Token to fetch previous set of results with, if
-            None then there are no previous results.
-    """
-
-    chunk: List[JsonDict]
-    next_batch: Optional[Any] = None
-    prev_batch: Optional[Any] = None
-
-    async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
-        d = {"chunk": self.chunk}
-
-        if self.next_batch:
-            d["next_batch"] = await self.next_batch.to_string(store)
-
-        if self.prev_batch:
-            d["prev_batch"] = await self.prev_batch.to_string(store)
-
-        return d