diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 36aa1092f6..b2295fd51f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -27,7 +27,6 @@ from typing import (
)
import attr
-from frozendict import frozendict
from synapse.api.constants import RelationTypes
from synapse.events import EventBase
@@ -41,45 +40,15 @@ from synapse.storage.database import (
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
-from synapse.types import JsonDict, RoomStreamToken, StreamToken
+from synapse.types import RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
- from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
-@attr.s(slots=True, frozen=True, auto_attribs=True)
-class _ThreadAggregation:
- # The latest event in the thread.
- latest_event: EventBase
- # The latest edit to the latest event in the thread.
- latest_edit: Optional[EventBase]
- # The total number of events in the thread.
- count: int
- # True if the current user has sent an event to the thread.
- current_user_participated: bool
-
-
-@attr.s(slots=True, auto_attribs=True)
-class BundledAggregations:
- """
- The bundled aggregations for an event.
-
- Some values require additional processing during serialization.
- """
-
- annotations: Optional[JsonDict] = None
- references: Optional[JsonDict] = None
- replace: Optional[EventBase] = None
- thread: Optional[_ThreadAggregation] = None
-
- def __bool__(self) -> bool:
- return bool(self.annotations or self.references or self.replace or self.thread)
-
-
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -91,10 +60,11 @@ class RelationsWorkerStore(SQLBaseStore):
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
- @cached(tree=True)
+ @cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
event_id: str,
+ event: EventBase,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
@@ -108,6 +78,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ event: The matching EventBase to event_id.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
@@ -122,9 +93,13 @@ class RelationsWorkerStore(SQLBaseStore):
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
"""
+ # We don't use `event_id`, it's there so that we can cache based on
+ # it. The `event_id` must match the `event.event_id`.
+ assert event.event_id == event_id
where_clause = ["relates_to_id = ?", "room_id = ?"]
- where_args: List[Union[str, int]] = [event_id, room_id]
+ where_args: List[Union[str, int]] = [event.event_id, room_id]
+ is_redacted = event.internal_metadata.is_redacted()
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -157,7 +132,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
- SELECT event_id, topological_ordering, stream_ordering
+ SELECT event_id, relation_type, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@@ -178,9 +153,12 @@ class RelationsWorkerStore(SQLBaseStore):
last_stream_id = None
events = []
for row in txn:
- events.append({"event_id": row[0]})
- last_topo_id = row[1]
- last_stream_id = row[2]
+ # Do not include edits for redacted events as they leak event
+ # content.
+ if not is_redacted or row[1] != RelationTypes.REPLACE:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[2]
+ last_stream_id = row[3]
# If there are more events, generate the next pagination key.
next_token = None
@@ -375,7 +353,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_applicable_edit", list_name="event_ids")
- async def _get_applicable_edits(
+ async def get_applicable_edits(
self, event_ids: Collection[str]
) -> Dict[str, Optional[EventBase]]:
"""Get the most recent edit (if any) that has happened for the given
@@ -464,7 +442,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
- async def _get_thread_summaries(
+ async def get_thread_summaries(
self, event_ids: Collection[str]
) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
"""Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
@@ -499,7 +477,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND relation_type = ?
+ AND %s
ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
"""
else:
@@ -514,16 +492,22 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND relation_type = ?
+ AND %s
ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
- args.append(RelationTypes.THREAD)
- txn.execute(sql % (clause,), args)
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+ else:
+ relations_clause = "relation_type = ?"
+ args.append(RelationTypes.THREAD)
+
+ txn.execute(sql % (clause, relations_clause), args)
latest_event_ids = {}
for parent_event_id, child_event_id in txn:
# Only consider the latest threaded reply (by topological ordering).
@@ -543,7 +527,7 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND relation_type = ?
+ AND %s
GROUP BY parent.event_id
"""
@@ -552,9 +536,15 @@ class RelationsWorkerStore(SQLBaseStore):
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", latest_event_ids.keys()
)
- args.append(RelationTypes.THREAD)
- txn.execute(sql % (clause,), args)
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+ else:
+ relations_clause = "relation_type = ?"
+ args.append(RelationTypes.THREAD)
+
+ txn.execute(sql % (clause, relations_clause), args)
counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
return counts, latest_event_ids
@@ -566,7 +556,7 @@ class RelationsWorkerStore(SQLBaseStore):
latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined]
# Check to see if any of those events are edited.
- latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+ latest_edits = await self.get_applicable_edits(latest_event_ids.values())
# Map to the event IDs to the thread summary.
#
@@ -589,7 +579,7 @@ class RelationsWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
- async def _get_threads_participated(
+ async def get_threads_participated(
self, event_ids: Collection[str], user_id: str
) -> Dict[str, bool]:
"""Get whether the requesting user participated in the given threads.
@@ -617,16 +607,24 @@ class RelationsWorkerStore(SQLBaseStore):
AND parent.room_id = child.room_id
WHERE
%s
- AND relation_type = ?
+ AND %s
AND child.sender = ?
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "relates_to_id", event_ids
)
- args.extend((RelationTypes.THREAD, user_id))
- txn.execute(sql % (clause,), args)
+ if self._msc3440_enabled:
+ relations_clause = "(relation_type = ? OR relation_type = ?)"
+ args.extend((RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD))
+ else:
+ relations_clause = "relation_type = ?"
+ args.append(RelationTypes.THREAD)
+
+ args.append(user_id)
+
+ txn.execute(sql % (clause, relations_clause), args)
return {row[0] for row in txn.fetchall()}
participated_threads = await self.db_pool.runInteraction(
@@ -737,122 +735,6 @@ class RelationsWorkerStore(SQLBaseStore):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)
- async def _get_bundled_aggregation_for_event(
- self, event: EventBase, user_id: str
- ) -> Optional[BundledAggregations]:
- """Generate bundled aggregations for an event.
-
- Note that this does not use a cache, but depends on cached methods.
-
- Args:
- event: The event to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- The bundled aggregations for an event, if bundled aggregations are
- enabled and the event can have bundled aggregations.
- """
-
- # Do not bundle aggregations for an event which represents an edit or an
- # annotation. It does not make sense for them to have related events.
- relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
- relation_type = relates_to.get("rel_type")
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
- return None
-
- event_id = event.event_id
- room_id = event.room_id
-
- # The bundled aggregations to include, a mapping of relation type to a
- # type-specific value. Some types include the direct return type here
- # while others need more processing during serialization.
- aggregations = BundledAggregations()
-
- annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
- if annotations.chunk:
- aggregations.annotations = await annotations.to_dict(
- cast("DataStore", self)
- )
-
- references = await self.get_relations_for_event(
- event_id, room_id, RelationTypes.REFERENCE, direction="f"
- )
- if references.chunk:
- aggregations.references = await references.to_dict(cast("DataStore", self))
-
- # Store the bundled aggregations in the event metadata for later use.
- return aggregations
-
- async def get_bundled_aggregations(
- self, events: Iterable[EventBase], user_id: str
- ) -> Dict[str, BundledAggregations]:
- """Generate bundled aggregations for events.
-
- Args:
- events: The iterable of events to calculate bundled aggregations for.
- user_id: The user requesting the bundled aggregations.
-
- Returns:
- A map of event ID to the bundled aggregation for the event. Not all
- events may have bundled aggregations in the results.
- """
- # The already processed event IDs. Tracked separately from the result
- # since the result omits events which do not have bundled aggregations.
- seen_event_ids = set()
-
- # State events and redacted events do not get bundled aggregations.
- events = [
- event
- for event in events
- if not event.is_state() and not event.internal_metadata.is_redacted()
- ]
-
- # event ID -> bundled aggregation in non-serialized form.
- results: Dict[str, BundledAggregations] = {}
-
- # Fetch other relations per event.
- for event in events:
- # De-duplicate events by ID to handle the same event requested multiple
- # times. The caches that _get_bundled_aggregation_for_event use should
- # capture this, but best to reduce work.
- if event.event_id in seen_event_ids:
- continue
- seen_event_ids.add(event.event_id)
-
- event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result:
- results[event.event_id] = event_result
-
- # Fetch any edits.
- edits = await self._get_applicable_edits(seen_event_ids)
- for event_id, edit in edits.items():
- results.setdefault(event_id, BundledAggregations()).replace = edit
-
- # Fetch thread summaries.
- if self._msc3440_enabled:
- summaries = await self._get_thread_summaries(seen_event_ids)
- # Only fetch participated for a limited selection based on what had
- # summaries.
- participated = await self._get_threads_participated(
- summaries.keys(), user_id
- )
- for event_id, summary in summaries.items():
- if summary:
- thread_count, latest_thread_event, edit = summary
- results.setdefault(
- event_id, BundledAggregations()
- ).thread = _ThreadAggregation(
- latest_event=latest_thread_event,
- latest_edit=edit,
- count=thread_count,
- # If there's a thread summary it must also exist in the
- # participated dictionary.
- current_user_participated=participated[event_id],
- )
-
- return results
-
class RelationsStore(RelationsWorkerStore):
pass
|