summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/appservice/api.py24
-rw-r--r--synapse/events/utils.py81
-rw-r--r--synapse/handlers/events.py3
-rw-r--r--synapse/handlers/initial_sync.py9
-rw-r--r--synapse/handlers/pagination.py7
-rw-r--r--synapse/rest/client/notifications.py9
-rw-r--r--synapse/rest/client/sync.py132
7 files changed, 125 insertions, 140 deletions
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index a0ea958af6..98fe354014 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -25,7 +25,7 @@ from synapse.appservice import (
     TransactionUnusedFallbackKeys,
 )
 from synapse.events import EventBase
-from synapse.events.utils import serialize_event
+from synapse.events.utils import SerializeEventConfig, serialize_event
 from synapse.http.client import SimpleHttpClient
 from synapse.types import JsonDict, ThirdPartyInstanceID
 from synapse.util.caches.response_cache import ResponseCache
@@ -321,16 +321,18 @@ class ApplicationServiceApi(SimpleHttpClient):
             serialize_event(
                 e,
                 time_now,
-                as_client_event=True,
-                # If this is an invite or a knock membership event, and we're interested
-                # in this user, then include any stripped state alongside the event.
-                include_stripped_room_state=(
-                    e.type == EventTypes.Member
-                    and (
-                        e.membership == Membership.INVITE
-                        or e.membership == Membership.KNOCK
-                    )
-                    and service.is_interested_in_user(e.state_key)
+                config=SerializeEventConfig(
+                    as_client_event=True,
+                    # If this is an invite or a knock membership event, and we're interested
+                    # in this user, then include any stripped state alongside the event.
+                    include_stripped_room_state=(
+                        e.type == EventTypes.Member
+                        and (
+                            e.membership == Membership.INVITE
+                            or e.membership == Membership.KNOCK
+                        )
+                        and service.is_interested_in_user(e.state_key)
+                    ),
                 ),
             )
             for e in events
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9386fa29dd..ee34cb46e4 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -26,6 +26,7 @@ from typing import (
     Union,
 )
 
+import attr
 from frozendict import frozendict
 
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
@@ -303,29 +304,37 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
     return d
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SerializeEventConfig:
+    as_client_event: bool = True
+    # Function to convert from federation format to client format
+    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1
+    # ID of the user's auth token - used for namespacing of transaction IDs
+    token_id: Optional[int] = None
+    # List of event fields to include. If empty, all fields will be returned.
+    only_event_fields: Optional[List[str]] = None
+    # Some events can have stripped room state stored in the `unsigned` field.
+    # This is required for invite and knock functionality. If this option is
+    # False, that state will be removed from the event before it is returned.
+    # Otherwise, it will be kept.
+    include_stripped_room_state: bool = False
+
+
+_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
+
+
 def serialize_event(
     e: Union[JsonDict, EventBase],
     time_now_ms: int,
     *,
-    as_client_event: bool = True,
-    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
-    token_id: Optional[str] = None,
-    only_event_fields: Optional[List[str]] = None,
-    include_stripped_room_state: bool = False,
+    config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
 ) -> JsonDict:
     """Serialize event for clients
 
     Args:
         e
         time_now_ms
-        as_client_event
-        event_format
-        token_id
-        only_event_fields
-        include_stripped_room_state: Some events can have stripped room state
-            stored in the `unsigned` field. This is required for invite and knock
-            functionality. If this option is False, that state will be removed from the
-            event before it is returned. Otherwise, it will be kept.
+        config: Event serialization config
 
     Returns:
         The serialized event dictionary.
@@ -348,11 +357,11 @@ def serialize_event(
 
     if "redacted_because" in e.unsigned:
         d["unsigned"]["redacted_because"] = serialize_event(
-            e.unsigned["redacted_because"], time_now_ms, event_format=event_format
+            e.unsigned["redacted_because"], time_now_ms, config=config
         )
 
-    if token_id is not None:
-        if token_id == getattr(e.internal_metadata, "token_id", None):
+    if config.token_id is not None:
+        if config.token_id == getattr(e.internal_metadata, "token_id", None):
             txn_id = getattr(e.internal_metadata, "txn_id", None)
             if txn_id is not None:
                 d["unsigned"]["transaction_id"] = txn_id
@@ -361,13 +370,14 @@ def serialize_event(
     # that are meant to provide metadata about a room to an invitee/knocker. They are
     # intended to only be included in specific circumstances, such as down sync, and
     # should not be included in any other case.
-    if not include_stripped_room_state:
+    if not config.include_stripped_room_state:
         d["unsigned"].pop("invite_room_state", None)
         d["unsigned"].pop("knock_room_state", None)
 
-    if as_client_event:
-        d = event_format(d)
+    if config.as_client_event:
+        d = config.event_format(d)
 
+    only_event_fields = config.only_event_fields
     if only_event_fields:
         if not isinstance(only_event_fields, list) or not all(
             isinstance(f, str) for f in only_event_fields
@@ -390,18 +400,18 @@ class EventClientSerializer:
         event: Union[JsonDict, EventBase],
         time_now: int,
         *,
+        config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
         bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
-        **kwargs: Any,
     ) -> JsonDict:
         """Serializes a single event.
 
         Args:
             event: The event being serialized.
             time_now: The current time in milliseconds
+            config: Event serialization config
             bundle_aggregations: Whether to include the bundled aggregations for this
                 event. Only applies to non-state events. (State events never include
                 bundled aggregations.)
-            **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
             The serialized event
@@ -410,7 +420,7 @@ class EventClientSerializer:
         if not isinstance(event, EventBase):
             return event
 
-        serialized_event = serialize_event(event, time_now, **kwargs)
+        serialized_event = serialize_event(event, time_now, config=config)
 
         # Check if there are any bundled aggregations to include with the event.
         if bundle_aggregations:
@@ -419,6 +429,7 @@ class EventClientSerializer:
                 self._inject_bundled_aggregations(
                     event,
                     time_now,
+                    config,
                     bundle_aggregations[event.event_id],
                     serialized_event,
                 )
@@ -456,6 +467,7 @@ class EventClientSerializer:
         self,
         event: EventBase,
         time_now: int,
+        config: SerializeEventConfig,
         aggregations: "BundledAggregations",
         serialized_event: JsonDict,
     ) -> None:
@@ -466,6 +478,7 @@ class EventClientSerializer:
             time_now: The current time in milliseconds
             aggregations: The bundled aggregation to serialize.
             serialized_event: The serialized event which may be modified.
+            config: Event serialization config
 
         """
         serialized_aggregations = {}
@@ -493,8 +506,8 @@ class EventClientSerializer:
             thread = aggregations.thread
 
             # Don't bundle aggregations as this could recurse forever.
-            serialized_latest_event = self.serialize_event(
-                thread.latest_event, time_now, bundle_aggregations=None
+            serialized_latest_event = serialize_event(
+                thread.latest_event, time_now, config=config
             )
             # Manually apply an edit, if one exists.
             if thread.latest_edit:
@@ -515,20 +528,34 @@ class EventClientSerializer:
             )
 
     def serialize_events(
-        self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
+        self,
+        events: Iterable[Union[JsonDict, EventBase]],
+        time_now: int,
+        *,
+        config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
+        bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
     ) -> List[JsonDict]:
         """Serializes multiple events.
 
         Args:
             event
             time_now: The current time in milliseconds
-            **kwargs: Arguments to pass to `serialize_event`
+            config: Event serialization config
+            bundle_aggregations: Whether to include the bundled aggregations for this
+                event. Only applies to non-state events. (State events never include
+                bundled aggregations.)
 
         Returns:
             The list of serialized events
         """
         return [
-            self.serialize_event(event, time_now=time_now, **kwargs) for event in events
+            self.serialize_event(
+                event,
+                time_now,
+                config=config,
+                bundle_aggregations=bundle_aggregations,
+            )
+            for event in events
         ]
 
 
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 97e75e60c3..d2ccb5c5d3 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional
 from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import AuthError, SynapseError
 from synapse.events import EventBase
+from synapse.events.utils import SerializeEventConfig
 from synapse.handlers.presence import format_user_presence_state
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, UserID
@@ -120,7 +121,7 @@ class EventStreamHandler:
             chunks = self._event_serializer.serialize_events(
                 events,
                 time_now,
-                as_client_event=as_client_event,
+                config=SerializeEventConfig(as_client_event=as_client_event),
             )
 
             chunk = {
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 344f20f37c..316cfae24f 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -18,6 +18,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 from synapse.api.constants import EduTypes, EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase
+from synapse.events.utils import SerializeEventConfig
 from synapse.events.validator import EventValidator
 from synapse.handlers.presence import format_user_presence_state
 from synapse.handlers.receipts import ReceiptEventSource
@@ -156,6 +157,8 @@ class InitialSyncHandler:
         if limit is None:
             limit = 10
 
+        serializer_options = SerializeEventConfig(as_client_event=as_client_event)
+
         async def handle_room(event: RoomsForUser) -> None:
             d: JsonDict = {
                 "room_id": event.room_id,
@@ -173,7 +176,7 @@ class InitialSyncHandler:
                 d["invite"] = self._event_serializer.serialize_event(
                     invite_event,
                     time_now,
-                    as_client_event=as_client_event,
+                    config=serializer_options,
                 )
 
             rooms_ret.append(d)
@@ -225,7 +228,7 @@ class InitialSyncHandler:
                         self._event_serializer.serialize_events(
                             messages,
                             time_now=time_now,
-                            as_client_event=as_client_event,
+                            config=serializer_options,
                         )
                     ),
                     "start": await start_token.to_string(self.store),
@@ -235,7 +238,7 @@ class InitialSyncHandler:
                 d["state"] = self._event_serializer.serialize_events(
                     current_state.values(),
                     time_now=time_now,
-                    as_client_event=as_client_event,
+                    config=serializer_options,
                 )
 
                 account_data_events = []
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 5c01a426ff..183fabcfc0 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -22,6 +22,7 @@ from twisted.python.failure import Failure
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
+from synapse.events.utils import SerializeEventConfig
 from synapse.handlers.room import ShutdownRoomResponse
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.state import StateFilter
@@ -541,13 +542,15 @@ class PaginationHandler:
 
         time_now = self.clock.time_msec()
 
+        serialize_options = SerializeEventConfig(as_client_event=as_client_event)
+
         chunk = {
             "chunk": (
                 self._event_serializer.serialize_events(
                     events,
                     time_now,
+                    config=serialize_options,
                     bundle_aggregations=aggregations,
-                    as_client_event=as_client_event,
                 )
             ),
             "start": await from_token.to_string(self.store),
@@ -556,7 +559,7 @@ class PaginationHandler:
 
         if state:
             chunk["state"] = self._event_serializer.serialize_events(
-                state, time_now, as_client_event=as_client_event
+                state, time_now, config=serialize_options
             )
 
         return chunk
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index 20377a9ac6..ff040de6b8 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -16,7 +16,10 @@ import logging
 from typing import TYPE_CHECKING, Tuple
 
 from synapse.api.constants import ReceiptTypes
-from synapse.events.utils import format_event_for_client_v2_without_room_id
+from synapse.events.utils import (
+    SerializeEventConfig,
+    format_event_for_client_v2_without_room_id,
+)
 from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
@@ -75,7 +78,9 @@ class NotificationsServlet(RestServlet):
                     self._event_serializer.serialize_event(
                         notif_events[pa.event_id],
                         self.clock.time_msec(),
-                        event_format=format_event_for_client_v2_without_room_id,
+                        config=SerializeEventConfig(
+                            event_format=format_event_for_client_v2_without_room_id
+                        ),
                     )
                 ),
             }
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index f3018ff690..53c385a86c 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -14,24 +14,14 @@
 import itertools
 import logging
 from collections import defaultdict
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Callable,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-    Union,
-)
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 
 from synapse.api.constants import Membership, PresenceState
 from synapse.api.errors import Codes, StoreError, SynapseError
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
-from synapse.events import EventBase
 from synapse.events.utils import (
+    SerializeEventConfig,
     format_event_for_client_v2_without_room_id,
     format_event_raw,
 )
@@ -48,7 +38,6 @@ from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import trace
-from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.types import JsonDict, StreamToken
 from synapse.util import json_decoder
 
@@ -239,28 +228,31 @@ class SyncRestServlet(RestServlet):
         else:
             raise Exception("Unknown event format %s" % (filter.event_format,))
 
+        serialize_options = SerializeEventConfig(
+            event_format=event_formatter,
+            token_id=access_token_id,
+            only_event_fields=filter.event_fields,
+        )
+        stripped_serialize_options = SerializeEventConfig(
+            event_format=event_formatter,
+            token_id=access_token_id,
+            include_stripped_room_state=True,
+        )
+
         joined = await self.encode_joined(
-            sync_result.joined,
-            time_now,
-            access_token_id,
-            filter.event_fields,
-            event_formatter,
+            sync_result.joined, time_now, serialize_options
         )
 
         invited = await self.encode_invited(
-            sync_result.invited, time_now, access_token_id, event_formatter
+            sync_result.invited, time_now, stripped_serialize_options
         )
 
         knocked = await self.encode_knocked(
-            sync_result.knocked, time_now, access_token_id, event_formatter
+            sync_result.knocked, time_now, stripped_serialize_options
         )
 
         archived = await self.encode_archived(
-            sync_result.archived,
-            time_now,
-            access_token_id,
-            filter.event_fields,
-            event_formatter,
+            sync_result.archived, time_now, serialize_options
         )
 
         logger.debug("building sync response dict")
@@ -339,9 +331,7 @@ class SyncRestServlet(RestServlet):
         self,
         rooms: List[JoinedSyncResult],
         time_now: int,
-        token_id: Optional[int],
-        event_fields: List[str],
-        event_formatter: Callable[[JsonDict], JsonDict],
+        serialize_options: SerializeEventConfig,
     ) -> JsonDict:
         """
         Encode the joined rooms in a sync result
@@ -349,24 +339,14 @@ class SyncRestServlet(RestServlet):
         Args:
             rooms: list of sync results for rooms this user is joined to
             time_now: current time - used as a baseline for age calculations
-            token_id: ID of the user's auth token - used for namespacing
-                of transaction IDs
-            event_fields: List of event fields to include. If empty,
-                all fields will be returned.
-            event_formatter: function to convert from federation format
-                to client format
+            serialize_options: Event serializer options
         Returns:
             The joined rooms list, in our response format
         """
         joined = {}
         for room in rooms:
             joined[room.room_id] = await self.encode_room(
-                room,
-                time_now,
-                token_id,
-                joined=True,
-                only_fields=event_fields,
-                event_formatter=event_formatter,
+                room, time_now, joined=True, serialize_options=serialize_options
             )
 
         return joined
@@ -376,8 +356,7 @@ class SyncRestServlet(RestServlet):
         self,
         rooms: List[InvitedSyncResult],
         time_now: int,
-        token_id: Optional[int],
-        event_formatter: Callable[[JsonDict], JsonDict],
+        serialize_options: SerializeEventConfig,
     ) -> JsonDict:
         """
         Encode the invited rooms in a sync result
@@ -385,10 +364,7 @@ class SyncRestServlet(RestServlet):
         Args:
             rooms: list of sync results for rooms this user is invited to
             time_now: current time - used as a baseline for age calculations
-            token_id: ID of the user's auth token - used for namespacing
-                of transaction IDs
-            event_formatter: function to convert from federation format
-                to client format
+            serialize_options: Event serializer options
 
         Returns:
             The invited rooms list, in our response format
@@ -396,11 +372,7 @@ class SyncRestServlet(RestServlet):
         invited = {}
         for room in rooms:
             invite = self._event_serializer.serialize_event(
-                room.invite,
-                time_now,
-                token_id=token_id,
-                event_format=event_formatter,
-                include_stripped_room_state=True,
+                room.invite, time_now, config=serialize_options
             )
             unsigned = dict(invite.get("unsigned", {}))
             invite["unsigned"] = unsigned
@@ -415,8 +387,7 @@ class SyncRestServlet(RestServlet):
         self,
         rooms: List[KnockedSyncResult],
         time_now: int,
-        token_id: Optional[int],
-        event_formatter: Callable[[Dict], Dict],
+        serialize_options: SerializeEventConfig,
     ) -> Dict[str, Dict[str, Any]]:
         """
         Encode the rooms we've knocked on in a sync result.
@@ -424,8 +395,7 @@ class SyncRestServlet(RestServlet):
         Args:
             rooms: list of sync results for rooms this user is knocking on
             time_now: current time - used as a baseline for age calculations
-            token_id: ID of the user's auth token - used for namespacing of transaction IDs
-            event_formatter: function to convert from federation format to client format
+            serialize_options: Event serializer options
 
         Returns:
             The list of rooms the user has knocked on, in our response format.
@@ -433,11 +403,7 @@ class SyncRestServlet(RestServlet):
         knocked = {}
         for room in rooms:
             knock = self._event_serializer.serialize_event(
-                room.knock,
-                time_now,
-                token_id=token_id,
-                event_format=event_formatter,
-                include_stripped_room_state=True,
+                room.knock, time_now, config=serialize_options
             )
 
             # Extract the `unsigned` key from the knock event.
@@ -470,9 +436,7 @@ class SyncRestServlet(RestServlet):
         self,
         rooms: List[ArchivedSyncResult],
         time_now: int,
-        token_id: Optional[int],
-        event_fields: List[str],
-        event_formatter: Callable[[JsonDict], JsonDict],
+        serialize_options: SerializeEventConfig,
     ) -> JsonDict:
         """
         Encode the archived rooms in a sync result
@@ -480,23 +444,14 @@ class SyncRestServlet(RestServlet):
         Args:
             rooms: list of sync results for rooms this user is joined to
             time_now: current time - used as a baseline for age calculations
-            token_id: ID of the user's auth token - used for namespacing
-                of transaction IDs
-            event_fields: List of event fields to include. If empty,
-                all fields will be returned.
-            event_formatter: function to convert from federation format to client format
+            serialize_options: Event serializer options
         Returns:
             The archived rooms list, in our response format
         """
         joined = {}
         for room in rooms:
             joined[room.room_id] = await self.encode_room(
-                room,
-                time_now,
-                token_id,
-                joined=False,
-                only_fields=event_fields,
-                event_formatter=event_formatter,
+                room, time_now, joined=False, serialize_options=serialize_options
             )
 
         return joined
@@ -505,10 +460,8 @@ class SyncRestServlet(RestServlet):
         self,
         room: Union[JoinedSyncResult, ArchivedSyncResult],
         time_now: int,
-        token_id: Optional[int],
         joined: bool,
-        only_fields: Optional[List[str]],
-        event_formatter: Callable[[JsonDict], JsonDict],
+        serialize_options: SerializeEventConfig,
     ) -> JsonDict:
         """
         Args:
@@ -524,20 +477,6 @@ class SyncRestServlet(RestServlet):
         Returns:
             The room, encoded in our response format
         """
-
-        def serialize(
-            events: Iterable[EventBase],
-            aggregations: Optional[Dict[str, BundledAggregations]] = None,
-        ) -> List[JsonDict]:
-            return self._event_serializer.serialize_events(
-                events,
-                time_now=time_now,
-                bundle_aggregations=aggregations,
-                token_id=token_id,
-                event_format=event_formatter,
-                only_event_fields=only_fields,
-            )
-
         state_dict = room.state
         timeline_events = room.timeline.events
 
@@ -554,9 +493,14 @@ class SyncRestServlet(RestServlet):
                     event.room_id,
                 )
 
-        serialized_state = serialize(state_events)
-        serialized_timeline = serialize(
-            timeline_events, room.timeline.bundled_aggregations
+        serialized_state = self._event_serializer.serialize_events(
+            state_events, time_now, config=serialize_options
+        )
+        serialized_timeline = self._event_serializer.serialize_events(
+            timeline_events,
+            time_now,
+            config=serialize_options,
+            bundle_aggregations=room.timeline.bundled_aggregations,
         )
 
         account_data = room.account_data