summary refs log tree commit diff
path: root/synapse/events/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/utils.py')
-rw-r--r--synapse/events/utils.py119
1 files changed, 69 insertions, 50 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 520edbbf61..3f3eba86a8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -13,18 +13,32 @@
 # limitations under the License.
 import collections.abc
 import re
-from typing import Any, Mapping, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Union,
+)
 
 from frozendict import frozendict
 
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
 from synapse.api.errors import Codes, SynapseError
 from synapse.api.room_versions import RoomVersion
+from synapse.types import JsonDict
 from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.frozenutils import unfreeze
 
 from . import EventBase
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
 # (?<!stuff) matches if the current position in the string is not preceded
 # by a match for 'stuff'.
@@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
     return pruned_event
 
 
-def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
+def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
     """Redacts the event_dict in the same way as `prune_event`, except it
     operates on dicts rather than event objects
 
@@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
 
     new_content = {}
 
-    def add_fields(*fields):
+    def add_fields(*fields: str) -> None:
         for field in fields:
             if field in event_dict["content"]:
                 new_content[field] = event_dict["content"][field]
@@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
 
     allowed_fields["content"] = new_content
 
-    unsigned = {}
+    unsigned: JsonDict = {}
     allowed_fields["unsigned"] = unsigned
 
     event_unsigned = event_dict.get("unsigned", {})
@@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
     return allowed_fields
 
 
-def _copy_field(src, dst, field):
+def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
     """Copy the field in 'src' to 'dst'.
 
     For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
     then dst={"foo":{"bar":5}}.
 
     Args:
-        src(dict): The dict to read from.
-        dst(dict): The dict to modify.
-        field(list<str>): List of keys to drill down to in 'src'.
+        src: The dict to read from.
+        dst: The dict to modify.
+        field: List of keys to drill down to in 'src'.
     """
     if len(field) == 0:  # this should be impossible
         return
@@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
     sub_out_dict[key_to_move] = sub_dict[key_to_move]
 
 
-def only_fields(dictionary, fields):
+def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
     """Return a new dict with only the fields in 'dictionary' which are present
     in 'fields'.
 
@@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
     A literal '.' character in a field name may be escaped using a '\'.
 
     Args:
-        dictionary(dict): The dictionary to read from.
-        fields(list<str>): A list of fields to copy over. Only shallow refs are
+        dictionary: The dictionary to read from.
+        fields: A list of fields to copy over. Only shallow refs are
         taken.
     Returns:
-        dict: A new dictionary with only the given fields. If fields was empty,
+        A new dictionary with only the given fields. If fields was empty,
         the same dictionary is returned.
     """
     if len(fields) == 0:
@@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
         [f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
     ]
 
-    output = {}
+    output: JsonDict = {}
     for field_array in split_fields:
         _copy_field(dictionary, output, field_array)
     return output
 
 
-def format_event_raw(d):
+def format_event_raw(d: JsonDict) -> JsonDict:
     return d
 
 
-def format_event_for_client_v1(d):
+def format_event_for_client_v1(d: JsonDict) -> JsonDict:
     d = format_event_for_client_v2(d)
 
     sender = d.get("sender")
@@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
     return d
 
 
-def format_event_for_client_v2(d):
+def format_event_for_client_v2(d: JsonDict) -> JsonDict:
     drop_keys = (
         "auth_events",
         "prev_events",
@@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
     return d
 
 
-def format_event_for_client_v2_without_room_id(d):
+def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
     d = format_event_for_client_v2(d)
     d.pop("room_id", None)
     return d
 
 
 def serialize_event(
-    e,
-    time_now_ms,
-    as_client_event=True,
-    event_format=format_event_for_client_v1,
-    token_id=None,
-    only_event_fields=None,
-    include_stripped_room_state=False,
-):
+    e: Union[JsonDict, EventBase],
+    time_now_ms: int,
+    as_client_event: bool = True,
+    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
+    token_id: Optional[str] = None,
+    only_event_fields: Optional[List[str]] = None,
+    include_stripped_room_state: bool = False,
+) -> JsonDict:
     """Serialize event for clients
 
     Args:
-        e (EventBase)
-        time_now_ms (int)
-        as_client_event (bool)
+        e
+        time_now_ms
+        as_client_event
         event_format
         token_id
         only_event_fields
-        include_stripped_room_state (bool): Some events can have stripped room state
+        include_stripped_room_state: Some events can have stripped room state
             stored in the `unsigned` field. This is required for invite and knock
             functionality. If this option is False, that state will be removed from the
             event before it is returned. Otherwise, it will be kept.
 
     Returns:
-        dict
+        The serialized event dictionary.
     """
 
     # FIXME(erikj): To handle the case of presence events and the like
@@ -369,25 +383,27 @@ class EventClientSerializer:
     clients.
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
-        self.experimental_msc1849_support_enabled = (
-            hs.config.server.experimental_msc1849_support_enabled
-        )
+        self._msc1849_enabled = hs.config.experimental.msc1849_enabled
 
     async def serialize_event(
-        self, event, time_now, bundle_aggregations=True, **kwargs
-    ):
+        self,
+        event: Union[JsonDict, EventBase],
+        time_now: int,
+        bundle_aggregations: bool = True,
+        **kwargs: Any,
+    ) -> JsonDict:
         """Serializes a single event.
 
         Args:
-            event (EventBase)
-            time_now (int): The current time in milliseconds
-            bundle_aggregations (bool): Whether to bundle in related events
+            event
+            time_now: The current time in milliseconds
+            bundle_aggregations: Whether to bundle in related events
             **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
-            dict: The serialized event
+            The serialized event
         """
         # To handle the case of presence events and the like
         if not isinstance(event, EventBase):
@@ -400,7 +416,7 @@ class EventClientSerializer:
         # we need to bundle in with the event.
         # Do not bundle relations if the event has been redacted
         if not event.internal_metadata.is_redacted() and (
-            self.experimental_msc1849_support_enabled and bundle_aggregations
+            self._msc1849_enabled and bundle_aggregations
         ):
             annotations = await self.store.get_aggregation_groups_for_event(event_id)
             references = await self.store.get_relations_for_event(
@@ -448,25 +464,27 @@ class EventClientSerializer:
 
         return serialized_event
 
-    def serialize_events(self, events, time_now, **kwargs):
+    async def serialize_events(
+        self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
+    ) -> List[JsonDict]:
         """Serializes multiple events.
 
         Args:
-            event (iter[EventBase])
-            time_now (int): The current time in milliseconds
+            event
+            time_now: The current time in milliseconds
             **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
-            Deferred[list[dict]]: The list of serialized events
+            The list of serialized events
         """
-        return yieldable_gather_results(
+        return await yieldable_gather_results(
             self.serialize_event, events, time_now=time_now, **kwargs
         )
 
 
 def copy_power_levels_contents(
     old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
-):
+) -> Dict[str, Union[int, Dict[str, int]]]:
     """Copy the content of a power_levels event, unfreezing frozendicts along the way
 
     Raises:
@@ -475,7 +493,7 @@ def copy_power_levels_contents(
     if not isinstance(old_power_levels, collections.abc.Mapping):
         raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
 
-    power_levels = {}
+    power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
     for k, v in old_power_levels.items():
 
         if isinstance(v, int):
@@ -483,7 +501,8 @@ def copy_power_levels_contents(
             continue
 
         if isinstance(v, collections.abc.Mapping):
-            power_levels[k] = h = {}
+            h: Dict[str, int] = {}
+            power_levels[k] = h
             for k1, v1 in v.items():
                 # we should only have one level of nesting
                 if not isinstance(v1, int):
@@ -498,7 +517,7 @@ def copy_power_levels_contents(
     return power_levels
 
 
-def validate_canonicaljson(value: Any):
+def validate_canonicaljson(value: Any) -> None:
     """
     Ensure that the JSON object is valid according to the rules of canonical JSON.