diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index db39aeabde..350ec9c03a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -950,54 +950,35 @@ class FederationHandler:
return event
- async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
- """Returns the state at the event. i.e. not including said event."""
-
- event = await self.store.get_event(event_id, check_room_id=room_id)
-
- state_groups = await self.state_store.get_state_groups(room_id, [event_id])
-
- if state_groups:
- _, state = list(state_groups.items()).pop()
- results = {(e.type, e.state_key): e for e in state}
-
- if event.is_state():
- # Get previous state
- if "replaces_state" in event.unsigned:
- prev_id = event.unsigned["replaces_state"]
- if prev_id != event.event_id:
- prev_event = await self.store.get_event(prev_id)
- results[(event.type, event.state_key)] = prev_event
- else:
- del results[(event.type, event.state_key)]
-
- res = list(results.values())
- return res
- else:
- return []
-
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event."""
event = await self.store.get_event(event_id, check_room_id=room_id)
+ if event.internal_metadata.outlier:
+ raise NotFoundError("State not known at event %s" % (event_id,))
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
- if state_groups:
- _, state = list(state_groups.items()).pop()
- results = state
+ # get_state_groups_ids should return exactly one result
+ assert len(state_groups) == 1
- if event.is_state():
- # Get previous state
- if "replaces_state" in event.unsigned:
- prev_id = event.unsigned["replaces_state"]
- if prev_id != event.event_id:
- results[(event.type, event.state_key)] = prev_id
- else:
- results.pop((event.type, event.state_key), None)
+ state_map = next(iter(state_groups.values()))
- return list(results.values())
- else:
- return []
+ state_key = event.get_state_key()
+ if state_key is not None:
+ # the event was not rejected (get_event raises a NotFoundError for rejected
+ # events) so the state at the event should include the event itself.
+ assert (
+ state_map.get((event.type, state_key)) == event.event_id
+ ), "State at event did not include event itself"
+
+ # ... but we need the state *before* that event
+ if "replaces_state" in event.unsigned:
+ prev_id = event.unsigned["replaces_state"]
+ state_map[(event.type, state_key)] = prev_id
+ else:
+ del state_map[(event.type, state_key)]
+
+ return list(state_map.values())
async def on_backfill_request(
self, origin: str, room_id: str, pdu_list: List[str], limit: int
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 60059fec3e..876b879483 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set
import attr
@@ -134,6 +134,7 @@ class PaginationHandler:
self.clock = hs.get_clock()
self._server_name = hs.hostname
self._room_shutdown_handler = hs.get_room_shutdown_handler()
+ self._relations_handler = hs.get_relations_handler()
self.pagination_lock = ReadWriteLock()
# IDs of rooms in which there currently an active purge *or delete* operation.
@@ -422,7 +423,7 @@ class PaginationHandler:
pagin_config: PaginationConfig,
as_client_event: bool = True,
event_filter: Optional[Filter] = None,
- ) -> Dict[str, Any]:
+ ) -> JsonDict:
"""Get messages in a room.
Args:
@@ -431,6 +432,7 @@ class PaginationHandler:
pagin_config: The pagination config rules to apply, if any.
as_client_event: True to get events in client-server format.
event_filter: Filter to apply to results or None
+
Returns:
Pagination API results
"""
@@ -538,7 +540,9 @@ class PaginationHandler:
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()
- aggregations = await self.store.get_bundled_aggregations(events, user_id)
+ aggregations = await self._relations_handler.get_bundled_aggregations(
+ events, user_id
+ )
time_now = self.clock.time_msec()
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
new file mode 100644
index 0000000000..57135d4519
--- /dev/null
+++ b/synapse/handlers/relations.py
@@ -0,0 +1,264 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast
+
+import attr
+from frozendict import frozendict
+
+from synapse.api.constants import RelationTypes
+from synapse.api.errors import SynapseError
+from synapse.events import EventBase
+from synapse.types import JsonDict, Requester, StreamToken
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
+
+
+logger = logging.getLogger(__name__)
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ # The latest event in the thread.
+ latest_event: EventBase
+ # The latest edit to the latest event in the thread.
+ latest_edit: Optional[EventBase]
+ # The total number of events in the thread.
+ count: int
+ # True if the current user has sent an event to the thread.
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
+class RelationsHandler:
+ def __init__(self, hs: "HomeServer"):
+ self._main_store = hs.get_datastores().main
+ self._auth = hs.get_auth()
+ self._clock = hs.get_clock()
+ self._event_handler = hs.get_event_handler()
+ self._event_serializer = hs.get_event_client_serializer()
+
+ async def get_relations(
+ self,
+ requester: Requester,
+ event_id: str,
+ room_id: str,
+ relation_type: Optional[str] = None,
+ event_type: Optional[str] = None,
+ aggregation_key: Optional[str] = None,
+ limit: int = 5,
+ direction: str = "b",
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
+ ) -> JsonDict:
+ """Get related events of a event, ordered by topological ordering.
+
+ TODO Accept a PaginationConfig instead of individual pagination parameters.
+
+ Args:
+ requester: The user requesting the relations.
+ event_id: Fetch events that relate to this event ID.
+ room_id: The room the event belongs to.
+ relation_type: Only fetch events with this relation type, if given.
+ event_type: Only fetch events with this event type, if given.
+ aggregation_key: Only fetch events with this aggregation key, if given.
+ limit: Only fetch the most recent `limit` events.
+ direction: Whether to fetch the most recent first (`"b"`) or the
+ oldest first (`"f"`).
+ from_token: Fetch rows from the given token, or from the start if None.
+ to_token: Fetch rows up to the given token, or up to the end if None.
+
+ Returns:
+ The pagination chunk.
+ """
+
+ user_id = requester.user.to_string()
+
+ await self._auth.check_user_in_room_or_world_readable(
+ room_id, user_id, allow_departed_users=True
+ )
+
+ # This gets the original event and checks that a) the event exists and
+ # b) the user is allowed to view it.
+ event = await self._event_handler.get_event(requester.user, room_id, event_id)
+ if event is None:
+ raise SynapseError(404, "Unknown parent event.")
+
+ pagination_chunk = await self._main_store.get_relations_for_event(
+ event_id=event_id,
+ event=event,
+ room_id=room_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ aggregation_key=aggregation_key,
+ limit=limit,
+ direction=direction,
+ from_token=from_token,
+ to_token=to_token,
+ )
+
+ events = await self._main_store.get_events_as_list(
+ [c["event_id"] for c in pagination_chunk.chunk]
+ )
+
+ now = self._clock.time_msec()
+ # Do not bundle aggregations when retrieving the original event because
+ # we want the content before relations are applied to it.
+ original_event = self._event_serializer.serialize_event(
+ event, now, bundle_aggregations=None
+ )
+ # The relations returned for the requested event do include their
+ # bundled aggregations.
+ aggregations = await self.get_bundled_aggregations(
+ events, requester.user.to_string()
+ )
+ serialized_events = self._event_serializer.serialize_events(
+ events, now, bundle_aggregations=aggregations
+ )
+
+ return_value = await pagination_chunk.to_dict(self._main_store)
+ return_value["chunk"] = serialized_events
+ return_value["original_event"] = original_event
+
+ return return_value
+
+ async def _get_bundled_aggregation_for_event(
+ self, event: EventBase, user_id: str
+ ) -> Optional[BundledAggregations]:
+ """Generate bundled aggregations for an event.
+
+ Note that this does not use a cache, but depends on cached methods.
+
+ Args:
+ event: The event to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ The bundled aggregations for an event, if bundled aggregations are
+ enabled and the event can have bundled aggregations.
+ """
+
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ return None
+
+ event_id = event.event_id
+ room_id = event.room_id
+
+ # The bundled aggregations to include, a mapping of relation type to a
+ # type-specific value. Some types include the direct return type here
+ # while others need more processing during serialization.
+ aggregations = BundledAggregations()
+
+ annotations = await self._main_store.get_aggregation_groups_for_event(
+ event_id, room_id
+ )
+ if annotations.chunk:
+ aggregations.annotations = await annotations.to_dict(
+ cast("DataStore", self)
+ )
+
+ references = await self._main_store.get_relations_for_event(
+ event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
+ )
+ if references.chunk:
+ aggregations.references = await references.to_dict(cast("DataStore", self))
+
+ # Store the bundled aggregations in the event metadata for later use.
+ return aggregations
+
+ async def get_bundled_aggregations(
+ self, events: Iterable[EventBase], user_id: str
+ ) -> Dict[str, BundledAggregations]:
+ """Generate bundled aggregations for events.
+
+ Args:
+ events: The iterable of events to calculate bundled aggregations for.
+ user_id: The user requesting the bundled aggregations.
+
+ Returns:
+ A map of event ID to the bundled aggregation for the event. Not all
+ events may have bundled aggregations in the results.
+ """
+ # De-duplicate events by ID to handle the same event requested multiple times.
+ #
+ # State events do not get bundled aggregations.
+ events_by_id = {
+ event.event_id: event for event in events if not event.is_state()
+ }
+
+ # event ID -> bundled aggregation in non-serialized form.
+ results: Dict[str, BundledAggregations] = {}
+
+ # Fetch other relations per event.
+ for event in events_by_id.values():
+ event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+ if event_result:
+ results[event.event_id] = event_result
+
+ # Fetch any edits (but not for redacted events).
+ edits = await self._main_store.get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
+ for event_id, edit in edits.items():
+ results.setdefault(event_id, BundledAggregations()).replace = edit
+
+ # Fetch thread summaries.
+ summaries = await self._main_store.get_thread_summaries(events_by_id.keys())
+ # Only fetch participated for a limited selection based on what had
+ # summaries.
+ participated = await self._main_store.get_threads_participated(
+ [event_id for event_id, summary in summaries.items() if summary], user_id
+ )
+ for event_id, summary in summaries.items():
+ if summary:
+ thread_count, latest_thread_event, edit = summary
+ results.setdefault(
+ event_id, BundledAggregations()
+ ).thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ latest_edit=edit,
+ count=thread_count,
+ # If there's a thread summary it must also exist in the
+ # participated dictionary.
+ current_user_participated=participated[event_id],
+ )
+
+ return results
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index b9735631fc..092e185c99 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -60,8 +60,8 @@ from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
+from synapse.handlers.relations import BundledAggregations
from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
@@ -1118,6 +1118,7 @@ class RoomContextHandler:
self.store = hs.get_datastores().main
self.storage = hs.get_storage()
self.state_store = self.storage.state
+ self._relations_handler = hs.get_relations_handler()
async def get_event_context(
self,
@@ -1190,7 +1191,7 @@ class RoomContextHandler:
event = filtered[0]
# Fetch the aggregations.
- aggregations = await self.store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
itertools.chain(events_before, (event,), events_after),
user.to_string(),
)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index aa16e417eb..30eddda65f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -54,6 +54,7 @@ class SearchHandler:
self.clock = hs.get_clock()
self.hs = hs
self._event_serializer = hs.get_event_client_serializer()
+ self._relations_handler = hs.get_relations_handler()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.auth = hs.get_auth()
@@ -354,7 +355,7 @@ class SearchHandler:
aggregations = None
if self._msc3666_enabled:
- aggregations = await self.store.get_bundled_aggregations(
+ aggregations = await self._relations_handler.get_bundled_aggregations(
# Generate an iterable of EventBase for all the events that will be
# returned, including contextual events.
itertools.chain(
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0aa3052fd6..6c569cfb1c 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -28,16 +28,16 @@ from typing import (
import attr
from prometheus_client import Counter
-from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes
+from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
+from synapse.handlers.relations import BundledAggregations
from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
-from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -269,6 +269,7 @@ class SyncHandler:
self.store = hs.get_datastores().main
self.notifier = hs.get_notifier()
self.presence_handler = hs.get_presence_handler()
+ self._relations_handler = hs.get_relations_handler()
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.state = hs.get_state_handler()
@@ -638,8 +639,10 @@ class SyncHandler:
# as clients will have all the necessary information.
bundled_aggregations = None
if limited or newly_joined_room:
- bundled_aggregations = await self.store.get_bundled_aggregations(
- recents, sync_config.user.to_string()
+ bundled_aggregations = (
+ await self._relations_handler.get_bundled_aggregations(
+ recents, sync_config.user.to_string()
+ )
)
return TimelineBatch(
@@ -1601,7 +1604,7 @@ class SyncHandler:
return set(), set(), set(), set()
# 3. Work out which rooms need reporting in the sync response.
- ignored_users = await self._get_ignored_users(user_id)
+ ignored_users = await self.store.ignored_users(user_id)
if since_token:
room_changes = await self._get_rooms_changed(
sync_result_builder, ignored_users
@@ -1627,7 +1630,6 @@ class SyncHandler:
logger.debug("Generating room entry for %s", room_entry.room_id)
await self._generate_room_entry(
sync_result_builder,
- ignored_users,
room_entry,
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
tags=tags_by_room.get(room_entry.room_id),
@@ -1657,29 +1659,6 @@ class SyncHandler:
newly_left_users,
)
- async def _get_ignored_users(self, user_id: str) -> FrozenSet[str]:
- """Retrieve the users ignored by the given user from their global account_data.
-
- Returns an empty set if
- - there is no global account_data entry for ignored_users
- - there is such an entry, but it's not a JSON object.
- """
- # TODO: Can we `SELECT ignored_user_id FROM ignored_users WHERE ignorer_user_id=?;` instead?
- ignored_account_data = (
- await self.store.get_global_account_data_by_type_for_user(
- user_id=user_id, data_type=AccountDataTypes.IGNORED_USER_LIST
- )
- )
-
- # If there is ignored users account data and it matches the proper type,
- # then use it.
- ignored_users: FrozenSet[str] = frozenset()
- if ignored_account_data:
- ignored_users_data = ignored_account_data.get("ignored_users", {})
- if isinstance(ignored_users_data, dict):
- ignored_users = frozenset(ignored_users_data.keys())
- return ignored_users
-
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
@@ -2022,7 +2001,6 @@ class SyncHandler:
async def _generate_room_entry(
self,
sync_result_builder: "SyncResultBuilder",
- ignored_users: FrozenSet[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[Dict[str, Dict[str, Any]]],
@@ -2051,7 +2029,6 @@ class SyncHandler:
Args:
sync_result_builder
- ignored_users: Set of users ignored by user.
room_builder
ephemeral: List of new ephemeral events for room
tags: List of *all* tags for room, or None if there has been
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index d27ed2be6a..048fd4bb82 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -19,8 +19,8 @@ import synapse.metrics
from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.user_directory import SearchResult
from synapse.storage.roommember import ProfileInfo
-from synapse.types import JsonDict
from synapse.util.metrics import Measure
if TYPE_CHECKING:
@@ -78,7 +78,7 @@ class UserDirectoryHandler(StateDeltasHandler):
async def search_users(
self, user_id: str, search_term: str, limit: int
- ) -> JsonDict:
+ ) -> SearchResult:
"""Searches for users in directory
Returns:
|