summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-03-16 10:39:15 -0400
committerGitHub <noreply@github.com>2022-03-16 10:39:15 -0400
commitfc9bd620ce94b64af46737e25a524336967782a1 (patch)
tree355f7d10a96a8ef9e47445b266d16b45f0ad0f63 /synapse
parentAdd some missing type hints to cache datastore. (#12216) (diff)
downloadsynapse-fc9bd620ce94b64af46737e25a524336967782a1.tar.xz
Add a relations handler to avoid duplication. (#12227)
Adds a handler layer between the REST and datastore layers for relations.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/pagination.py5
-rw-r--r--synapse/handlers/relations.py117
-rw-r--r--synapse/rest/client/relations.py75
-rw-r--r--synapse/server.py5
4 files changed, 133 insertions, 69 deletions
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 60059fec3e..41679f7f86 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
 
 import attr
 
@@ -422,7 +422,7 @@ class PaginationHandler:
         pagin_config: PaginationConfig,
         as_client_event: bool = True,
         event_filter: Optional[Filter] = None,
-    ) -> Dict[str, Any]:
+    ) -> JsonDict:
         """Get messages in a room.
 
         Args:
@@ -431,6 +431,7 @@ class PaginationHandler:
             pagin_config: The pagination config rules to apply, if any.
             as_client_event: True to get events in client-server format.
             event_filter: Filter to apply to results or None
+
         Returns:
             Pagination API results
         """
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
new file mode 100644
index 0000000000..8e475475ad
--- /dev/null
+++ b/synapse/handlers/relations.py
@@ -0,0 +1,117 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Optional
+
+from synapse.api.errors import SynapseError
+from synapse.types import JsonDict, Requester, StreamToken
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+
+logger = logging.getLogger(__name__)
+
+
+class RelationsHandler:
+    def __init__(self, hs: "HomeServer"):
+        self._main_store = hs.get_datastores().main
+        self._auth = hs.get_auth()
+        self._clock = hs.get_clock()
+        self._event_handler = hs.get_event_handler()
+        self._event_serializer = hs.get_event_client_serializer()
+
+    async def get_relations(
+        self,
+        requester: Requester,
+        event_id: str,
+        room_id: str,
+        relation_type: Optional[str] = None,
+        event_type: Optional[str] = None,
+        aggregation_key: Optional[str] = None,
+        limit: int = 5,
+        direction: str = "b",
+        from_token: Optional[StreamToken] = None,
+        to_token: Optional[StreamToken] = None,
+    ) -> JsonDict:
+        """Get related events of a event, ordered by topological ordering.
+
+        TODO Accept a PaginationConfig instead of individual pagination parameters.
+
+        Args:
+            requester: The user requesting the relations.
+            event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
+            relation_type: Only fetch events with this relation type, if given.
+            event_type: Only fetch events with this event type, if given.
+            aggregation_key: Only fetch events with this aggregation key, if given.
+            limit: Only fetch the most recent `limit` events.
+            direction: Whether to fetch the most recent first (`"b"`) or the
+                oldest first (`"f"`).
+            from_token: Fetch rows from the given token, or from the start if None.
+            to_token: Fetch rows up to the given token, or up to the end if None.
+
+        Returns:
+            The pagination chunk.
+        """
+
+        user_id = requester.user.to_string()
+
+        await self._auth.check_user_in_room_or_world_readable(
+            room_id, user_id, allow_departed_users=True
+        )
+
+        # This gets the original event and checks that a) the event exists and
+        # b) the user is allowed to view it.
+        event = await self._event_handler.get_event(requester.user, room_id, event_id)
+        if event is None:
+            raise SynapseError(404, "Unknown parent event.")
+
+        pagination_chunk = await self._main_store.get_relations_for_event(
+            event_id=event_id,
+            event=event,
+            room_id=room_id,
+            relation_type=relation_type,
+            event_type=event_type,
+            aggregation_key=aggregation_key,
+            limit=limit,
+            direction=direction,
+            from_token=from_token,
+            to_token=to_token,
+        )
+
+        events = await self._main_store.get_events_as_list(
+            [c["event_id"] for c in pagination_chunk.chunk]
+        )
+
+        now = self._clock.time_msec()
+        # Do not bundle aggregations when retrieving the original event because
+        # we want the content before relations are applied to it.
+        original_event = self._event_serializer.serialize_event(
+            event, now, bundle_aggregations=None
+        )
+        # The relations returned for the requested event do include their
+        # bundled aggregations.
+        aggregations = await self._main_store.get_bundled_aggregations(
+            events, requester.user.to_string()
+        )
+        serialized_events = self._event_serializer.serialize_events(
+            events, now, bundle_aggregations=aggregations
+        )
+
+        return_value = await pagination_chunk.to_dict(self._main_store)
+        return_value["chunk"] = serialized_events
+        return_value["original_event"] = original_event
+
+        return return_value
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index d9a6be43f7..c16078b187 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -51,9 +51,7 @@ class RelationPaginationServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.clock = hs.get_clock()
-        self._event_serializer = hs.get_event_client_serializer()
-        self.event_handler = hs.get_event_handler()
+        self._relations_handler = hs.get_relations_handler()
 
     async def on_GET(
         self,
@@ -65,16 +63,6 @@ class RelationPaginationServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        await self.auth.check_user_in_room_or_world_readable(
-            room_id, requester.user.to_string(), allow_departed_users=True
-        )
-
-        # This gets the original event and checks that a) the event exists and
-        # b) the user is allowed to view it.
-        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
-        if event is None:
-            raise SynapseError(404, "Unknown parent event.")
-
         limit = parse_integer(request, "limit", default=5)
         direction = parse_string(
             request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
@@ -90,9 +78,9 @@ class RelationPaginationServlet(RestServlet):
         if to_token_str:
             to_token = await StreamToken.from_string(self.store, to_token_str)
 
-        pagination_chunk = await self.store.get_relations_for_event(
+        result = await self._relations_handler.get_relations(
+            requester=requester,
             event_id=parent_id,
-            event=event,
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
@@ -102,30 +90,7 @@ class RelationPaginationServlet(RestServlet):
             to_token=to_token,
         )
 
-        events = await self.store.get_events_as_list(
-            [c["event_id"] for c in pagination_chunk.chunk]
-        )
-
-        now = self.clock.time_msec()
-        # Do not bundle aggregations when retrieving the original event because
-        # we want the content before relations are applied to it.
-        original_event = self._event_serializer.serialize_event(
-            event, now, bundle_aggregations=None
-        )
-        # The relations returned for the requested event do include their
-        # bundled aggregations.
-        aggregations = await self.store.get_bundled_aggregations(
-            events, requester.user.to_string()
-        )
-        serialized_events = self._event_serializer.serialize_events(
-            events, now, bundle_aggregations=aggregations
-        )
-
-        return_value = await pagination_chunk.to_dict(self.store)
-        return_value["chunk"] = serialized_events
-        return_value["original_event"] = original_event
-
-        return 200, return_value
+        return 200, result
 
 
 class RelationAggregationPaginationServlet(RestServlet):
@@ -245,9 +210,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.clock = hs.get_clock()
-        self._event_serializer = hs.get_event_client_serializer()
-        self.event_handler = hs.get_event_handler()
+        self._relations_handler = hs.get_relations_handler()
 
     async def on_GET(
         self,
@@ -260,18 +223,6 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        await self.auth.check_user_in_room_or_world_readable(
-            room_id,
-            requester.user.to_string(),
-            allow_departed_users=True,
-        )
-
-        # This checks that a) the event exists and b) the user is allowed to
-        # view it.
-        event = await self.event_handler.get_event(requester.user, room_id, parent_id)
-        if event is None:
-            raise SynapseError(404, "Unknown parent event.")
-
         if relation_type != RelationTypes.ANNOTATION:
             raise SynapseError(400, "Relation type must be 'annotation'")
 
@@ -286,9 +237,9 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         if to_token_str:
             to_token = await StreamToken.from_string(self.store, to_token_str)
 
-        result = await self.store.get_relations_for_event(
+        result = await self._relations_handler.get_relations(
+            requester=requester,
             event_id=parent_id,
-            event=event,
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
@@ -298,17 +249,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
             to_token=to_token,
         )
 
-        events = await self.store.get_events_as_list(
-            [c["event_id"] for c in result.chunk]
-        )
-
-        now = self.clock.time_msec()
-        serialized_events = self._event_serializer.serialize_events(events, now)
-
-        return_value = await result.to_dict(self.store)
-        return_value["chunk"] = serialized_events
-
-        return 200, return_value
+        return 200, result
 
 
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
diff --git a/synapse/server.py b/synapse/server.py
index 2fcf18a7a6..380369db92 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -94,6 +94,7 @@ from synapse.handlers.profile import ProfileHandler
 from synapse.handlers.read_marker import ReadMarkerHandler
 from synapse.handlers.receipts import ReceiptsHandler
 from synapse.handlers.register import RegistrationHandler
+from synapse.handlers.relations import RelationsHandler
 from synapse.handlers.room import (
     RoomContextHandler,
     RoomCreationHandler,
@@ -720,6 +721,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return PaginationHandler(self)
 
     @cache_in_self
+    def get_relations_handler(self) -> RelationsHandler:
+        return RelationsHandler(self)
+
+    @cache_in_self
     def get_room_context_handler(self) -> RoomContextHandler:
         return RoomContextHandler(self)