diff options
-rw-r--r-- | changelog.d/11815.misc | 1 | ||||
-rw-r--r-- | synapse/events/utils.py | 57 | ||||
-rw-r--r-- | synapse/handlers/room.py | 77 | ||||
-rw-r--r-- | synapse/handlers/search.py | 45 | ||||
-rw-r--r-- | synapse/handlers/sync.py | 3 | ||||
-rw-r--r-- | synapse/push/mailer.py | 2 | ||||
-rw-r--r-- | synapse/rest/admin/rooms.py | 39 | ||||
-rw-r--r-- | synapse/rest/client/room.py | 39 | ||||
-rw-r--r-- | synapse/rest/client/sync.py | 3 | ||||
-rw-r--r-- | synapse/storage/databases/main/relations.py | 61 | ||||
-rw-r--r-- | synapse/storage/databases/main/stream.py | 22 | ||||
-rw-r--r-- | tests/rest/client/test_relations.py | 2 |
12 files changed, 212 insertions, 139 deletions
diff --git a/changelog.d/11815.misc b/changelog.d/11815.misc new file mode 100644 index 0000000000..83aa6d6eb0 --- /dev/null +++ b/changelog.d/11815.misc @@ -0,0 +1 @@ +Improve type safety of bundled aggregations code. diff --git a/synapse/events/utils.py b/synapse/events/utils.py index 918adeecf8..243696b357 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -14,7 +14,17 @@ # limitations under the License. import collections.abc import re -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + Optional, + Union, +) from frozendict import frozendict @@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze from . import EventBase +if TYPE_CHECKING: + from synapse.storage.databases.main.relations import BundledAggregations + + # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\' # (?<!stuff) matches if the current position in the string is not preceded # by a match for 'stuff'. @@ -376,7 +390,7 @@ class EventClientSerializer: event: Union[JsonDict, EventBase], time_now: int, *, - bundle_aggregations: Optional[Dict[str, JsonDict]] = None, + bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None, **kwargs: Any, ) -> JsonDict: """Serializes a single event. @@ -415,7 +429,7 @@ class EventClientSerializer: self, event: EventBase, time_now: int, - aggregations: JsonDict, + aggregations: "BundledAggregations", serialized_event: JsonDict, ) -> None: """Potentially injects bundled aggregations into the unsigned portion of the serialized event. @@ -427,13 +441,18 @@ class EventClientSerializer: serialized_event: The serialized event which may be modified. """ - # Make a copy in-case the object is cached. - aggregations = aggregations.copy() + serialized_aggregations = {} + + if aggregations.annotations: + serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations + + if aggregations.references: + serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references - if RelationTypes.REPLACE in aggregations: + if aggregations.replace: # If there is an edit replace the content, preserving existing # relations. - edit = aggregations[RelationTypes.REPLACE] + edit = aggregations.replace # Ensure we take copies of the edit content, otherwise we risk modifying # the original event. @@ -451,24 +470,28 @@ class EventClientSerializer: else: serialized_event["content"].pop("m.relates_to", None) - aggregations[RelationTypes.REPLACE] = { + serialized_aggregations[RelationTypes.REPLACE] = { "event_id": edit.event_id, "origin_server_ts": edit.origin_server_ts, "sender": edit.sender, } # If this event is the start of a thread, include a summary of the replies. - if RelationTypes.THREAD in aggregations: - # Serialize the latest thread event. - latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"] - - # Don't bundle aggregations as this could recurse forever. - aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event( - latest_thread_event, time_now, bundle_aggregations=None - ) + if aggregations.thread: + serialized_aggregations[RelationTypes.THREAD] = { + # Don't bundle aggregations as this could recurse forever. + "latest_event": self.serialize_event( + aggregations.thread.latest_event, time_now, bundle_aggregations=None + ), + "count": aggregations.thread.count, + "current_user_participated": aggregations.thread.current_user_participated, + } # Include the bundled aggregations in the event. - serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations) + if serialized_aggregations: + serialized_event["unsigned"].setdefault("m.relations", {}).update( + serialized_aggregations + ) def serialize_events( self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index f963078e59..1420d67729 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -30,6 +30,7 @@ from typing import ( Tuple, ) +import attr from typing_extensions import TypedDict from synapse.api.constants import ( @@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents from synapse.federation.federation_client import InvalidResponseError from synapse.handlers.federation import get_domains_from_state from synapse.rest.admin._base import assert_user_is_admin +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.state import StateFilter from synapse.streams import EventSource from synapse.types import ( @@ -90,6 +92,17 @@ id_server_scheme = "https://" FIVE_MINUTES_IN_MS = 5 * 60 * 1000 +@attr.s(slots=True, frozen=True, auto_attribs=True) +class EventContext: + events_before: List[EventBase] + event: EventBase + events_after: List[EventBase] + state: List[EventBase] + aggregations: Dict[str, BundledAggregations] + start: str + end: str + + class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() @@ -1119,7 +1132,7 @@ class RoomContextHandler: limit: int, event_filter: Optional[Filter], use_admin_priviledge: bool = False, - ) -> Optional[JsonDict]: + ) -> Optional[EventContext]: """Retrieves events, pagination tokens and state around a given event in a room. @@ -1167,38 +1180,28 @@ class RoomContextHandler: results = await self.store.get_events_around( room_id, event_id, before_limit, after_limit, event_filter ) + events_before = results.events_before + events_after = results.events_after if event_filter: - results["events_before"] = await event_filter.filter( - results["events_before"] - ) - results["events_after"] = await event_filter.filter(results["events_after"]) + events_before = await event_filter.filter(events_before) + events_after = await event_filter.filter(events_after) - results["events_before"] = await filter_evts(results["events_before"]) - results["events_after"] = await filter_evts(results["events_after"]) + events_before = await filter_evts(events_before) + events_after = await filter_evts(events_after) # filter_evts can return a pruned event in case the user is allowed to see that # there's something there but not see the content, so use the event that's in # `filtered` rather than the event we retrieved from the datastore. - results["event"] = filtered[0] + event = filtered[0] # Fetch the aggregations. aggregations = await self.store.get_bundled_aggregations( - [results["event"]], user.to_string() + itertools.chain(events_before, (event,), events_after), + user.to_string(), ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_before"], user.to_string() - ) - ) - aggregations.update( - await self.store.get_bundled_aggregations( - results["events_after"], user.to_string() - ) - ) - results["aggregations"] = aggregations - if results["events_after"]: - last_event_id = results["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event_id @@ -1206,9 +1209,9 @@ class RoomContextHandler: state_filter = StateFilter.from_lazy_load_member_list( ev.sender for ev in itertools.chain( - results["events_before"], - (results["event"],), - results["events_after"], + events_before, + (event,), + events_after, ) ) else: @@ -1226,21 +1229,23 @@ class RoomContextHandler: if event_filter: state_events = await event_filter.filter(state_events) - results["state"] = await filter_evts(state_events) - # We use a dummy token here as we only care about the room portion of # the token, which we replace. token = StreamToken.START - results["start"] = await token.copy_and_replace( - "room_key", results["start"] - ).to_string(self.store) - - results["end"] = await token.copy_and_replace( - "room_key", results["end"] - ).to_string(self.store) - - return results + return EventContext( + events_before=events_before, + event=event, + events_after=events_after, + state=await filter_evts(state_events), + aggregations=aggregations, + start=await token.copy_and_replace("room_key", results.start).to_string( + self.store + ), + end=await token.copy_and_replace("room_key", results.end).to_string( + self.store + ), + ) class TimestampLookupHandler: diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 0b153a6822..02bb5ae72f 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -361,36 +361,37 @@ class SearchHandler: logger.info( "Context for search returned %d and %d events", - len(res["events_before"]), - len(res["events_after"]), + len(res.events_before), + len(res.events_after), ) - res["events_before"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_before"] + events_before = await filter_events_for_client( + self.storage, user.to_string(), res.events_before ) - res["events_after"] = await filter_events_for_client( - self.storage, user.to_string(), res["events_after"] + events_after = await filter_events_for_client( + self.storage, user.to_string(), res.events_after ) - res["start"] = await now_token.copy_and_replace( - "room_key", res["start"] - ).to_string(self.store) - - res["end"] = await now_token.copy_and_replace( - "room_key", res["end"] - ).to_string(self.store) + context = { + "events_before": events_before, + "events_after": events_after, + "start": await now_token.copy_and_replace( + "room_key", res.start + ).to_string(self.store), + "end": await now_token.copy_and_replace( + "room_key", res.end + ).to_string(self.store), + } if include_profile: senders = { ev.sender - for ev in itertools.chain( - res["events_before"], [event], res["events_after"] - ) + for ev in itertools.chain(events_before, [event], events_after) } - if res["events_after"]: - last_event_id = res["events_after"][-1].event_id + if events_after: + last_event_id = events_after[-1].event_id else: last_event_id = event.event_id @@ -402,7 +403,7 @@ class SearchHandler: last_event_id, state_filter ) - res["profile_info"] = { + context["profile_info"] = { s.state_key: { "displayname": s.content.get("displayname", None), "avatar_url": s.content.get("avatar_url", None), @@ -411,7 +412,7 @@ class SearchHandler: if s.type == EventTypes.Member and s.state_key in senders } - contexts[event.event_id] = res + contexts[event.event_id] = context else: contexts = {} @@ -421,10 +422,10 @@ class SearchHandler: for context in contexts.values(): context["events_before"] = self._event_serializer.serialize_events( - context["events_before"], time_now + context["events_before"], time_now # type: ignore[arg-type] ) context["events_after"] = self._event_serializer.serialize_events( - context["events_after"], time_now + context["events_after"], time_now # type: ignore[arg-type] ) state_results = {} diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index 7e2a892b63..c72ed7c290 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -37,6 +37,7 @@ from synapse.logging.context import current_context from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span from synapse.push.clientformat import format_push_rules_for_user from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.relations import BundledAggregations from synapse.storage.roommember import MemberSummary from synapse.storage.state import StateFilter from synapse.types import ( @@ -100,7 +101,7 @@ class TimelineBatch: limited: bool # A mapping of event ID to the bundled aggregations for the above events. # This is only calculated if limited is true. - bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None + bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index dadfc57413..3df8452eec 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -455,7 +455,7 @@ class Mailer: } the_events = await filter_events_for_client( - self.storage, user_id, results["events_before"] + self.storage, user_id, results.events_before ) the_events.append(notif_event) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index efe25fe7eb..5b706efbcf 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, @@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet): use_admin_priviledge=True, ) - if not results: + if not event_context: raise SynapseError( HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND ) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return HTTPStatus.OK, results diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 90bb9142a0..90355e44b2 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = await self.room_context_handler.get_event_context( + event_context = await self.room_context_handler.get_event_context( requester, room_id, event_id, limit, event_filter ) - if not results: + if not event_context: raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - aggregations = results.pop("aggregations", None) - results["events_before"] = self._event_serializer.serialize_events( - results["events_before"], time_now, bundle_aggregations=aggregations - ) - results["event"] = self._event_serializer.serialize_event( - results["event"], time_now, bundle_aggregations=aggregations - ) - results["events_after"] = self._event_serializer.serialize_events( - results["events_after"], time_now, bundle_aggregations=aggregations - ) - results["state"] = self._event_serializer.serialize_events( - results["state"], time_now - ) + results = { + "events_before": self._event_serializer.serialize_events( + event_context.events_before, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "event": self._event_serializer.serialize_event( + event_context.event, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "events_after": self._event_serializer.serialize_events( + event_context.events_after, + time_now, + bundle_aggregations=event_context.aggregations, + ), + "state": self._event_serializer.serialize_events( + event_context.state, time_now + ), + "start": event_context.start, + "end": event_context.end, + } return 200, results diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index d20ae1421e..f9615da525 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -48,6 +48,7 @@ from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace +from synapse.storage.databases.main.relations import BundledAggregations from synapse.types import JsonDict, StreamToken from synapse.util import json_decoder @@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet): def serialize( events: Iterable[EventBase], - aggregations: Optional[Dict[str, Dict[str, Any]]] = None, + aggregations: Optional[Dict[str, BundledAggregations]] = None, ) -> List[JsonDict]: return self._event_serializer.serialize_events( events, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 2cb5d06c13..a9a5dd5f03 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -13,17 +13,7 @@ # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Any, - Dict, - Iterable, - List, - Optional, - Tuple, - Union, - cast, -) +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast import attr from frozendict import frozendict @@ -43,6 +33,7 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached if TYPE_CHECKING: @@ -51,6 +42,30 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _ThreadAggregation: + latest_event: EventBase + count: int + current_user_participated: bool + + +@attr.s(slots=True, auto_attribs=True) +class BundledAggregations: + """ + The bundled aggregations for an event. + + Some values require additional processing during serialization. + """ + + annotations: Optional[JsonDict] = None + references: Optional[JsonDict] = None + replace: Optional[EventBase] = None + thread: Optional[_ThreadAggregation] = None + + def __bool__(self) -> bool: + return bool(self.annotations or self.references or self.replace or self.thread) + + class RelationsWorkerStore(SQLBaseStore): def __init__( self, @@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore): async def _get_bundled_aggregation_for_event( self, event: EventBase, user_id: str - ) -> Optional[Dict[str, Any]]: + ) -> Optional[BundledAggregations]: """Generate bundled aggregations for an event. Note that this does not use a cache, but depends on cached methods. @@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore): # The bundled aggregations to include, a mapping of relation type to a # type-specific value. Some types include the direct return type here # while others need more processing during serialization. - aggregations: Dict[str, Any] = {} + aggregations = BundledAggregations() annotations = await self.get_aggregation_groups_for_event(event_id, room_id) if annotations.chunk: - aggregations[RelationTypes.ANNOTATION] = annotations.to_dict() + aggregations.annotations = annotations.to_dict() references = await self.get_relations_for_event( event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations[RelationTypes.REFERENCE] = references.to_dict() + aggregations.references = references.to_dict() edit = None if event.type == EventTypes.Message: edit = await self.get_applicable_edit(event_id, room_id) if edit: - aggregations[RelationTypes.REPLACE] = edit + aggregations.replace = edit # If this event is the start of a thread, include a summary of the replies. if self._msc3440_enabled: @@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore): event_id, room_id, user_id ) if latest_thread_event: - aggregations[RelationTypes.THREAD] = { - "latest_event": latest_thread_event, - "count": thread_count, - "current_user_participated": participated, - } + aggregations.thread = _ThreadAggregation( + latest_event=latest_thread_event, + count=thread_count, + current_user_participated=participated, + ) # Store the bundled aggregations in the event metadata for later use. return aggregations @@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore): self, events: Iterable[EventBase], user_id: str, - ) -> Dict[str, Dict[str, Any]]: + ) -> Dict[str, BundledAggregations]: """Generate bundled aggregations for events. Args: @@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore): results = {} for event in events: event_result = await self._get_bundled_aggregation_for_event(event, user_id) - if event_result is not None: + if event_result: results[event.event_id] = event_result return results diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 319464b1fa..a898f847e7 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -81,6 +81,14 @@ class _EventDictReturn: stream_ordering: int +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _EventsAround: + events_before: List[EventBase] + events_after: List[EventBase] + start: RoomStreamToken + end: RoomStreamToken + + def generate_pagination_where_clause( direction: str, column_names: Tuple[str, str], @@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): before_limit: int, after_limit: int, event_filter: Optional[Filter] = None, - ) -> dict: + ) -> _EventsAround: """Retrieve events and pagination tokens around a given event in a room. """ @@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): list(results["after"]["event_ids"]), get_prev_content=True ) - return { - "events_before": events_before, - "events_after": events_after, - "start": results["before"]["token"], - "end": results["after"]["token"], - } + return _EventsAround( + events_before=events_before, + events_after=events_after, + start=results["before"]["token"], + end=results["after"]["token"], + ) def _get_events_around_txn( self, diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index c9b220e73d..96ae7790bb 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -577,7 +577,7 @@ class RelationsTestCase(unittest.HomeserverTestCase): self.assertEquals(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] self.assertTrue(room_timeline["limited"]) - self._find_event_in_chunk(room_timeline["events"]) + assert_bundle(self._find_event_in_chunk(room_timeline["events"])) def test_aggregation_get_event_for_annotation(self): """Test that annotations do not get bundled aggregations included |