diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 520edbbf61..23bd24d963 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -13,18 +13,32 @@
# limitations under the License.
import collections.abc
import re
-from typing import Any, Mapping, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Union,
+)
from frozendict import frozendict
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
+from synapse.types import JsonDict
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.frozenutils import unfreeze
from . import EventBase
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@@ -65,7 +79,7 @@ def prune_event(event: EventBase) -> EventBase:
return pruned_event
-def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
+def prune_event_dict(room_version: RoomVersion, event_dict: JsonDict) -> JsonDict:
"""Redacts the event_dict in the same way as `prune_event`, except it
operates on dicts rather than event objects
@@ -97,7 +111,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
new_content = {}
- def add_fields(*fields):
+ def add_fields(*fields: str) -> None:
for field in fields:
if field in event_dict["content"]:
new_content[field] = event_dict["content"][field]
@@ -151,7 +165,7 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
allowed_fields["content"] = new_content
- unsigned = {}
+ unsigned: JsonDict = {}
allowed_fields["unsigned"] = unsigned
event_unsigned = event_dict.get("unsigned", {})
@@ -164,16 +178,16 @@ def prune_event_dict(room_version: RoomVersion, event_dict: dict) -> dict:
return allowed_fields
-def _copy_field(src, dst, field):
+def _copy_field(src: JsonDict, dst: JsonDict, field: List[str]) -> None:
"""Copy the field in 'src' to 'dst'.
For example, if src={"foo":{"bar":5}} and dst={}, and field=["foo","bar"]
then dst={"foo":{"bar":5}}.
Args:
- src(dict): The dict to read from.
- dst(dict): The dict to modify.
- field(list<str>): List of keys to drill down to in 'src'.
+ src: The dict to read from.
+ dst: The dict to modify.
+ field: List of keys to drill down to in 'src'.
"""
if len(field) == 0: # this should be impossible
return
@@ -205,7 +219,7 @@ def _copy_field(src, dst, field):
sub_out_dict[key_to_move] = sub_dict[key_to_move]
-def only_fields(dictionary, fields):
+def only_fields(dictionary: JsonDict, fields: List[str]) -> JsonDict:
"""Return a new dict with only the fields in 'dictionary' which are present
in 'fields'.
@@ -215,11 +229,11 @@ def only_fields(dictionary, fields):
A literal '.' character in a field name may be escaped using a '\'.
Args:
- dictionary(dict): The dictionary to read from.
- fields(list<str>): A list of fields to copy over. Only shallow refs are
+ dictionary: The dictionary to read from.
+ fields: A list of fields to copy over. Only shallow refs are
taken.
Returns:
- dict: A new dictionary with only the given fields. If fields was empty,
+ A new dictionary with only the given fields. If fields was empty,
the same dictionary is returned.
"""
if len(fields) == 0:
@@ -235,17 +249,17 @@ def only_fields(dictionary, fields):
[f.replace(r"\.", r".") for f in field_array] for field_array in split_fields
]
- output = {}
+ output: JsonDict = {}
for field_array in split_fields:
_copy_field(dictionary, output, field_array)
return output
-def format_event_raw(d):
+def format_event_raw(d: JsonDict) -> JsonDict:
return d
-def format_event_for_client_v1(d):
+def format_event_for_client_v1(d: JsonDict) -> JsonDict:
d = format_event_for_client_v2(d)
sender = d.get("sender")
@@ -267,7 +281,7 @@ def format_event_for_client_v1(d):
return d
-def format_event_for_client_v2(d):
+def format_event_for_client_v2(d: JsonDict) -> JsonDict:
drop_keys = (
"auth_events",
"prev_events",
@@ -282,37 +296,37 @@ def format_event_for_client_v2(d):
return d
-def format_event_for_client_v2_without_room_id(d):
+def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
d = format_event_for_client_v2(d)
d.pop("room_id", None)
return d
def serialize_event(
- e,
- time_now_ms,
- as_client_event=True,
- event_format=format_event_for_client_v1,
- token_id=None,
- only_event_fields=None,
- include_stripped_room_state=False,
-):
+ e: Union[JsonDict, EventBase],
+ time_now_ms: int,
+ as_client_event: bool = True,
+ event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
+ token_id: Optional[str] = None,
+ only_event_fields: Optional[List[str]] = None,
+ include_stripped_room_state: bool = False,
+) -> JsonDict:
"""Serialize event for clients
Args:
- e (EventBase)
- time_now_ms (int)
- as_client_event (bool)
+ e
+ time_now_ms
+ as_client_event
event_format
token_id
only_event_fields
- include_stripped_room_state (bool): Some events can have stripped room state
+ include_stripped_room_state: Some events can have stripped room state
stored in the `unsigned` field. This is required for invite and knock
functionality. If this option is False, that state will be removed from the
event before it is returned. Otherwise, it will be kept.
Returns:
- dict
+ The serialized event dictionary.
"""
# FIXME(erikj): To handle the case of presence events and the like
@@ -369,25 +383,29 @@ class EventClientSerializer:
clients.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.experimental_msc1849_support_enabled = (
hs.config.server.experimental_msc1849_support_enabled
)
async def serialize_event(
- self, event, time_now, bundle_aggregations=True, **kwargs
- ):
+ self,
+ event: Union[JsonDict, EventBase],
+ time_now: int,
+ bundle_aggregations: bool = True,
+ **kwargs: Any,
+ ) -> JsonDict:
"""Serializes a single event.
Args:
- event (EventBase)
- time_now (int): The current time in milliseconds
- bundle_aggregations (bool): Whether to bundle in related events
+ event
+ time_now: The current time in milliseconds
+ bundle_aggregations: Whether to bundle in related events
**kwargs: Arguments to pass to `serialize_event`
Returns:
- dict: The serialized event
+ The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
@@ -448,25 +466,27 @@ class EventClientSerializer:
return serialized_event
- def serialize_events(self, events, time_now, **kwargs):
+ async def serialize_events(
+ self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
+ ) -> List[JsonDict]:
"""Serializes multiple events.
Args:
- event (iter[EventBase])
- time_now (int): The current time in milliseconds
+ event
+ time_now: The current time in milliseconds
**kwargs: Arguments to pass to `serialize_event`
Returns:
- Deferred[list[dict]]: The list of serialized events
+ The list of serialized events
"""
- return yieldable_gather_results(
+ return await yieldable_gather_results(
self.serialize_event, events, time_now=time_now, **kwargs
)
def copy_power_levels_contents(
old_power_levels: Mapping[str, Union[int, Mapping[str, int]]]
-):
+) -> Dict[str, Union[int, Dict[str, int]]]:
"""Copy the content of a power_levels event, unfreezing frozendicts along the way
Raises:
@@ -475,7 +495,7 @@ def copy_power_levels_contents(
if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
- power_levels = {}
+ power_levels: Dict[str, Union[int, Dict[str, int]]] = {}
for k, v in old_power_levels.items():
if isinstance(v, int):
@@ -483,7 +503,8 @@ def copy_power_levels_contents(
continue
if isinstance(v, collections.abc.Mapping):
- power_levels[k] = h = {}
+ h: Dict[str, int] = {}
+ power_levels[k] = h
for k1, v1 in v.items():
# we should only have one level of nesting
if not isinstance(v1, int):
@@ -498,7 +519,7 @@ def copy_power_levels_contents(
return power_levels
-def validate_canonicaljson(value: Any):
+def validate_canonicaljson(value: Any) -> None:
"""
Ensure that the JSON object is valid according to the rules of canonical JSON.
|