summary refs log tree commit diff
path: root/synapse/events
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events')
-rw-r--r--synapse/events/__init__.py227
-rw-r--r--synapse/events/third_party_rules.py40
-rw-r--r--synapse/events/validator.py2
3 files changed, 184 insertions, 85 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 157669ea88..38f3cf4d33 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -16,8 +16,23 @@
 
 import abc
 import os
-from typing import Dict, Optional, Tuple, Type
-
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generic,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+    overload,
+)
+
+from typing_extensions import Literal
 from unpaddedbase64 import encode_base64
 
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
@@ -26,6 +41,9 @@ from synapse.util.caches import intern_dict
 from synapse.util.frozenutils import freeze
 from synapse.util.stringutils import strtobool
 
+if TYPE_CHECKING:
+    from synapse.events.builder import EventBuilder
+
 # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
 # bugs where we accidentally share e.g. signature dicts. However, converting a
 # dict to frozen_dicts is expensive.
@@ -37,7 +55,23 @@ from synapse.util.stringutils import strtobool
 USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0"))
 
 
-class DictProperty:
+T = TypeVar("T")
+
+
+# DictProperty (and DefaultDictProperty) require the classes they're used with to
+# have a _dict property to pull properties from.
+#
+# TODO _DictPropertyInstance should not include EventBuilder but due to
+# https://github.com/python/mypy/issues/5570 it thinks the DictProperty and
+# DefaultDictProperty get applied to EventBuilder when it is in a Union with
+# EventBase. This is the least invasive hack to get mypy to comply.
+#
+# Note that DictProperty/DefaultDictProperty cannot actually be used with
+# EventBuilder as it lacks a _dict property.
+_DictPropertyInstance = Union["_EventInternalMetadata", "EventBase", "EventBuilder"]
+
+
+class DictProperty(Generic[T]):
     """An object property which delegates to the `_dict` within its parent object."""
 
     __slots__ = ["key"]
@@ -45,12 +79,33 @@ class DictProperty:
     def __init__(self, key: str):
         self.key = key
 
-    def __get__(self, instance, owner=None):
+    @overload
+    def __get__(
+        self,
+        instance: Literal[None],
+        owner: Optional[Type[_DictPropertyInstance]] = None,
+    ) -> "DictProperty":
+        ...
+
+    @overload
+    def __get__(
+        self,
+        instance: _DictPropertyInstance,
+        owner: Optional[Type[_DictPropertyInstance]] = None,
+    ) -> T:
+        ...
+
+    def __get__(
+        self,
+        instance: Optional[_DictPropertyInstance],
+        owner: Optional[Type[_DictPropertyInstance]] = None,
+    ) -> Union[T, "DictProperty"]:
         # if the property is accessed as a class property rather than an instance
         # property, return the property itself rather than the value
         if instance is None:
             return self
         try:
+            assert isinstance(instance, (EventBase, _EventInternalMetadata))
             return instance._dict[self.key]
         except KeyError as e1:
             # We want this to look like a regular attribute error (mostly so that
@@ -65,10 +120,12 @@ class DictProperty:
                 "'%s' has no '%s' property" % (type(instance), self.key)
             ) from e1.__context__
 
-    def __set__(self, instance, v):
+    def __set__(self, instance: _DictPropertyInstance, v: T) -> None:
+        assert isinstance(instance, (EventBase, _EventInternalMetadata))
         instance._dict[self.key] = v
 
-    def __delete__(self, instance):
+    def __delete__(self, instance: _DictPropertyInstance) -> None:
+        assert isinstance(instance, (EventBase, _EventInternalMetadata))
         try:
             del instance._dict[self.key]
         except KeyError as e1:
@@ -77,7 +134,7 @@ class DictProperty:
             ) from e1.__context__
 
 
-class DefaultDictProperty(DictProperty):
+class DefaultDictProperty(DictProperty, Generic[T]):
     """An extension of DictProperty which provides a default if the property is
     not present in the parent's _dict.
 
@@ -86,13 +143,34 @@ class DefaultDictProperty(DictProperty):
 
     __slots__ = ["default"]
 
-    def __init__(self, key, default):
+    def __init__(self, key: str, default: T):
         super().__init__(key)
         self.default = default
 
-    def __get__(self, instance, owner=None):
+    @overload
+    def __get__(
+        self,
+        instance: Literal[None],
+        owner: Optional[Type[_DictPropertyInstance]] = None,
+    ) -> "DefaultDictProperty":
+        ...
+
+    @overload
+    def __get__(
+        self,
+        instance: _DictPropertyInstance,
+        owner: Optional[Type[_DictPropertyInstance]] = None,
+    ) -> T:
+        ...
+
+    def __get__(
+        self,
+        instance: Optional[_DictPropertyInstance],
+        owner: Optional[Type[_DictPropertyInstance]] = None,
+    ) -> Union[T, "DefaultDictProperty"]:
         if instance is None:
             return self
+        assert isinstance(instance, (EventBase, _EventInternalMetadata))
         return instance._dict.get(self.key, self.default)
 
 
@@ -111,22 +189,22 @@ class _EventInternalMetadata:
         # in the DAG)
         self.outlier = False
 
-    out_of_band_membership: bool = DictProperty("out_of_band_membership")
-    send_on_behalf_of: str = DictProperty("send_on_behalf_of")
-    recheck_redaction: bool = DictProperty("recheck_redaction")
-    soft_failed: bool = DictProperty("soft_failed")
-    proactively_send: bool = DictProperty("proactively_send")
-    redacted: bool = DictProperty("redacted")
-    txn_id: str = DictProperty("txn_id")
-    token_id: int = DictProperty("token_id")
-    historical: bool = DictProperty("historical")
+    out_of_band_membership: DictProperty[bool] = DictProperty("out_of_band_membership")
+    send_on_behalf_of: DictProperty[str] = DictProperty("send_on_behalf_of")
+    recheck_redaction: DictProperty[bool] = DictProperty("recheck_redaction")
+    soft_failed: DictProperty[bool] = DictProperty("soft_failed")
+    proactively_send: DictProperty[bool] = DictProperty("proactively_send")
+    redacted: DictProperty[bool] = DictProperty("redacted")
+    txn_id: DictProperty[str] = DictProperty("txn_id")
+    token_id: DictProperty[int] = DictProperty("token_id")
+    historical: DictProperty[bool] = DictProperty("historical")
 
     # XXX: These are set by StreamWorkerStore._set_before_and_after.
     # I'm pretty sure that these are never persisted to the database, so shouldn't
     # be here
-    before: RoomStreamToken = DictProperty("before")
-    after: RoomStreamToken = DictProperty("after")
-    order: Tuple[int, int] = DictProperty("order")
+    before: DictProperty[RoomStreamToken] = DictProperty("before")
+    after: DictProperty[RoomStreamToken] = DictProperty("after")
+    order: DictProperty[Tuple[int, int]] = DictProperty("order")
 
     def get_dict(self) -> JsonDict:
         return dict(self._dict)
@@ -162,9 +240,6 @@ class _EventInternalMetadata:
 
         If the sender of the redaction event is allowed to redact any event
         due to auth rules, then this will always return false.
-
-        Returns:
-            bool
         """
         return self._dict.get("recheck_redaction", False)
 
@@ -176,32 +251,23 @@ class _EventInternalMetadata:
                sent to clients.
             2. They should not be added to the forward extremities (and
                therefore not to current state).
-
-        Returns:
-            bool
         """
         return self._dict.get("soft_failed", False)
 
-    def should_proactively_send(self):
+    def should_proactively_send(self) -> bool:
         """Whether the event, if ours, should be sent to other clients and
         servers.
 
         This is used for sending dummy events internally. Servers and clients
         can still explicitly fetch the event.
-
-        Returns:
-            bool
         """
         return self._dict.get("proactively_send", True)
 
-    def is_redacted(self):
+    def is_redacted(self) -> bool:
         """Whether the event has been redacted.
 
         This is used for efficiently checking whether an event has been
         marked as redacted without needing to make another database call.
-
-        Returns:
-            bool
         """
         return self._dict.get("redacted", False)
 
@@ -241,29 +307,31 @@ class EventBase(metaclass=abc.ABCMeta):
 
         self.internal_metadata = _EventInternalMetadata(internal_metadata_dict)
 
-    auth_events = DictProperty("auth_events")
-    depth = DictProperty("depth")
-    content = DictProperty("content")
-    hashes = DictProperty("hashes")
-    origin = DictProperty("origin")
-    origin_server_ts = DictProperty("origin_server_ts")
-    prev_events = DictProperty("prev_events")
-    redacts = DefaultDictProperty("redacts", None)
-    room_id = DictProperty("room_id")
-    sender = DictProperty("sender")
-    state_key = DictProperty("state_key")
-    type = DictProperty("type")
-    user_id = DictProperty("sender")
+    depth: DictProperty[int] = DictProperty("depth")
+    content: DictProperty[JsonDict] = DictProperty("content")
+    hashes: DictProperty[Dict[str, str]] = DictProperty("hashes")
+    origin: DictProperty[str] = DictProperty("origin")
+    origin_server_ts: DictProperty[int] = DictProperty("origin_server_ts")
+    redacts: DefaultDictProperty[Optional[str]] = DefaultDictProperty("redacts", None)
+    room_id: DictProperty[str] = DictProperty("room_id")
+    sender: DictProperty[str] = DictProperty("sender")
+    # TODO state_key should be Optional[str], this is generally asserted in Synapse
+    # by calling is_state() first (which ensures this), but it is hard (not possible?)
+    # to properly annotate that calling is_state() asserts that state_key exists
+    # and is non-None.
+    state_key: DictProperty[str] = DictProperty("state_key")
+    type: DictProperty[str] = DictProperty("type")
+    user_id: DictProperty[str] = DictProperty("sender")
 
     @property
     def event_id(self) -> str:
         raise NotImplementedError()
 
     @property
-    def membership(self):
+    def membership(self) -> str:
         return self.content["membership"]
 
-    def is_state(self):
+    def is_state(self) -> bool:
         return hasattr(self, "state_key") and self.state_key is not None
 
     def get_dict(self) -> JsonDict:
@@ -272,13 +340,13 @@ class EventBase(metaclass=abc.ABCMeta):
 
         return d
 
-    def get(self, key, default=None):
+    def get(self, key: str, default: Optional[Any] = None) -> Any:
         return self._dict.get(key, default)
 
-    def get_internal_metadata_dict(self):
+    def get_internal_metadata_dict(self) -> JsonDict:
         return self.internal_metadata.get_dict()
 
-    def get_pdu_json(self, time_now=None) -> JsonDict:
+    def get_pdu_json(self, time_now: Optional[int] = None) -> JsonDict:
         pdu_json = self.get_dict()
 
         if time_now is not None and "age_ts" in pdu_json["unsigned"]:
@@ -305,49 +373,46 @@ class EventBase(metaclass=abc.ABCMeta):
 
         return template_json
 
-    def __set__(self, instance, value):
-        raise AttributeError("Unrecognized attribute %s" % (instance,))
-
-    def __getitem__(self, field):
+    def __getitem__(self, field: str) -> Optional[Any]:
         return self._dict[field]
 
-    def __contains__(self, field):
+    def __contains__(self, field: str) -> bool:
         return field in self._dict
 
-    def items(self):
+    def items(self) -> List[Tuple[str, Optional[Any]]]:
         return list(self._dict.items())
 
-    def keys(self):
+    def keys(self) -> Iterable[str]:
         return self._dict.keys()
 
-    def prev_event_ids(self):
+    def prev_event_ids(self) -> Sequence[str]:
         """Returns the list of prev event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
         Returns:
-            list[str]: The list of event IDs of this event's prev_events
+            The list of event IDs of this event's prev_events
         """
-        return [e for e, _ in self.prev_events]
+        return [e for e, _ in self._dict["prev_events"]]
 
-    def auth_event_ids(self):
+    def auth_event_ids(self) -> Sequence[str]:
         """Returns the list of auth event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
         Returns:
-            list[str]: The list of event IDs of this event's auth_events
+            The list of event IDs of this event's auth_events
         """
-        return [e for e, _ in self.auth_events]
+        return [e for e, _ in self._dict["auth_events"]]
 
-    def freeze(self):
+    def freeze(self) -> None:
         """'Freeze' the event dict, so it cannot be modified by accident"""
 
         # this will be a no-op if the event dict is already frozen.
         self._dict = freeze(self._dict)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.__repr__()
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         rejection = f"REJECTED={self.rejected_reason}, " if self.rejected_reason else ""
 
         return (
@@ -443,7 +508,7 @@ class FrozenEventV2(EventBase):
         else:
             frozen_dict = event_dict
 
-        self._event_id = None
+        self._event_id: Optional[str] = None
 
         super().__init__(
             frozen_dict,
@@ -455,7 +520,7 @@ class FrozenEventV2(EventBase):
         )
 
     @property
-    def event_id(self):
+    def event_id(self) -> str:
         # We have to import this here as otherwise we get an import loop which
         # is hard to break.
         from synapse.crypto.event_signing import compute_event_reference_hash
@@ -465,23 +530,23 @@ class FrozenEventV2(EventBase):
         self._event_id = "$" + encode_base64(compute_event_reference_hash(self)[1])
         return self._event_id
 
-    def prev_event_ids(self):
+    def prev_event_ids(self) -> Sequence[str]:
         """Returns the list of prev event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
         Returns:
-            list[str]: The list of event IDs of this event's prev_events
+            The list of event IDs of this event's prev_events
         """
-        return self.prev_events
+        return self._dict["prev_events"]
 
-    def auth_event_ids(self):
+    def auth_event_ids(self) -> Sequence[str]:
         """Returns the list of auth event IDs. The order matches the order
         specified in the event, though there is no meaning to it.
 
         Returns:
-            list[str]: The list of event IDs of this event's auth_events
+            The list of event IDs of this event's auth_events
         """
-        return self.auth_events
+        return self._dict["auth_events"]
 
 
 class FrozenEventV3(FrozenEventV2):
@@ -490,7 +555,7 @@ class FrozenEventV3(FrozenEventV2):
     format_version = EventFormatVersions.V3  # All events of this type are V3
 
     @property
-    def event_id(self):
+    def event_id(self) -> str:
         # We have to import this here as otherwise we get an import loop which
         # is hard to break.
         from synapse.crypto.event_signing import compute_event_reference_hash
@@ -503,12 +568,14 @@ class FrozenEventV3(FrozenEventV2):
         return self._event_id
 
 
-def _event_type_from_format_version(format_version: int) -> Type[EventBase]:
+def _event_type_from_format_version(
+    format_version: int,
+) -> Type[Union[FrozenEvent, FrozenEventV2, FrozenEventV3]]:
     """Returns the python type to use to construct an Event object for the
     given event format version.
 
     Args:
-        format_version (int): The event format version
+        format_version: The event format version
 
     Returns:
         type: A type that can be initialized as per the initializer of
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index 2a6dabdab6..1bb8ca7145 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -14,7 +14,7 @@
 import logging
 from typing import TYPE_CHECKING, Any, Awaitable, Callable, List, Optional, Tuple
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import ModuleFailedException, SynapseError
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.types import Requester, StateMap
@@ -36,6 +36,7 @@ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
 CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
     [str, StateMap[EventBase], str], Awaitable[bool]
 ]
+ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
 
 
 def load_legacy_third_party_event_rules(hs: "HomeServer") -> None:
@@ -152,6 +153,7 @@ class ThirdPartyEventRules:
         self._check_visibility_can_be_modified_callbacks: List[
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = []
+        self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = []
 
     def register_third_party_rules_callbacks(
         self,
@@ -163,6 +165,7 @@ class ThirdPartyEventRules:
         check_visibility_can_be_modified: Optional[
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = None,
+        on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
     ) -> None:
         """Register callbacks from modules for each hook."""
         if check_event_allowed is not None:
@@ -181,6 +184,9 @@ class ThirdPartyEventRules:
                 check_visibility_can_be_modified,
             )
 
+        if on_new_event is not None:
+            self._on_new_event_callbacks.append(on_new_event)
+
     async def check_event_allowed(
         self, event: EventBase, context: EventContext
     ) -> Tuple[bool, Optional[dict]]:
@@ -227,9 +233,10 @@ class ThirdPartyEventRules:
                 # This module callback needs a rework so that hacks such as
                 # this one are not necessary.
                 raise e
-            except Exception as e:
-                logger.warning("Failed to run module API callback %s: %s", callback, e)
-                continue
+            except Exception:
+                raise ModuleFailedException(
+                    "Failed to run `check_event_allowed` module API callback"
+                )
 
             # Return if the event shouldn't be allowed or if the module came up with a
             # replacement dict for the event.
@@ -321,6 +328,31 @@ class ThirdPartyEventRules:
 
         return True
 
+    async def on_new_event(self, event_id: str) -> None:
+        """Let modules act on events after they've been sent (e.g. auto-accepting
+        invites, etc.)
+
+        Args:
+            event_id: The ID of the event.
+
+        Raises:
+            ModuleFailureError if a callback raised any exception.
+        """
+        # Bail out early without hitting the store if we don't have any callbacks
+        if len(self._on_new_event_callbacks) == 0:
+            return
+
+        event = await self.store.get_event(event_id)
+        state_events = await self._get_state_map_for_room(event.room_id)
+
+        for callback in self._on_new_event_callbacks:
+            try:
+                await callback(event, state_events)
+            except Exception as e:
+                logger.exception(
+                    "Failed to run module API callback %s: %s", callback, e
+                )
+
     async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
         """Given a room ID, return the state events of that room.
 
diff --git a/synapse/events/validator.py b/synapse/events/validator.py
index 4d459c17f1..cf86934968 100644
--- a/synapse/events/validator.py
+++ b/synapse/events/validator.py
@@ -55,7 +55,7 @@ class EventValidator:
         ]
 
         for k in required:
-            if not hasattr(event, k):
+            if k not in event:
                 raise SynapseError(400, "Event does not have key %s" % (k,))
 
         # Check that the following keys have string values