summary refs log tree commit diff
path: root/synapse/events
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events')
-rw-r--r--synapse/events/spamcheck.py11
-rw-r--r--synapse/events/third_party_rules.py65
-rw-r--r--synapse/events/utils.py101
3 files changed, 136 insertions, 41 deletions
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 04afd48274..cd80fcf9d1 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -21,7 +21,6 @@ from typing import (
     Awaitable,
     Callable,
     Collection,
-    Dict,
     List,
     Optional,
     Tuple,
@@ -31,7 +30,7 @@ from typing import (
 from synapse.rest.media.v1._base import FileInfo
 from synapse.rest.media.v1.media_storage import ReadableFileWrapper
 from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.types import RoomAlias
+from synapse.types import RoomAlias, UserProfile
 from synapse.util.async_helpers import maybe_awaitable
 
 if TYPE_CHECKING:
@@ -50,7 +49,7 @@ USER_MAY_SEND_3PID_INVITE_CALLBACK = Callable[[str, str, str, str], Awaitable[bo
 USER_MAY_CREATE_ROOM_CALLBACK = Callable[[str], Awaitable[bool]]
 USER_MAY_CREATE_ROOM_ALIAS_CALLBACK = Callable[[str, RoomAlias], Awaitable[bool]]
 USER_MAY_PUBLISH_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
-CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[Dict[str, str]], Awaitable[bool]]
+CHECK_USERNAME_FOR_SPAM_CALLBACK = Callable[[UserProfile], Awaitable[bool]]
 LEGACY_CHECK_REGISTRATION_FOR_SPAM_CALLBACK = Callable[
     [
         Optional[dict],
@@ -245,8 +244,8 @@ class SpamChecker:
         """Checks if a given event is considered "spammy" by this server.
 
         If the server considers an event spammy, then it will be rejected if
-        sent by a local user. If it is sent by a user on another server, then
-        users receive a blank event.
+        sent by a local user. If it is sent by a user on another server, the
+        event is soft-failed.
 
         Args:
             event: the event to be checked
@@ -383,7 +382,7 @@ class SpamChecker:
 
         return True
 
-    async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+    async def check_username_for_spam(self, user_profile: UserProfile) -> bool:
         """Checks if a user ID or display name are considered "spammy" by this server.
 
         If the server considers a username spammy, then it will not be included in
diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py
index dd3104faf3..bfca454f51 100644
--- a/synapse/events/third_party_rules.py
+++ b/synapse/events/third_party_rules.py
@@ -38,6 +38,8 @@ CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK = Callable[
     [str, StateMap[EventBase], str], Awaitable[bool]
 ]
 ON_NEW_EVENT_CALLBACK = Callable[[EventBase, StateMap[EventBase]], Awaitable]
+CHECK_CAN_SHUTDOWN_ROOM_CALLBACK = Callable[[str, str], Awaitable[bool]]
+CHECK_CAN_DEACTIVATE_USER_CALLBACK = Callable[[str, bool], Awaitable[bool]]
 ON_PROFILE_UPDATE_CALLBACK = Callable[[str, ProfileInfo, bool, bool], Awaitable]
 ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK = Callable[[str, bool, bool], Awaitable]
 
@@ -157,6 +159,12 @@ class ThirdPartyEventRules:
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = []
         self._on_new_event_callbacks: List[ON_NEW_EVENT_CALLBACK] = []
+        self._check_can_shutdown_room_callbacks: List[
+            CHECK_CAN_SHUTDOWN_ROOM_CALLBACK
+        ] = []
+        self._check_can_deactivate_user_callbacks: List[
+            CHECK_CAN_DEACTIVATE_USER_CALLBACK
+        ] = []
         self._on_profile_update_callbacks: List[ON_PROFILE_UPDATE_CALLBACK] = []
         self._on_user_deactivation_status_changed_callbacks: List[
             ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
@@ -173,8 +181,12 @@ class ThirdPartyEventRules:
             CHECK_VISIBILITY_CAN_BE_MODIFIED_CALLBACK
         ] = None,
         on_new_event: Optional[ON_NEW_EVENT_CALLBACK] = None,
+        check_can_shutdown_room: Optional[CHECK_CAN_SHUTDOWN_ROOM_CALLBACK] = None,
+        check_can_deactivate_user: Optional[CHECK_CAN_DEACTIVATE_USER_CALLBACK] = None,
         on_profile_update: Optional[ON_PROFILE_UPDATE_CALLBACK] = None,
-        on_deactivation: Optional[ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK] = None,
+        on_user_deactivation_status_changed: Optional[
+            ON_USER_DEACTIVATION_STATUS_CHANGED_CALLBACK
+        ] = None,
     ) -> None:
         """Register callbacks from modules for each hook."""
         if check_event_allowed is not None:
@@ -196,11 +208,18 @@ class ThirdPartyEventRules:
         if on_new_event is not None:
             self._on_new_event_callbacks.append(on_new_event)
 
+        if check_can_shutdown_room is not None:
+            self._check_can_shutdown_room_callbacks.append(check_can_shutdown_room)
+
+        if check_can_deactivate_user is not None:
+            self._check_can_deactivate_user_callbacks.append(check_can_deactivate_user)
         if on_profile_update is not None:
             self._on_profile_update_callbacks.append(on_profile_update)
 
-        if on_deactivation is not None:
-            self._on_user_deactivation_status_changed_callbacks.append(on_deactivation)
+        if on_user_deactivation_status_changed is not None:
+            self._on_user_deactivation_status_changed_callbacks.append(
+                on_user_deactivation_status_changed,
+            )
 
     async def check_event_allowed(
         self, event: EventBase, context: EventContext
@@ -365,6 +384,46 @@ class ThirdPartyEventRules:
                     "Failed to run module API callback %s: %s", callback, e
                 )
 
+    async def check_can_shutdown_room(self, user_id: str, room_id: str) -> bool:
+        """Intercept requests to shutdown a room. If `False` is returned, the
+         room must not be shut down.
+
+        Args:
+            requester: The ID of the user requesting the shutdown.
+            room_id: The ID of the room.
+        """
+        for callback in self._check_can_shutdown_room_callbacks:
+            try:
+                if await callback(user_id, room_id) is False:
+                    return False
+            except Exception as e:
+                logger.exception(
+                    "Failed to run module API callback %s: %s", callback, e
+                )
+        return True
+
+    async def check_can_deactivate_user(
+        self,
+        user_id: str,
+        by_admin: bool,
+    ) -> bool:
+        """Intercept requests to deactivate a user. If `False` is returned, the
+        user should not be deactivated.
+
+        Args:
+            requester
+            user_id: The ID of the room.
+        """
+        for callback in self._check_can_deactivate_user_callbacks:
+            try:
+                if await callback(user_id, by_admin) is False:
+                    return False
+            except Exception as e:
+                logger.exception(
+                    "Failed to run module API callback %s: %s", callback, e
+                )
+        return True
+
     async def _get_state_map_for_room(self, room_id: str) -> StateMap[EventBase]:
         """Given a room ID, return the state events of that room.
 
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9386fa29dd..7120062127 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -26,6 +26,7 @@ from typing import (
     Union,
 )
 
+import attr
 from frozendict import frozendict
 
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
@@ -37,7 +38,8 @@ from synapse.util.frozenutils import unfreeze
 from . import EventBase
 
 if TYPE_CHECKING:
-    from synapse.storage.databases.main.relations import BundledAggregations
+    from synapse.handlers.relations import BundledAggregations
+    from synapse.server import HomeServer
 
 
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
@@ -303,29 +305,37 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
     return d
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SerializeEventConfig:
+    as_client_event: bool = True
+    # Function to convert from federation format to client format
+    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1
+    # ID of the user's auth token - used for namespacing of transaction IDs
+    token_id: Optional[int] = None
+    # List of event fields to include. If empty, all fields will be returned.
+    only_event_fields: Optional[List[str]] = None
+    # Some events can have stripped room state stored in the `unsigned` field.
+    # This is required for invite and knock functionality. If this option is
+    # False, that state will be removed from the event before it is returned.
+    # Otherwise, it will be kept.
+    include_stripped_room_state: bool = False
+
+
+_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
+
+
 def serialize_event(
     e: Union[JsonDict, EventBase],
     time_now_ms: int,
     *,
-    as_client_event: bool = True,
-    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
-    token_id: Optional[str] = None,
-    only_event_fields: Optional[List[str]] = None,
-    include_stripped_room_state: bool = False,
+    config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
 ) -> JsonDict:
     """Serialize event for clients
 
     Args:
         e
         time_now_ms
-        as_client_event
-        event_format
-        token_id
-        only_event_fields
-        include_stripped_room_state: Some events can have stripped room state
-            stored in the `unsigned` field. This is required for invite and knock
-            functionality. If this option is False, that state will be removed from the
-            event before it is returned. Otherwise, it will be kept.
+        config: Event serialization config
 
     Returns:
         The serialized event dictionary.
@@ -348,11 +358,11 @@ def serialize_event(
 
     if "redacted_because" in e.unsigned:
         d["unsigned"]["redacted_because"] = serialize_event(
-            e.unsigned["redacted_because"], time_now_ms, event_format=event_format
+            e.unsigned["redacted_because"], time_now_ms, config=config
         )
 
-    if token_id is not None:
-        if token_id == getattr(e.internal_metadata, "token_id", None):
+    if config.token_id is not None:
+        if config.token_id == getattr(e.internal_metadata, "token_id", None):
             txn_id = getattr(e.internal_metadata, "txn_id", None)
             if txn_id is not None:
                 d["unsigned"]["transaction_id"] = txn_id
@@ -361,13 +371,14 @@ def serialize_event(
     # that are meant to provide metadata about a room to an invitee/knocker. They are
     # intended to only be included in specific circumstances, such as down sync, and
     # should not be included in any other case.
-    if not include_stripped_room_state:
+    if not config.include_stripped_room_state:
         d["unsigned"].pop("invite_room_state", None)
         d["unsigned"].pop("knock_room_state", None)
 
-    if as_client_event:
-        d = event_format(d)
+    if config.as_client_event:
+        d = config.event_format(d)
 
+    only_event_fields = config.only_event_fields
     if only_event_fields:
         if not isinstance(only_event_fields, list) or not all(
             isinstance(f, str) for f in only_event_fields
@@ -385,23 +396,26 @@ class EventClientSerializer:
     clients.
     """
 
+    def __init__(self, hs: "HomeServer"):
+        self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
     def serialize_event(
         self,
         event: Union[JsonDict, EventBase],
         time_now: int,
         *,
+        config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
         bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
-        **kwargs: Any,
     ) -> JsonDict:
         """Serializes a single event.
 
         Args:
             event: The event being serialized.
             time_now: The current time in milliseconds
+            config: Event serialization config
             bundle_aggregations: Whether to include the bundled aggregations for this
                 event. Only applies to non-state events. (State events never include
                 bundled aggregations.)
-            **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
             The serialized event
@@ -410,7 +424,7 @@ class EventClientSerializer:
         if not isinstance(event, EventBase):
             return event
 
-        serialized_event = serialize_event(event, time_now, **kwargs)
+        serialized_event = serialize_event(event, time_now, config=config)
 
         # Check if there are any bundled aggregations to include with the event.
         if bundle_aggregations:
@@ -419,6 +433,7 @@ class EventClientSerializer:
                 self._inject_bundled_aggregations(
                     event,
                     time_now,
+                    config,
                     bundle_aggregations[event.event_id],
                     serialized_event,
                 )
@@ -456,6 +471,7 @@ class EventClientSerializer:
         self,
         event: EventBase,
         time_now: int,
+        config: SerializeEventConfig,
         aggregations: "BundledAggregations",
         serialized_event: JsonDict,
     ) -> None:
@@ -466,6 +482,7 @@ class EventClientSerializer:
             time_now: The current time in milliseconds
             aggregations: The bundled aggregation to serialize.
             serialized_event: The serialized event which may be modified.
+            config: Event serialization config
 
         """
         serialized_aggregations = {}
@@ -493,8 +510,8 @@ class EventClientSerializer:
             thread = aggregations.thread
 
             # Don't bundle aggregations as this could recurse forever.
-            serialized_latest_event = self.serialize_event(
-                thread.latest_event, time_now, bundle_aggregations=None
+            serialized_latest_event = serialize_event(
+                thread.latest_event, time_now, config=config
             )
             # Manually apply an edit, if one exists.
             if thread.latest_edit:
@@ -502,33 +519,53 @@ class EventClientSerializer:
                     thread.latest_event, serialized_latest_event, thread.latest_edit
                 )
 
-            serialized_aggregations[RelationTypes.THREAD] = {
+            thread_summary = {
                 "latest_event": serialized_latest_event,
                 "count": thread.count,
                 "current_user_participated": thread.current_user_participated,
             }
+            serialized_aggregations[RelationTypes.THREAD] = thread_summary
+            if self._msc3440_enabled:
+                serialized_aggregations[RelationTypes.UNSTABLE_THREAD] = thread_summary
 
         # Include the bundled aggregations in the event.
         if serialized_aggregations:
-            serialized_event["unsigned"].setdefault("m.relations", {}).update(
-                serialized_aggregations
-            )
+            # There is likely already an "unsigned" field, but a filter might
+            # have stripped it off (via the event_fields option). The server is
+            # allowed to return additional fields, so add it back.
+            serialized_event.setdefault("unsigned", {}).setdefault(
+                "m.relations", {}
+            ).update(serialized_aggregations)
 
     def serialize_events(
-        self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
+        self,
+        events: Iterable[Union[JsonDict, EventBase]],
+        time_now: int,
+        *,
+        config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
+        bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
     ) -> List[JsonDict]:
         """Serializes multiple events.
 
         Args:
             event
             time_now: The current time in milliseconds
-            **kwargs: Arguments to pass to `serialize_event`
+            config: Event serialization config
+            bundle_aggregations: Whether to include the bundled aggregations for this
+                event. Only applies to non-state events. (State events never include
+                bundled aggregations.)
 
         Returns:
             The list of serialized events
         """
         return [
-            self.serialize_event(event, time_now=time_now, **kwargs) for event in events
+            self.serialize_event(
+                event,
+                time_now,
+                config=config,
+                bundle_aggregations=bundle_aggregations,
+            )
+            for event in events
         ]