summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/message.py218
1 files changed, 174 insertions, 44 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index bb77f6210c..ee373ea2ac 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -61,6 +61,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import (
     MutableStateMap,
+    PersistedEventPosition,
     Requester,
     RoomAlias,
     StateMap,
@@ -1316,6 +1317,124 @@ class EventCreationHandler:
                     400, "Cannot start threads from an event with a relation"
                 )
 
+    async def handle_create_room_events(
+        self,
+        requester: Requester,
+        events_and_ctx: List[Tuple[EventBase, EventContext]],
+        ratelimit: bool = True,
+    ) -> EventBase:
+        """
+        Process a batch of room creation events. For each event in the list it checks
+        the authorization and that the event can be serialized. Returns the last event in the
+        list once it has been persisted.
+        Args:
+            requester: the room creator
+            events_and_ctx: a set of events and their associated contexts to persist
+            ratelimit: whether to ratelimit this request
+        """
+        for event, context in events_and_ctx:
+            try:
+                validate_event_for_room_version(event)
+                await self._event_auth_handler.check_auth_rules_from_context(
+                    event, context
+                )
+            except AuthError as err:
+                logger.warning("Denying new event %r because %s", event, err)
+                raise err
+
+            # Ensure that we can round trip before trying to persist in db
+            try:
+                dump = json_encoder.encode(event.content)
+                json_decoder.decode(dump)
+            except Exception:
+                logger.exception("Failed to encode content: %r", event.content)
+                raise
+
+        # We now persist the events
+        try:
+            result = await self._persist_events_batch(
+                requester, events_and_ctx, ratelimit
+            )
+        except Exception as e:
+            logger.info(f"Encountered an error persisting events: {e}")
+
+        return result
+
+    async def _persist_events_batch(
+        self,
+        requestor: Requester,
+        events_and_ctx: List[Tuple[EventBase, EventContext]],
+        ratelimit: bool = True,
+    ) -> EventBase:
+        """
+        Processes the push actions and adds them to the push staging area before attempting to
+        persist the batch of events.
+        See handle_create_room_events for arguments
+        Returns the last event in the list if persisted successfully
+        """
+        for event, context in events_and_ctx:
+            with opentracing.start_active_span("calculate_push_actions"):
+                await self._bulk_push_rule_evaluator.action_for_event_by_user(
+                    event, context
+                )
+        try:
+            last_event = await self.persist_and_notify_batched_events(
+                requestor, events_and_ctx, ratelimit
+            )
+        except Exception:
+            # Ensure that we actually remove the entries in the push actions
+            # staging area, if we calculated them.
+            for event, _ in events_and_ctx:
+                await self.store.remove_push_actions_from_staging(event.event_id)
+            raise
+
+        return last_event
+
+    async def persist_and_notify_batched_events(
+        self,
+        requester: Requester,
+        events_and_ctx: List[Tuple[EventBase, EventContext]],
+        ratelimit: bool = True,
+    ) -> EventBase:
+        """
+        Handles the actual persisting of a batch of events to the DB, and sends the appropriate
+        notifications when this is done.
+        Args:
+            requester: the room creator
+            events_and_ctx: list of events and their associated contexts to persist
+            ratelimit: whether to apply ratelimiting to this request
+        """
+        if ratelimit:
+            await self.request_ratelimiter.ratelimit(requester)
+
+        for event, context in events_and_ctx:
+            await self._actions_by_event_type(event, context)
+
+        assert self._storage_controllers.persistence is not None
+        (
+            persisted_events,
+            max_stream_token,
+        ) = await self._storage_controllers.persistence.persist_events(events_and_ctx)
+
+        stream_ordering = persisted_events[-1].internal_metadata.stream_ordering
+        assert stream_ordering is not None
+        pos = PersistedEventPosition(self._instance_name, stream_ordering)
+
+        async def _notify() -> None:
+            try:
+                await self.notifier.on_new_room_event(
+                    persisted_events[-1], pos, max_stream_token
+                )
+            except Exception:
+                logger.exception(
+                    "Error notifying about new room event %s",
+                    event.event_id,
+                )
+
+        run_in_background(_notify)
+
+        return persisted_events[-1]
+
     @measure_func("handle_new_client_event")
     async def handle_new_client_event(
         self,
@@ -1650,6 +1769,55 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
+        # run checks/actions on event based on type
+        await self._actions_by_event_type(event, context)
+
+        # Mark any `m.historical` messages as backfilled so they don't appear
+        # in `/sync` and have the proper decrementing `stream_ordering` as we import
+        backfilled = False
+        if event.internal_metadata.is_historical():
+            backfilled = True
+
+        # Note that this returns the event that was persisted, which may not be
+        # the same as we passed in if it was deduplicated due transaction IDs.
+        (
+            event,
+            event_pos,
+            max_stream_token,
+        ) = await self._storage_controllers.persistence.persist_event(
+            event, context=context, backfilled=backfilled
+        )
+
+        if self._ephemeral_events_enabled:
+            # If there's an expiry timestamp on the event, schedule its expiry.
+            self._message_handler.maybe_schedule_expiry(event)
+
+        async def _notify() -> None:
+            try:
+                await self.notifier.on_new_room_event(
+                    event, event_pos, max_stream_token, extra_users=extra_users
+                )
+            except Exception:
+                logger.exception(
+                    "Error notifying about new room event %s",
+                    event.event_id,
+                )
+
+        run_in_background(_notify)
+
+        if event.type == EventTypes.Message:
+            # We don't want to block sending messages on any presence code. This
+            # matters as sometimes presence code can take a while.
+            run_in_background(self._bump_active_time, requester.user)
+
+        return event
+
+    async def _actions_by_event_type(
+        self, event: EventBase, context: EventContext
+    ) -> None:
+        """
+        Helper function to execute actions/checks based on the event type
+        """
         if event.type == EventTypes.Member and event.membership == Membership.JOIN:
             (
                 current_membership,
@@ -1670,11 +1838,13 @@ class EventCreationHandler:
 
             original_event_id = event.unsigned.get("replaces_state")
             if original_event_id:
-                original_event = await self.store.get_event(original_event_id)
+                original_alias_event = await self.store.get_event(original_event_id)
 
-                if original_event:
-                    original_alias = original_event.content.get("alias", None)
-                    original_alt_aliases = original_event.content.get("alt_aliases", [])
+                if original_alias_event:
+                    original_alias = original_alias_event.content.get("alias", None)
+                    original_alt_aliases = original_alias_event.content.get(
+                        "alt_aliases", []
+                    )
 
             # Check the alias is currently valid (if it has changed).
             room_alias_str = event.content.get("alias", None)
@@ -1852,46 +2022,6 @@ class EventCreationHandler:
                         errcode=Codes.INVALID_PARAM,
                     )
 
-        # Mark any `m.historical` messages as backfilled so they don't appear
-        # in `/sync` and have the proper decrementing `stream_ordering` as we import
-        backfilled = False
-        if event.internal_metadata.is_historical():
-            backfilled = True
-
-        # Note that this returns the event that was persisted, which may not be
-        # the same as we passed in if it was deduplicated due transaction IDs.
-        (
-            event,
-            event_pos,
-            max_stream_token,
-        ) = await self._storage_controllers.persistence.persist_event(
-            event, context=context, backfilled=backfilled
-        )
-
-        if self._ephemeral_events_enabled:
-            # If there's an expiry timestamp on the event, schedule its expiry.
-            self._message_handler.maybe_schedule_expiry(event)
-
-        async def _notify() -> None:
-            try:
-                await self.notifier.on_new_room_event(
-                    event, event_pos, max_stream_token, extra_users=extra_users
-                )
-            except Exception:
-                logger.exception(
-                    "Error notifying about new room event %s",
-                    event.event_id,
-                )
-
-        run_in_background(_notify)
-
-        if event.type == EventTypes.Message:
-            # We don't want to block sending messages on any presence code. This
-            # matters as sometimes presence code can take a while.
-            run_in_background(self._bump_active_time, requester.user)
-
-        return event
-
     async def _maybe_kick_guest_users(
         self, event: EventBase, context: EventContext
     ) -> None: