summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12237.misc1
-rw-r--r--synapse/events/utils.py2
-rw-r--r--synapse/handlers/pagination.py5
-rw-r--r--synapse/handlers/relations.py151
-rw-r--r--synapse/handlers/room.py5
-rw-r--r--synapse/handlers/search.py3
-rw-r--r--synapse/handlers/sync.py9
-rw-r--r--synapse/rest/client/room.py3
-rw-r--r--synapse/storage/databases/main/relations.py151
9 files changed, 173 insertions, 157 deletions
diff --git a/changelog.d/12237.misc b/changelog.d/12237.misc
new file mode 100644
index 0000000000..41c9dcbd37
--- /dev/null
+++ b/changelog.d/12237.misc
@@ -0,0 +1 @@
+Refactor the relations endpoints to add a `RelationsHandler`.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index a0520068e0..7120062127 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -38,8 +38,8 @@ from synapse.util.frozenutils import unfreeze
 from . import EventBase
 
 if TYPE_CHECKING:
+    from synapse.handlers.relations import BundledAggregations
     from synapse.server import HomeServer
-    from synapse.storage.databases.main.relations import BundledAggregations
 
 
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 41679f7f86..876b879483 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -134,6 +134,7 @@ class PaginationHandler:
         self.clock = hs.get_clock()
         self._server_name = hs.hostname
         self._room_shutdown_handler = hs.get_room_shutdown_handler()
+        self._relations_handler = hs.get_relations_handler()
 
         self.pagination_lock = ReadWriteLock()
         # IDs of rooms in which there currently an active purge *or delete* operation.
@@ -539,7 +540,9 @@ class PaginationHandler:
                 state_dict = await self.store.get_events(list(state_ids.values()))
                 state = state_dict.values()
 
-        aggregations = await self.store.get_bundled_aggregations(events, user_id)
+        aggregations = await self._relations_handler.get_bundled_aggregations(
+            events, user_id
+        )
 
         time_now = self.clock.time_msec()
 
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 8e475475ad..57135d4519 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -12,18 +12,53 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
 
+import attr
+from frozendict import frozendict
+
+from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
+from synapse.events import EventBase
 from synapse.types import JsonDict, Requester, StreamToken
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+    # The latest event in the thread.
+    latest_event: EventBase
+    # The latest edit to the latest event in the thread.
+    latest_edit: Optional[EventBase]
+    # The total number of events in the thread.
+    count: int
+    # True if the current user has sent an event to the thread.
+    current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+    """
+    The bundled aggregations for an event.
+
+    Some values require additional processing during serialization.
+    """
+
+    annotations: Optional[JsonDict] = None
+    references: Optional[JsonDict] = None
+    replace: Optional[EventBase] = None
+    thread: Optional[_ThreadAggregation] = None
+
+    def __bool__(self) -> bool:
+        return bool(self.annotations or self.references or self.replace or self.thread)
+
+
 class RelationsHandler:
     def __init__(self, hs: "HomeServer"):
         self._main_store = hs.get_datastores().main
@@ -103,7 +138,7 @@ class RelationsHandler:
         )
         # The relations returned for the requested event do include their
         # bundled aggregations.
-        aggregations = await self._main_store.get_bundled_aggregations(
+        aggregations = await self.get_bundled_aggregations(
             events, requester.user.to_string()
         )
         serialized_events = self._event_serializer.serialize_events(
@@ -115,3 +150,115 @@ class RelationsHandler:
         return_value["original_event"] = original_event
 
         return return_value
+
+    async def _get_bundled_aggregation_for_event(
+        self, event: EventBase, user_id: str
+    ) -> Optional[BundledAggregations]:
+        """Generate bundled aggregations for an event.
+
+        Note that this does not use a cache, but depends on cached methods.
+
+        Args:
+            event: The event to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
+
+        Returns:
+            The bundled aggregations for an event, if bundled aggregations are
+            enabled and the event can have bundled aggregations.
+        """
+
+        # Do not bundle aggregations for an event which represents an edit or an
+        # annotation. It does not make sense for them to have related events.
+        relates_to = event.content.get("m.relates_to")
+        if isinstance(relates_to, (dict, frozendict)):
+            relation_type = relates_to.get("rel_type")
+            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                return None
+
+        event_id = event.event_id
+        room_id = event.room_id
+
+        # The bundled aggregations to include, a mapping of relation type to a
+        # type-specific value. Some types include the direct return type here
+        # while others need more processing during serialization.
+        aggregations = BundledAggregations()
+
+        annotations = await self._main_store.get_aggregation_groups_for_event(
+            event_id, room_id
+        )
+        if annotations.chunk:
+            aggregations.annotations = await annotations.to_dict(
+                cast("DataStore", self)
+            )
+
+        references = await self._main_store.get_relations_for_event(
+            event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
+        )
+        if references.chunk:
+            aggregations.references = await references.to_dict(cast("DataStore", self))
+
+        # Store the bundled aggregations in the event metadata for later use.
+        return aggregations
+
+    async def get_bundled_aggregations(
+        self, events: Iterable[EventBase], user_id: str
+    ) -> Dict[str, BundledAggregations]:
+        """Generate bundled aggregations for events.
+
+        Args:
+            events: The iterable of events to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
+
+        Returns:
+            A map of event ID to the bundled aggregation for the event. Not all
+            events may have bundled aggregations in the results.
+        """
+        # De-duplicate events by ID to handle the same event requested multiple times.
+        #
+        # State events do not get bundled aggregations.
+        events_by_id = {
+            event.event_id: event for event in events if not event.is_state()
+        }
+
+        # event ID -> bundled aggregation in non-serialized form.
+        results: Dict[str, BundledAggregations] = {}
+
+        # Fetch other relations per event.
+        for event in events_by_id.values():
+            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+            if event_result:
+                results[event.event_id] = event_result
+
+        # Fetch any edits (but not for redacted events).
+        edits = await self._main_store.get_applicable_edits(
+            [
+                event_id
+                for event_id, event in events_by_id.items()
+                if not event.internal_metadata.is_redacted()
+            ]
+        )
+        for event_id, edit in edits.items():
+            results.setdefault(event_id, BundledAggregations()).replace = edit
+
+        # Fetch thread summaries.
+        summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
+        # Only fetch participated for a limited selection based on what had
+        # summaries.
+        participated = await self._main_store.get_threads_participated(
+            [event_id for event_id, summary in summaries.items() if summary], user_id
+        )
+        for event_id, summary in summaries.items():
+            if summary:
+                thread_count, latest_thread_event, edit = summary
+                results.setdefault(
+                    event_id, BundledAggregations()
+                ).thread = _ThreadAggregation(
+                    latest_event=latest_thread_event,
+                    latest_edit=edit,
+                    count=thread_count,
+                    # If there's a thread summary it must also exist in the
+                    # participated dictionary.
+                    current_user_participated=participated[event_id],
+                )
+
+        return results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b9735631fc..092e185c99 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,8 +60,8 @@ from synapse.events import EventBase
 from synapse.events.utils import copy_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers.federation import get_domains_from_state
+from synapse.handlers.relations import BundledAggregations
 from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
 from synapse.types import (
@@ -1118,6 +1118,7 @@ class RoomContextHandler:
         self.store = hs.get_datastores().main
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
+        self._relations_handler = hs.get_relations_handler()
 
     async def get_event_context(
         self,
@@ -1190,7 +1191,7 @@ class RoomContextHandler:
         event = filtered[0]
 
         # Fetch the aggregations.
-        aggregations = await self.store.get_bundled_aggregations(
+        aggregations = await self._relations_handler.get_bundled_aggregations(
             itertools.chain(events_before, (event,), events_after),
             user.to_string(),
         )
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index aa16e417eb..30eddda65f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -54,6 +54,7 @@ class SearchHandler:
         self.clock = hs.get_clock()
         self.hs = hs
         self._event_serializer = hs.get_event_client_serializer()
+        self._relations_handler = hs.get_relations_handler()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
         self.auth = hs.get_auth()
@@ -354,7 +355,7 @@ class SearchHandler:
 
         aggregations = None
         if self._msc3666_enabled:
-            aggregations = await self.store.get_bundled_aggregations(
+            aggregations = await self._relations_handler.get_bundled_aggregations(
                 # Generate an iterable of EventBase for all the events that will be
                 # returned, including contextual events.
                 itertools.chain(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index c9d6a18bd7..6c569cfb1c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -33,11 +33,11 @@ from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
+from synapse.handlers.relations import BundledAggregations
 from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.databases.main.event_push_actions import NotifCounts
-from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -269,6 +269,7 @@ class SyncHandler:
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.presence_handler = hs.get_presence_handler()
+        self._relations_handler = hs.get_relations_handler()
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
         self.state = hs.get_state_handler()
@@ -638,8 +639,10 @@ class SyncHandler:
         # as clients will have all the necessary information.
         bundled_aggregations = None
         if limited or newly_joined_room:
-            bundled_aggregations = await self.store.get_bundled_aggregations(
-                recents, sync_config.user.to_string()
+            bundled_aggregations = (
+                await self._relations_handler.get_bundled_aggregations(
+                    recents, sync_config.user.to_string()
+                )
             )
 
         return TimelineBatch(
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 8a06ab8c5f..47e152c8cc 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -645,6 +645,7 @@ class RoomEventServlet(RestServlet):
         self._store = hs.get_datastores().main
         self.event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
+        self._relations_handler = hs.get_relations_handler()
         self.auth = hs.get_auth()
 
     async def on_GET(
@@ -663,7 +664,7 @@ class RoomEventServlet(RestServlet):
 
         if event:
             # Ensure there are bundled aggregations available.
-            aggregations = await self._store.get_bundled_aggregations(
+            aggregations = await self._relations_handler.get_bundled_aggregations(
                 [event], requester.user.to_string()
             )
 
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index af2334a65e..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
 )
 
 import attr
-from frozendict import frozendict
 
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
-    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
 
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
-    # The latest event in the thread.
-    latest_event: EventBase
-    # The latest edit to the latest event in the thread.
-    latest_edit: Optional[EventBase]
-    # The total number of events in the thread.
-    count: int
-    # True if the current user has sent an event to the thread.
-    current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
-    """
-    The bundled aggregations for an event.
-
-    Some values require additional processing during serialization.
-    """
-
-    annotations: Optional[JsonDict] = None
-    references: Optional[JsonDict] = None
-    replace: Optional[EventBase] = None
-    thread: Optional[_ThreadAggregation] = None
-
-    def __bool__(self) -> bool:
-        return bool(self.annotations or self.references or self.replace or self.thread)
-
-
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -384,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
-    async def _get_applicable_edits(
+    async def get_applicable_edits(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[EventBase]]:
         """Get the most recent edit (if any) that has happened for the given
@@ -473,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
-    async def _get_thread_summaries(
+    async def get_thread_summaries(
         self, event_ids: Collection[str]
     ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
         """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -587,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
         # Check to see if any of those events are edited.
-        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+        latest_edits = await self.get_applicable_edits(latest_event_ids.values())
 
         # Map to the event IDs to the thread summary.
         #
@@ -610,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
         raise NotImplementedError()
 
     @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
-    async def _get_threads_participated(
+    async def get_threads_participated(
         self, event_ids: Collection[str], user_id: str
     ) -> Dict[str, bool]:
         """Get whether the requesting user participated in the given threads.
@@ -766,116 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
-    async def _get_bundled_aggregation_for_event(
-        self, event: EventBase, user_id: str
-    ) -> Optional[BundledAggregations]:
-        """Generate bundled aggregations for an event.
-
-        Note that this does not use a cache, but depends on cached methods.
-
-        Args:
-            event: The event to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            The bundled aggregations for an event, if bundled aggregations are
-            enabled and the event can have bundled aggregations.
-        """
-
-        # Do not bundle aggregations for an event which represents an edit or an
-        # annotation. It does not make sense for them to have related events.
-        relates_to = event.content.get("m.relates_to")
-        if isinstance(relates_to, (dict, frozendict)):
-            relation_type = relates_to.get("rel_type")
-            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
-                return None
-
-        event_id = event.event_id
-        room_id = event.room_id
-
-        # The bundled aggregations to include, a mapping of relation type to a
-        # type-specific value. Some types include the direct return type here
-        # while others need more processing during serialization.
-        aggregations = BundledAggregations()
-
-        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
-        if annotations.chunk:
-            aggregations.annotations = await annotations.to_dict(
-                cast("DataStore", self)
-            )
-
-        references = await self.get_relations_for_event(
-            event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
-        )
-        if references.chunk:
-            aggregations.references = await references.to_dict(cast("DataStore", self))
-
-        # Store the bundled aggregations in the event metadata for later use.
-        return aggregations
-
-    async def get_bundled_aggregations(
-        self, events: Iterable[EventBase], user_id: str
-    ) -> Dict[str, BundledAggregations]:
-        """Generate bundled aggregations for events.
-
-        Args:
-            events: The iterable of events to calculate bundled aggregations for.
-            user_id: The user requesting the bundled aggregations.
-
-        Returns:
-            A map of event ID to the bundled aggregation for the event. Not all
-            events may have bundled aggregations in the results.
-        """
-        # De-duplicate events by ID to handle the same event requested multiple times.
-        #
-        # State events do not get bundled aggregations.
-        events_by_id = {
-            event.event_id: event for event in events if not event.is_state()
-        }
-
-        # event ID -> bundled aggregation in non-serialized form.
-        results: Dict[str, BundledAggregations] = {}
-
-        # Fetch other relations per event.
-        for event in events_by_id.values():
-            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result:
-                results[event.event_id] = event_result
-
-        # Fetch any edits (but not for redacted events).
-        edits = await self._get_applicable_edits(
-            [
-                event_id
-                for event_id, event in events_by_id.items()
-                if not event.internal_metadata.is_redacted()
-            ]
-        )
-        for event_id, edit in edits.items():
-            results.setdefault(event_id, BundledAggregations()).replace = edit
-
-        # Fetch thread summaries.
-        summaries = await self._get_thread_summaries(events_by_id.keys())
-        # Only fetch participated for a limited selection based on what had
-        # summaries.
-        participated = await self._get_threads_participated(
-            [event_id for event_id, summary in summaries.items() if summary], user_id
-        )
-        for event_id, summary in summaries.items():
-            if summary:
-                thread_count, latest_thread_event, edit = summary
-                results.setdefault(
-                    event_id, BundledAggregations()
-                ).thread = _ThreadAggregation(
-                    latest_event=latest_thread_event,
-                    latest_edit=edit,
-                    count=thread_count,
-                    # If there's a thread summary it must also exist in the
-                    # participated dictionary.
-                    current_user_participated=participated[event_id],
-                )
-
-        return results
-
 
 class RelationsStore(RelationsWorkerStore):
     pass