diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 918adeecf8..243696b357 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -14,7 +14,17 @@
# limitations under the License.
import collections.abc
import re
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Union,
+)
from frozendict import frozendict
@@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase
+if TYPE_CHECKING:
+ from synapse.storage.databases.main.relations import BundledAggregations
+
+
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@@ -376,7 +390,7 @@ class EventClientSerializer:
event: Union[JsonDict, EventBase],
time_now: int,
*,
- bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
+ bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
@@ -415,7 +429,7 @@ class EventClientSerializer:
self,
event: EventBase,
time_now: int,
- aggregations: JsonDict,
+ aggregations: "BundledAggregations",
serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
@@ -427,13 +441,18 @@ class EventClientSerializer:
serialized_event: The serialized event which may be modified.
"""
- # Make a copy in-case the object is cached.
- aggregations = aggregations.copy()
+ serialized_aggregations = {}
+
+ if aggregations.annotations:
+ serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
+
+ if aggregations.references:
+ serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
- if RelationTypes.REPLACE in aggregations:
+ if aggregations.replace:
# If there is an edit replace the content, preserving existing
# relations.
- edit = aggregations[RelationTypes.REPLACE]
+ edit = aggregations.replace
# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
@@ -451,24 +470,28 @@ class EventClientSerializer:
else:
serialized_event["content"].pop("m.relates_to", None)
- aggregations[RelationTypes.REPLACE] = {
+ serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
"sender": edit.sender,
}
# If this event is the start of a thread, include a summary of the replies.
- if RelationTypes.THREAD in aggregations:
- # Serialize the latest thread event.
- latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
-
- # Don't bundle aggregations as this could recurse forever.
- aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
- latest_thread_event, time_now, bundle_aggregations=None
- )
+ if aggregations.thread:
+ serialized_aggregations[RelationTypes.THREAD] = {
+ # Don't bundle aggregations as this could recurse forever.
+ "latest_event": self.serialize_event(
+ aggregations.thread.latest_event, time_now, bundle_aggregations=None
+ ),
+ "count": aggregations.thread.count,
+ "current_user_participated": aggregations.thread.current_user_participated,
+ }
# Include the bundled aggregations in the event.
- serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
+ if serialized_aggregations:
+ serialized_event["unsigned"].setdefault("m.relations", {}).update(
+ serialized_aggregations
+ )
def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f963078e59..1420d67729 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -30,6 +30,7 @@ from typing import (
Tuple,
)
+import attr
from typing_extensions import TypedDict
from synapse.api.constants import (
@@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents
from synapse.federation.federation_client import InvalidResponseError
from synapse.handlers.federation import get_domains_from_state
from synapse.rest.admin._base import assert_user_is_admin
+from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.state import StateFilter
from synapse.streams import EventSource
from synapse.types import (
@@ -90,6 +92,17 @@ id_server_scheme = "https://"
FIVE_MINUTES_IN_MS = 5 * 60 * 1000
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventContext:
+ events_before: List[EventBase]
+ event: EventBase
+ events_after: List[EventBase]
+ state: List[EventBase]
+ aggregations: Dict[str, BundledAggregations]
+ start: str
+ end: str
+
+
class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
@@ -1119,7 +1132,7 @@ class RoomContextHandler:
limit: int,
event_filter: Optional[Filter],
use_admin_priviledge: bool = False,
- ) -> Optional[JsonDict]:
+ ) -> Optional[EventContext]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
@@ -1167,38 +1180,28 @@ class RoomContextHandler:
results = await self.store.get_events_around(
room_id, event_id, before_limit, after_limit, event_filter
)
+ events_before = results.events_before
+ events_after = results.events_after
if event_filter:
- results["events_before"] = await event_filter.filter(
- results["events_before"]
- )
- results["events_after"] = await event_filter.filter(results["events_after"])
+ events_before = await event_filter.filter(events_before)
+ events_after = await event_filter.filter(events_after)
- results["events_before"] = await filter_evts(results["events_before"])
- results["events_after"] = await filter_evts(results["events_after"])
+ events_before = await filter_evts(events_before)
+ events_after = await filter_evts(events_after)
# filter_evts can return a pruned event in case the user is allowed to see that
# there's something there but not see the content, so use the event that's in
# `filtered` rather than the event we retrieved from the datastore.
- results["event"] = filtered[0]
+ event = filtered[0]
# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations(
- [results["event"]], user.to_string()
+ itertools.chain(events_before, (event,), events_after),
+ user.to_string(),
)
- aggregations.update(
- await self.store.get_bundled_aggregations(
- results["events_before"], user.to_string()
- )
- )
- aggregations.update(
- await self.store.get_bundled_aggregations(
- results["events_after"], user.to_string()
- )
- )
- results["aggregations"] = aggregations
- if results["events_after"]:
- last_event_id = results["events_after"][-1].event_id
+ if events_after:
+ last_event_id = events_after[-1].event_id
else:
last_event_id = event_id
@@ -1206,9 +1209,9 @@ class RoomContextHandler:
state_filter = StateFilter.from_lazy_load_member_list(
ev.sender
for ev in itertools.chain(
- results["events_before"],
- (results["event"],),
- results["events_after"],
+ events_before,
+ (event,),
+ events_after,
)
)
else:
@@ -1226,21 +1229,23 @@ class RoomContextHandler:
if event_filter:
state_events = await event_filter.filter(state_events)
- results["state"] = await filter_evts(state_events)
-
# We use a dummy token here as we only care about the room portion of
# the token, which we replace.
token = StreamToken.START
- results["start"] = await token.copy_and_replace(
- "room_key", results["start"]
- ).to_string(self.store)
-
- results["end"] = await token.copy_and_replace(
- "room_key", results["end"]
- ).to_string(self.store)
-
- return results
+ return EventContext(
+ events_before=events_before,
+ event=event,
+ events_after=events_after,
+ state=await filter_evts(state_events),
+ aggregations=aggregations,
+ start=await token.copy_and_replace("room_key", results.start).to_string(
+ self.store
+ ),
+ end=await token.copy_and_replace("room_key", results.end).to_string(
+ self.store
+ ),
+ )
class TimestampLookupHandler:
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 0b153a6822..02bb5ae72f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -361,36 +361,37 @@ class SearchHandler:
logger.info(
"Context for search returned %d and %d events",
- len(res["events_before"]),
- len(res["events_after"]),
+ len(res.events_before),
+ len(res.events_after),
)
- res["events_before"] = await filter_events_for_client(
- self.storage, user.to_string(), res["events_before"]
+ events_before = await filter_events_for_client(
+ self.storage, user.to_string(), res.events_before
)
- res["events_after"] = await filter_events_for_client(
- self.storage, user.to_string(), res["events_after"]
+ events_after = await filter_events_for_client(
+ self.storage, user.to_string(), res.events_after
)
- res["start"] = await now_token.copy_and_replace(
- "room_key", res["start"]
- ).to_string(self.store)
-
- res["end"] = await now_token.copy_and_replace(
- "room_key", res["end"]
- ).to_string(self.store)
+ context = {
+ "events_before": events_before,
+ "events_after": events_after,
+ "start": await now_token.copy_and_replace(
+ "room_key", res.start
+ ).to_string(self.store),
+ "end": await now_token.copy_and_replace(
+ "room_key", res.end
+ ).to_string(self.store),
+ }
if include_profile:
senders = {
ev.sender
- for ev in itertools.chain(
- res["events_before"], [event], res["events_after"]
- )
+ for ev in itertools.chain(events_before, [event], events_after)
}
- if res["events_after"]:
- last_event_id = res["events_after"][-1].event_id
+ if events_after:
+ last_event_id = events_after[-1].event_id
else:
last_event_id = event.event_id
@@ -402,7 +403,7 @@ class SearchHandler:
last_event_id, state_filter
)
- res["profile_info"] = {
+ context["profile_info"] = {
s.state_key: {
"displayname": s.content.get("displayname", None),
"avatar_url": s.content.get("avatar_url", None),
@@ -411,7 +412,7 @@ class SearchHandler:
if s.type == EventTypes.Member and s.state_key in senders
}
- contexts[event.event_id] = res
+ contexts[event.event_id] = context
else:
contexts = {}
@@ -421,10 +422,10 @@ class SearchHandler:
for context in contexts.values():
context["events_before"] = self._event_serializer.serialize_events(
- context["events_before"], time_now
+ context["events_before"], time_now # type: ignore[arg-type]
)
context["events_after"] = self._event_serializer.serialize_events(
- context["events_after"], time_now
+ context["events_after"], time_now # type: ignore[arg-type]
)
state_results = {}
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7e2a892b63..c72ed7c290 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.logging.context import current_context
from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.relations import BundledAggregations
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import (
@@ -100,7 +101,7 @@ class TimelineBatch:
limited: bool
# A mapping of event ID to the bundled aggregations for the above events.
# This is only calculated if limited is true.
- bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
+ bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index dadfc57413..3df8452eec 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -455,7 +455,7 @@ class Mailer:
}
the_events = await filter_events_for_client(
- self.storage, user_id, results["events_before"]
+ self.storage, user_id, results.events_before
)
the_events.append(notif_event)
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index efe25fe7eb..5b706efbcf 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet):
else:
event_filter = None
- results = await self.room_context_handler.get_event_context(
+ event_context = await self.room_context_handler.get_event_context(
requester,
room_id,
event_id,
@@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet):
use_admin_priviledge=True,
)
- if not results:
+ if not event_context:
raise SynapseError(
HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
)
time_now = self.clock.time_msec()
- aggregations = results.pop("aggregations", None)
- results["events_before"] = self._event_serializer.serialize_events(
- results["events_before"], time_now, bundle_aggregations=aggregations
- )
- results["event"] = self._event_serializer.serialize_event(
- results["event"], time_now, bundle_aggregations=aggregations
- )
- results["events_after"] = self._event_serializer.serialize_events(
- results["events_after"], time_now, bundle_aggregations=aggregations
- )
- results["state"] = self._event_serializer.serialize_events(
- results["state"], time_now
- )
+ results = {
+ "events_before": self._event_serializer.serialize_events(
+ event_context.events_before,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "event": self._event_serializer.serialize_event(
+ event_context.event,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "events_after": self._event_serializer.serialize_events(
+ event_context.events_after,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "state": self._event_serializer.serialize_events(
+ event_context.state, time_now
+ ),
+ "start": event_context.start,
+ "end": event_context.end,
+ }
return HTTPStatus.OK, results
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 90bb9142a0..90355e44b2 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet):
else:
event_filter = None
- results = await self.room_context_handler.get_event_context(
+ event_context = await self.room_context_handler.get_event_context(
requester, room_id, event_id, limit, event_filter
)
- if not results:
+ if not event_context:
raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
time_now = self.clock.time_msec()
- aggregations = results.pop("aggregations", None)
- results["events_before"] = self._event_serializer.serialize_events(
- results["events_before"], time_now, bundle_aggregations=aggregations
- )
- results["event"] = self._event_serializer.serialize_event(
- results["event"], time_now, bundle_aggregations=aggregations
- )
- results["events_after"] = self._event_serializer.serialize_events(
- results["events_after"], time_now, bundle_aggregations=aggregations
- )
- results["state"] = self._event_serializer.serialize_events(
- results["state"], time_now
- )
+ results = {
+ "events_before": self._event_serializer.serialize_events(
+ event_context.events_before,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "event": self._event_serializer.serialize_event(
+ event_context.event,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "events_after": self._event_serializer.serialize_events(
+ event_context.events_after,
+ time_now,
+ bundle_aggregations=event_context.aggregations,
+ ),
+ "state": self._event_serializer.serialize_events(
+ event_context.state, time_now
+ ),
+ "start": event_context.start,
+ "end": event_context.end,
+ }
return 200, results
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d20ae1421e..f9615da525 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -48,6 +48,7 @@ from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import trace
+from synapse.storage.databases.main.relations import BundledAggregations
from synapse.types import JsonDict, StreamToken
from synapse.util import json_decoder
@@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet):
def serialize(
events: Iterable[EventBase],
- aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
+ aggregations: Optional[Dict[str, BundledAggregations]] = None,
) -> List[JsonDict]:
return self._event_serializer.serialize_events(
events,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c13..a9a5dd5f03 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
# limitations under the License.
import logging
-from typing import (
- TYPE_CHECKING,
- Any,
- Dict,
- Iterable,
- List,
- Optional,
- Tuple,
- Union,
- cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
import attr
from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
+from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+ latest_event: EventBase
+ count: int
+ current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+ """
+ The bundled aggregations for an event.
+
+ Some values require additional processing during serialization.
+ """
+
+ annotations: Optional[JsonDict] = None
+ references: Optional[JsonDict] = None
+ replace: Optional[EventBase] = None
+ thread: Optional[_ThreadAggregation] = None
+
+ def __bool__(self) -> bool:
+ return bool(self.annotations or self.references or self.replace or self.thread)
+
+
class RelationsWorkerStore(SQLBaseStore):
def __init__(
self,
@@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore):
async def _get_bundled_aggregation_for_event(
self, event: EventBase, user_id: str
- ) -> Optional[Dict[str, Any]]:
+ ) -> Optional[BundledAggregations]:
"""Generate bundled aggregations for an event.
Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore):
# The bundled aggregations to include, a mapping of relation type to a
# type-specific value. Some types include the direct return type here
# while others need more processing during serialization.
- aggregations: Dict[str, Any] = {}
+ aggregations = BundledAggregations()
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
- aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+ aggregations.annotations = annotations.to_dict()
references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
- aggregations[RelationTypes.REFERENCE] = references.to_dict()
+ aggregations.references = references.to_dict()
edit = None
if event.type == EventTypes.Message:
edit = await self.get_applicable_edit(event_id, room_id)
if edit:
- aggregations[RelationTypes.REPLACE] = edit
+ aggregations.replace = edit
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
@@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore):
event_id, room_id, user_id
)
if latest_thread_event:
- aggregations[RelationTypes.THREAD] = {
- "latest_event": latest_thread_event,
- "count": thread_count,
- "current_user_participated": participated,
- }
+ aggregations.thread = _ThreadAggregation(
+ latest_event=latest_thread_event,
+ count=thread_count,
+ current_user_participated=participated,
+ )
# Store the bundled aggregations in the event metadata for later use.
return aggregations
@@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore):
self,
events: Iterable[EventBase],
user_id: str,
- ) -> Dict[str, Dict[str, Any]]:
+ ) -> Dict[str, BundledAggregations]:
"""Generate bundled aggregations for events.
Args:
@@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore):
results = {}
for event in events:
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
- if event_result is not None:
+ if event_result:
results[event.event_id] = event_result
return results
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 319464b1fa..a898f847e7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -81,6 +81,14 @@ class _EventDictReturn:
stream_ordering: int
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventsAround:
+ events_before: List[EventBase]
+ events_after: List[EventBase]
+ start: RoomStreamToken
+ end: RoomStreamToken
+
+
def generate_pagination_where_clause(
direction: str,
column_names: Tuple[str, str],
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
before_limit: int,
after_limit: int,
event_filter: Optional[Filter] = None,
- ) -> dict:
+ ) -> _EventsAround:
"""Retrieve events and pagination tokens around a given event in a
room.
"""
@@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
list(results["after"]["event_ids"]), get_prev_content=True
)
- return {
- "events_before": events_before,
- "events_after": events_after,
- "start": results["before"]["token"],
- "end": results["after"]["token"],
- }
+ return _EventsAround(
+ events_before=events_before,
+ events_after=events_after,
+ start=results["before"]["token"],
+ end=results["after"]["token"],
+ )
def _get_events_around_txn(
self,
|