diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 38f3cf4d33..9acb3c0cc4 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -315,10 +315,11 @@ class EventBase(metaclass=abc.ABCMeta):
redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None)
room_id: DictProperty[str] = DictProperty("room_id")
sender: DictProperty[str] = DictProperty("sender")
- # TODO state_key should be Optional[str], this is generally asserted in Synapse
- # by calling is_state() first (which ensures this), but it is hard (not possible?)
+ # TODO state_key should be Optional[str]. This is generally asserted in Synapse
+ # by calling is_state() first (which ensures it is not None), but it is hard (not possible?)
# to properly annotate that calling is_state() asserts that state_key exists
- # and is non-None.
+ # and is non-None. It would be better to replace such direct references with
+ # get_state_key() (and a check for None).
state_key: DictProperty[str] = DictProperty("state_key")
type: DictProperty[str] = DictProperty("type")
user_id: DictProperty[str] = DictProperty("sender")
@@ -332,7 +333,11 @@ class EventBase(metaclass=abc.ABCMeta):
return self.content["membership"]
def is_state(self) -> bool:
- return hasattr(self, "state_key") and self.state_key is not None
+ return self.get_state_key() is not None
+
+ def get_state_key(self) -> Optional[str]:
+ """Get the state key of this event, or None if it's not a state event"""
+ return self._dict.get("state_key")
def get_dict(self) -> JsonDict:
d = dict(self._dict)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 0eab1aefd6..5833fee25f 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -163,7 +163,7 @@ class EventContext:
return {
"prev_state_id": prev_state_id,
"event_type": event.type,
- "event_state_key": event.state_key if event.is_state() else None,
+ "event_state_key": event.get_state_key(),
"state_group": self._state_group,
"state_group_before_event": self.state_group_before_event,
"rejected": self.rejected,
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 918adeecf8..243696b357 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -14,7 +14,17 @@
# limitations under the License.
import collections.abc
import re
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Union,
+)
from frozendict import frozendict
@@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze
from . import EventBase
+if TYPE_CHECKING:
+ from synapse.storage.databases.main.relations import BundledAggregations
+
+
# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
@@ -376,7 +390,7 @@ class EventClientSerializer:
event: Union[JsonDict, EventBase],
time_now: int,
*,
- bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
+ bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
@@ -415,7 +429,7 @@ class EventClientSerializer:
self,
event: EventBase,
time_now: int,
- aggregations: JsonDict,
+ aggregations: "BundledAggregations",
serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.
@@ -427,13 +441,18 @@ class EventClientSerializer:
serialized_event: The serialized event which may be modified.
"""
- # Make a copy in-case the object is cached.
- aggregations = aggregations.copy()
+ serialized_aggregations = {}
+
+ if aggregations.annotations:
+ serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
+
+ if aggregations.references:
+ serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
- if RelationTypes.REPLACE in aggregations:
+ if aggregations.replace:
# If there is an edit replace the content, preserving existing
# relations.
- edit = aggregations[RelationTypes.REPLACE]
+ edit = aggregations.replace
# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
@@ -451,24 +470,28 @@ class EventClientSerializer:
else:
serialized_event["content"].pop("m.relates_to", None)
- aggregations[RelationTypes.REPLACE] = {
+ serialized_aggregations[RelationTypes.REPLACE] = {
"event_id": edit.event_id,
"origin_server_ts": edit.origin_server_ts,
"sender": edit.sender,
}
# If this event is the start of a thread, include a summary of the replies.
- if RelationTypes.THREAD in aggregations:
- # Serialize the latest thread event.
- latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
-
- # Don't bundle aggregations as this could recurse forever.
- aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
- latest_thread_event, time_now, bundle_aggregations=None
- )
+ if aggregations.thread:
+ serialized_aggregations[RelationTypes.THREAD] = {
+ # Don't bundle aggregations as this could recurse forever.
+ "latest_event": self.serialize_event(
+ aggregations.thread.latest_event, time_now, bundle_aggregations=None
+ ),
+ "count": aggregations.thread.count,
+ "current_user_participated": aggregations.thread.current_user_participated,
+ }
# Include the bundled aggregations in the event.
- serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
+ if serialized_aggregations:
+ serialized_event["unsigned"].setdefault("m.relations", {}).update(
+ serialized_aggregations
+ )
def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index cf86934968..360d24274a 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections.abc
-from typing import Iterable, Union
+from typing import Iterable, Type, Union
import jsonschema
@@ -246,7 +246,7 @@ POWER_LEVELS_SCHEMA = {
# This could return something newer than Draft 7, but that's the current "latest"
# validator.
-def _create_power_level_validator() -> jsonschema.Draft7Validator:
+def _create_power_level_validator() -> Type[jsonschema.Draft7Validator]:
validator = jsonschema.validators.validator_for(POWER_LEVELS_SCHEMA)
# by default jsonschema does not consider a frozendict to be an object so
|