summary refs log tree commit diff
path: root/synapse/events
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events')
-rw-r--r--synapse/events/__init__.py45
-rw-r--r--synapse/events/snapshot.py180
-rw-r--r--synapse/events/spamcheck.py81
3 files changed, 137 insertions, 169 deletions
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index c238376caf..39ad2793d9 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import abc
+import collections.abc
 import os
 from typing import (
     TYPE_CHECKING,
@@ -32,9 +33,11 @@ from typing import (
     overload,
 )
 
+import attr
 from typing_extensions import Literal
 from unpaddedbase64 import encode_base64
 
+from synapse.api.constants import RelationTypes
 from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
 from synapse.types import JsonDict, RoomStreamToken
 from synapse.util.caches import intern_dict
@@ -615,3 +618,45 @@ def make_event_from_dict(
     return event_type(
         event_dict, room_version, internal_metadata_dict or {}, rejected_reason
     )
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventRelation:
+    # The target event of the relation.
+    parent_id: str
+    # The relation type.
+    rel_type: str
+    # The aggregation key. Will be None if the rel_type is not m.annotation or is
+    # not a string.
+    aggregation_key: Optional[str]
+
+
+def relation_from_event(event: EventBase) -> Optional[_EventRelation]:
+    """
+    Attempt to parse relation information an event.
+
+    Returns:
+        The event relation information, if it is valid. None, otherwise.
+    """
+    relation = event.content.get("m.relates_to")
+    if not relation or not isinstance(relation, collections.abc.Mapping):
+        # No relation information.
+        return None
+
+    # Relations must have a type and parent event ID.
+    rel_type = relation.get("rel_type")
+    if not isinstance(rel_type, str):
+        return None
+
+    parent_id = relation.get("event_id")
+    if not isinstance(parent_id, str):
+        return None
+
+    # Annotations have a key field.
+    aggregation_key = None
+    if rel_type == RelationTypes.ANNOTATION:
+        aggregation_key = relation.get("key")
+        if not isinstance(aggregation_key, str):
+            aggregation_key = None
+
+    return _EventRelation(parent_id, rel_type, aggregation_key)
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index 46042b2bf7..9ccd24b298 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -15,12 +15,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 
 import attr
 from frozendict import frozendict
-
-from twisted.internet.defer import Deferred
+from typing_extensions import Literal
 
 from synapse.appservice import ApplicationService
 from synapse.events import EventBase
-from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.types import JsonDict, StateMap
 
 if TYPE_CHECKING:
@@ -60,6 +58,9 @@ class EventContext:
             If ``state_group`` is None (ie, the event is an outlier),
             ``state_group_before_event`` will always also be ``None``.
 
+        state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
+            then this is the delta of the state between the two groups.
+
         prev_group: If it is known, ``state_group``'s prev_group. Note that this being
             None does not necessarily mean that ``state_group`` does not have
             a prev_group!
@@ -78,73 +79,47 @@ class EventContext:
         app_service: If this event is being sent by a (local) application service, that
             app service.
 
-        _current_state_ids: The room state map, including this event - ie, the state
-            in ``state_group``.
-
-            (type, state_key) -> event_id
-
-            For an outlier, this is {}
-
-            Note that this is a private attribute: it should be accessed via
-            ``get_current_state_ids``. _AsyncEventContext impl calculates this
-            on-demand: it will be None until that happens.
-
-        _prev_state_ids: The room state map, excluding this event - ie, the state
-            in ``state_group_before_event``. For a non-state
-            event, this will be the same as _current_state_events.
-
-            Note that it is a completely different thing to prev_group!
-
-            (type, state_key) -> event_id
-
-            For an outlier, this is {}
-
-            As with _current_state_ids, this is a private attribute. It should be
-            accessed via get_prev_state_ids.
-
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
 
-    rejected: Union[bool, str] = False
+    _storage: "Storage"
+    rejected: Union[Literal[False], str] = False
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
+    _state_delta_due_to_event: Optional[StateMap[str]] = None
     prev_group: Optional[int] = None
     delta_ids: Optional[StateMap[str]] = None
     app_service: Optional[ApplicationService] = None
 
-    _current_state_ids: Optional[StateMap[str]] = None
-    _prev_state_ids: Optional[StateMap[str]] = None
-
     partial_state: bool = False
 
     @staticmethod
     def with_state(
+        storage: "Storage",
         state_group: Optional[int],
         state_group_before_event: Optional[int],
-        current_state_ids: Optional[StateMap[str]],
-        prev_state_ids: Optional[StateMap[str]],
+        state_delta_due_to_event: Optional[StateMap[str]],
         partial_state: bool,
         prev_group: Optional[int] = None,
         delta_ids: Optional[StateMap[str]] = None,
     ) -> "EventContext":
         return EventContext(
-            current_state_ids=current_state_ids,
-            prev_state_ids=prev_state_ids,
+            storage=storage,
             state_group=state_group,
             state_group_before_event=state_group_before_event,
+            state_delta_due_to_event=state_delta_due_to_event,
             prev_group=prev_group,
             delta_ids=delta_ids,
             partial_state=partial_state,
         )
 
     @staticmethod
-    def for_outlier() -> "EventContext":
+    def for_outlier(
+        storage: "Storage",
+    ) -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
-        return EventContext(
-            current_state_ids={},
-            prev_state_ids={},
-        )
+        return EventContext(storage=storage)
 
     async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
         """Converts self to a type that can be serialized as JSON, and then
@@ -157,24 +132,14 @@ class EventContext:
             The serialized event.
         """
 
-        # We don't serialize the full state dicts, instead they get pulled out
-        # of the DB on the other side. However, the other side can't figure out
-        # the prev_state_ids, so if we're a state event we include the event
-        # id that we replaced in the state.
-        if event.is_state():
-            prev_state_ids = await self.get_prev_state_ids()
-            prev_state_id = prev_state_ids.get((event.type, event.state_key))
-        else:
-            prev_state_id = None
-
         return {
-            "prev_state_id": prev_state_id,
-            "event_type": event.type,
-            "event_state_key": event.get_state_key(),
             "state_group": self._state_group,
             "state_group_before_event": self.state_group_before_event,
             "rejected": self.rejected,
             "prev_group": self.prev_group,
+            "state_delta_due_to_event": _encode_state_dict(
+                self._state_delta_due_to_event
+            ),
             "delta_ids": _encode_state_dict(self.delta_ids),
             "app_service_id": self.app_service.id if self.app_service else None,
             "partial_state": self.partial_state,
@@ -192,16 +157,16 @@ class EventContext:
         Returns:
             The event context.
         """
-        context = _AsyncEventContextImpl(
+        context = EventContext(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
             storage=storage,
-            prev_state_id=input["prev_state_id"],
-            event_type=input["event_type"],
-            event_state_key=input["event_state_key"],
             state_group=input["state_group"],
             state_group_before_event=input["state_group_before_event"],
             prev_group=input["prev_group"],
+            state_delta_due_to_event=_decode_state_dict(
+                input["state_delta_due_to_event"]
+            ),
             delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
             partial_state=input.get("partial_state", False),
@@ -249,8 +214,15 @@ class EventContext:
         if self.rejected:
             raise RuntimeError("Attempt to access state_ids of rejected event")
 
-        await self._ensure_fetched()
-        return self._current_state_ids
+        assert self._state_delta_due_to_event is not None
+
+        prev_state_ids = await self.get_prev_state_ids()
+
+        if self._state_delta_due_to_event:
+            prev_state_ids = dict(prev_state_ids)
+            prev_state_ids.update(self._state_delta_due_to_event)
+
+        return prev_state_ids
 
     async def get_prev_state_ids(self) -> StateMap[str]:
         """
@@ -265,94 +237,10 @@ class EventContext:
             Maps a (type, state_key) to the event ID of the state event matching
             this tuple.
         """
-        await self._ensure_fetched()
-        # There *should* be previous state IDs now.
-        assert self._prev_state_ids is not None
-        return self._prev_state_ids
-
-    def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
-        """Gets the current state IDs if we have them already cached.
-
-        It is an error to access this for a rejected event, since rejected state should
-        not make it into the room state. This method will raise an exception if
-        ``rejected`` is set.
-
-        Returns:
-            Returns None if we haven't cached the state or if state_group is None
-            (which happens when the associated event is an outlier).
-
-            Otherwise, returns the the current state IDs.
-        """
-        if self.rejected:
-            raise RuntimeError("Attempt to access state_ids of rejected event")
-
-        return self._current_state_ids
-
-    async def _ensure_fetched(self) -> None:
-        return None
-
-
-@attr.s(slots=True)
-class _AsyncEventContextImpl(EventContext):
-    """
-    An implementation of EventContext which fetches _current_state_ids and
-    _prev_state_ids from the database on demand.
-
-    Attributes:
-
-        _storage
-
-        _fetching_state_deferred: Resolves when *_state_ids have been calculated.
-            None if we haven't started calculating yet
-
-        _event_type: The type of the event the context is associated with.
-
-        _event_state_key: The state_key of the event the context is associated with.
-
-        _prev_state_id: If the event associated with the context is a state event,
-            then `_prev_state_id` is the event_id of the state that was replaced.
-    """
-
-    # This needs to have a default as we're inheriting
-    _storage: "Storage" = attr.ib(default=None)
-    _prev_state_id: Optional[str] = attr.ib(default=None)
-    _event_type: str = attr.ib(default=None)
-    _event_state_key: Optional[str] = attr.ib(default=None)
-    _fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
-
-    async def _ensure_fetched(self) -> None:
-        if not self._fetching_state_deferred:
-            self._fetching_state_deferred = run_in_background(self._fill_out_state)
-
-        await make_deferred_yieldable(self._fetching_state_deferred)
-
-    async def _fill_out_state(self) -> None:
-        """Called to populate the _current_state_ids and _prev_state_ids
-        attributes by loading from the database.
-        """
-        if self.state_group is None:
-            # No state group means the event is an outlier. Usually the state_ids dicts are also
-            # pre-set to empty dicts, but they get reset when the context is serialized, so set
-            # them to empty dicts again here.
-            self._current_state_ids = {}
-            self._prev_state_ids = {}
-            return
-
-        current_state_ids = await self._storage.state.get_state_ids_for_group(
-            self.state_group
+        assert self.state_group_before_event is not None
+        return await self._storage.state.get_state_ids_for_group(
+            self.state_group_before_event
         )
-        # Set this separately so mypy knows current_state_ids is not None.
-        self._current_state_ids = current_state_ids
-        if self._event_state_key is not None:
-            self._prev_state_ids = dict(current_state_ids)
-
-            key = (self._event_type, self._event_state_key)
-            if self._prev_state_id:
-                self._prev_state_ids[key] = self._prev_state_id
-            else:
-                self._prev_state_ids.pop(key, None)
-        else:
-            self._prev_state_ids = current_state_ids
 
 
 def _encode_state_dict(
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 3b6795d40f..f30207376a 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -32,6 +32,7 @@ from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
 from synapse.types import RoomAlias, UserProfile
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
+from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
     import synapse.events
@@ -162,7 +163,10 @@ def load_legacy_spam_checkers(hs: "synapse.server.HomeServer") -> None:
 
 
 class SpamChecker:
-    def __init__(self) -> None:
+    def __init__(self, hs: "synapse.server.HomeServer") -> None:
+        self.hs = hs
+        self.clock = hs.get_clock()
+
         self._check_event_for_spam_callbacks: List[CHECK_EVENT_FOR_SPAM_CALLBACK] = []
         self._user_may_join_room_callbacks: List[USER_MAY_JOIN_ROOM_CALLBACK] = []
         self._user_may_invite_callbacks: List[USER_MAY_INVITE_CALLBACK] = []
@@ -255,7 +259,10 @@ class SpamChecker:
             will be used as the error message returned to the user.
         """
         for callback in self._check_event_for_spam_callbacks:
-            res: Union[bool, str] = await delay_cancellation(callback(event))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                res: Union[bool, str] = await delay_cancellation(callback(event))
             if res:
                 return res
 
@@ -276,9 +283,12 @@ class SpamChecker:
             Whether the user may join the room
         """
         for callback in self._user_may_join_room_callbacks:
-            may_join_room = await delay_cancellation(
-                callback(user_id, room_id, is_invited)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_join_room = await delay_cancellation(
+                    callback(user_id, room_id, is_invited)
+                )
             if may_join_room is False:
                 return False
 
@@ -300,9 +310,12 @@ class SpamChecker:
             True if the user may send an invite, otherwise False
         """
         for callback in self._user_may_invite_callbacks:
-            may_invite = await delay_cancellation(
-                callback(inviter_userid, invitee_userid, room_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_invite = await delay_cancellation(
+                    callback(inviter_userid, invitee_userid, room_id)
+                )
             if may_invite is False:
                 return False
 
@@ -328,9 +341,12 @@ class SpamChecker:
             True if the user may send the invite, otherwise False
         """
         for callback in self._user_may_send_3pid_invite_callbacks:
-            may_send_3pid_invite = await delay_cancellation(
-                callback(inviter_userid, medium, address, room_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_send_3pid_invite = await delay_cancellation(
+                    callback(inviter_userid, medium, address, room_id)
+                )
             if may_send_3pid_invite is False:
                 return False
 
@@ -348,7 +364,10 @@ class SpamChecker:
             True if the user may create a room, otherwise False
         """
         for callback in self._user_may_create_room_callbacks:
-            may_create_room = await delay_cancellation(callback(userid))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_create_room = await delay_cancellation(callback(userid))
             if may_create_room is False:
                 return False
 
@@ -369,9 +388,12 @@ class SpamChecker:
             True if the user may create a room alias, otherwise False
         """
         for callback in self._user_may_create_room_alias_callbacks:
-            may_create_room_alias = await delay_cancellation(
-                callback(userid, room_alias)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_create_room_alias = await delay_cancellation(
+                    callback(userid, room_alias)
+                )
             if may_create_room_alias is False:
                 return False
 
@@ -390,7 +412,10 @@ class SpamChecker:
             True if the user may publish the room, otherwise False
         """
         for callback in self._user_may_publish_room_callbacks:
-            may_publish_room = await delay_cancellation(callback(userid, room_id))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                may_publish_room = await delay_cancellation(callback(userid, room_id))
             if may_publish_room is False:
                 return False
 
@@ -412,9 +437,13 @@ class SpamChecker:
             True if the user is spammy.
         """
         for callback in self._check_username_for_spam_callbacks:
-            # Make a copy of the user profile object to ensure the spam checker cannot
-            # modify it.
-            if await delay_cancellation(callback(user_profile.copy())):
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                # Make a copy of the user profile object to ensure the spam checker cannot
+                # modify it.
+                res = await delay_cancellation(callback(user_profile.copy()))
+            if res:
                 return True
 
         return False
@@ -442,9 +471,12 @@ class SpamChecker:
         """
 
         for callback in self._check_registration_for_spam_callbacks:
-            behaviour = await delay_cancellation(
-                callback(email_threepid, username, request_info, auth_provider_id)
-            )
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                behaviour = await delay_cancellation(
+                    callback(email_threepid, username, request_info, auth_provider_id)
+                )
             assert isinstance(behaviour, RegistrationBehaviour)
             if behaviour != RegistrationBehaviour.ALLOW:
                 return behaviour
@@ -486,7 +518,10 @@ class SpamChecker:
         """
 
         for callback in self._check_media_file_for_spam_callbacks:
-            spam = await delay_cancellation(callback(file_wrapper, file_info))
+            with Measure(
+                self.clock, "{}.{}".format(callback.__module__, callback.__qualname__)
+            ):
+                spam = await delay_cancellation(callback(file_wrapper, file_info))
             if spam:
                 return True