summary refs log tree commit diff
diff options
context:
space:
mode:
authorH. Shay <hillerys@element.io>2022-09-01 14:46:25 -0700
committerH. Shay <hillerys@element.io>2022-09-01 14:46:25 -0700
commit059746dec47bf405ab8133aa6e719cc89e90dc09 (patch)
tree5be49cc031ce0e1de0694c5881df9561d7d590c6
parentreduce duplicated code (diff)
downloadsynapse-059746dec47bf405ab8133aa6e719cc89e90dc09.tar.xz
split out creating events for batches and add helper methods for duplicated code
-rw-r--r--synapse/handlers/message.py200
1 files changed, 132 insertions, 68 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1f3eca7dfa..3af77a5ed4 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -63,6 +63,7 @@ from synapse.types import (
     MutableStateMap,
     Requester,
     RoomAlias,
+    StateMap,
     StreamToken,
     UserID,
     create_requester,
@@ -567,7 +568,6 @@ class EventCreationHandler:
         outlier: bool = False,
         historical: bool = False,
         depth: Optional[int] = None,
-        for_batch: bool = False,
     ) -> Tuple[EventBase, EventContext]:
         """
         Given a dict from a client, create a new event.
@@ -619,8 +619,6 @@ class EventCreationHandler:
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
-            for_batch: Whether this event is being created for batch sending. Notably events
-                created for batch sending do not have their event context computed
 
         Raises:
             ResourceLimitError if server is blocked to some resource being
@@ -630,49 +628,10 @@ class EventCreationHandler:
         """
         await self.auth_blocking.check_auth_blocking(requester=requester)
 
-        if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
-            room_version_id = event_dict["content"]["room_version"]
-            maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
-            if not maybe_room_version_obj:
-                # this can happen if support is withdrawn for a room version
-                raise UnsupportedRoomVersionError(room_version_id)
-            room_version_obj = maybe_room_version_obj
-        else:
-            try:
-                room_version_obj = await self.store.get_room_version(
-                    event_dict["room_id"]
-                )
-            except NotFoundError:
-                raise AuthError(403, "Unknown room")
-
-        builder = self.event_builder_factory.for_room_version(
-            room_version_obj, event_dict
-        )
-
-        self.validator.validate_builder(builder)
+        builder = await self._get_and_validate_builder(event_dict)
 
         if builder.type == EventTypes.Member:
-            membership = builder.content.get("membership", None)
-            target = UserID.from_string(builder.state_key)
-
-            if membership in self.membership_types_to_include_profile_data_in:
-                # If event doesn't include a display name, add one.
-                profile = self.profile_handler
-                content = builder.content
-
-                try:
-                    if "displayname" not in content:
-                        displayname = await profile.get_displayname(target)
-                        if displayname is not None:
-                            content["displayname"] = displayname
-                    if "avatar_url" not in content:
-                        avatar_url = await profile.get_avatar_url(target)
-                        if avatar_url is not None:
-                            content["avatar_url"] = avatar_url
-                except Exception as e:
-                    logger.info(
-                        "Failed to get profile information for %r: %s", target, e
-                    )
+            await self._build_profile_data(builder)
 
         is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
         if require_consent and not is_exempt:
@@ -688,27 +647,15 @@ class EventCreationHandler:
 
         builder.internal_metadata.historical = historical
 
-        if for_batch:
-            event = await builder.build(
-                prev_event_ids=prev_event_ids,
-                auth_event_ids=auth_event_ids,
-                depth=depth,
-            )
-            # Pass on the outlier property from the builder to the event
-            # after it is created
-            if builder.internal_metadata.outlier:
-                event.internal_metadata.outlier = True
-
-        else:
-            event, context = await self.create_new_client_event(
-                builder=builder,
-                requester=requester,
-                allow_no_prev_events=allow_no_prev_events,
-                prev_event_ids=prev_event_ids,
-                auth_event_ids=auth_event_ids,
-                state_event_ids=state_event_ids,
-                depth=depth,
-            )
+        event, context = await self.create_new_client_event(
+            builder=builder,
+            requester=requester,
+            allow_no_prev_events=allow_no_prev_events,
+            prev_event_ids=prev_event_ids,
+            auth_event_ids=auth_event_ids,
+            state_event_ids=state_event_ids,
+            depth=depth,
+        )
 
         # In an ideal world we wouldn't need the second part of this condition. However,
         # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
@@ -748,10 +695,127 @@ class EventCreationHandler:
 
         self.validator.validate_new(event, self.config)
 
-        if for_batch:
-            return event
+        return event, context
+
+    async def create_event_for_batch(
+        self,
+        requester: Requester,
+        event_dict: dict,
+        prev_event_ids: List[str],
+        depth: int,
+        state_map: StateMap,
+        txn_id: Optional[str] = None,
+        require_consent: bool = True,
+        outlier: bool = False,
+    ) -> EventBase:
+        """
+        Given a dict from a client, create a new event. Notably does not create an event
+        context. Adds display names to Join membership events.
+
+        Args:
+            requester
+            event_dict: An entire event
+            txn_id
+            prev_event_ids:
+                the forward extremities to use as the prev_events for the
+                new event.
+            state_map: a state_map of previously created events for batching. Will be used
+                to calculate the auth_ids for the event, as the previously created events for
+                batching will not yet have been persisted to the db
+            require_consent: Whether to check if the requester has
+                consented to the privacy policy.
+            outlier: Indicates whether the event is an `outlier`, i.e. if
+                it's from an arbitrary point and floating in the DAG as
+                opposed to being inline with the current DAG.
+            depth: Override the depth used to order the event in the DAG.
+
+        Returns:
+            the created event
+        """
+        await self.auth_blocking.check_auth_blocking(requester=requester)
+
+        builder = await self._get_and_validate_builder(event_dict)
+
+        if builder.type == EventTypes.Member:
+            await self._build_profile_data(builder)
+
+        is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
+        if require_consent and not is_exempt:
+            await self.assert_accepted_privacy_policy(requester)
+
+        if requester.access_token_id is not None:
+            builder.internal_metadata.token_id = requester.access_token_id
+
+        if txn_id is not None:
+            builder.internal_metadata.txn_id = txn_id
+
+        builder.internal_metadata.outlier = outlier
+
+        auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
+        event = await builder.build(
+            prev_event_ids=prev_event_ids,
+            auth_event_ids=auth_ids,
+            depth=depth,
+        )
+        # Pass on the outlier property from the builder to the event
+        # after it is created
+        if builder.internal_metadata.outlier:
+            event.internal_metadata.outlier = True
+
+        self.validator.validate_new(event, self.config)
+
+        return event
+
+    async def _build_profile_data(self, builder: EventBuilder) -> None:
+        """
+        Helper method to add profile information to membership event
+        """
+        membership = builder.content.get("membership", None)
+        target = UserID.from_string(builder.state_key)
+
+        if membership in self.membership_types_to_include_profile_data_in:
+            # If event doesn't include a display name, add one.
+            profile = self.profile_handler
+            content = builder.content
+
+            try:
+                if "displayname" not in content:
+                    displayname = await profile.get_displayname(target)
+                    if displayname is not None:
+                        content["displayname"] = displayname
+                if "avatar_url" not in content:
+                    avatar_url = await profile.get_avatar_url(target)
+                    if avatar_url is not None:
+                        content["avatar_url"] = avatar_url
+            except Exception as e:
+                logger.info("Failed to get profile information for %r: %s", target, e)
+
+    async def _get_and_validate_builder(self, event_dict: dict) -> EventBuilder:
+        """
+        Helper method to create and validate a builder object when creating an event
+        """
+        if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
+            room_version_id = event_dict["content"]["room_version"]
+            maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+            if not maybe_room_version_obj:
+                # this can happen if support is withdrawn for a room version
+                raise UnsupportedRoomVersionError(room_version_id)
+            room_version_obj = maybe_room_version_obj
         else:
-            return event, context
+            try:
+                room_version_obj = await self.store.get_room_version(
+                    event_dict["room_id"]
+                )
+            except NotFoundError:
+                raise AuthError(403, "Unknown room")
+
+        builder = self.event_builder_factory.for_room_version(
+            room_version_obj, event_dict
+        )
+
+        self.validator.validate_builder(builder)
+
+        return builder
 
     async def _is_exempt_from_privacy_policy(
         self, builder: EventBuilder, requester: Requester