diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 9c58e3689e..ebf4e32230 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -29,6 +29,7 @@ from synapse.api.errors import (
SynapseError,
)
from synapse.api.filtering import Filter
+from synapse.appservice import ApplicationService
from synapse.events.utils import format_event_for_client_v2
from synapse.http.servlet import (
RestServlet,
@@ -47,11 +48,13 @@ from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import (
JsonDict,
+ Requester,
RoomAlias,
RoomID,
StreamToken,
ThirdPartyInstanceID,
UserID,
+ create_requester,
)
from synapse.util import json_decoder
from synapse.util.stringutils import parse_and_validate_server_name, random_string
@@ -309,7 +312,7 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
self.room_member_handler = hs.get_room_member_handler()
self.auth = hs.get_auth()
- async def inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
+ async def _inherit_depth_from_prev_ids(self, prev_event_ids) -> int:
(
most_recent_prev_event_id,
most_recent_prev_event_depth,
@@ -378,6 +381,25 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
return insertion_event
+ async def _create_requester_for_user_id_from_app_service(
+ self, user_id: str, app_service: ApplicationService
+ ) -> Requester:
+ """Creates a new requester for the given user_id
+ and validates that the app service is allowed to control
+ the given user.
+
+ Args:
+ user_id: The author MXID that the app service is controlling
+ app_service: The app service that controls the user
+
+ Returns:
+ Requester object
+ """
+
+ await self.auth.validate_appservice_can_control_user_id(app_service, user_id)
+
+ return create_requester(user_id, app_service=app_service)
+
async def on_POST(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=False)
@@ -443,7 +465,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
if event_dict["type"] == EventTypes.Member:
membership = event_dict["content"].get("membership", None)
event_id, _ = await self.room_member_handler.update_membership(
- requester,
+ await self._create_requester_for_user_id_from_app_service(
+ state_event["sender"], requester.app_service
+ ),
target=UserID.from_string(event_dict["state_key"]),
room_id=room_id,
action=membership,
@@ -463,7 +487,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
+ await self._create_requester_for_user_id_from_app_service(
+ state_event["sender"], requester.app_service
+ ),
event_dict,
outlier=True,
prev_event_ids=[fake_prev_event_id],
@@ -479,7 +505,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
events_to_create = body["events"]
prev_event_ids = prev_events_from_query
- inherited_depth = await self.inherit_depth_from_prev_ids(prev_events_from_query)
+ inherited_depth = await self._inherit_depth_from_prev_ids(
+ prev_events_from_query
+ )
# Figure out which chunk to connect to. If they passed in
# chunk_id_from_query let's use it. The chunk ID passed in comes
@@ -509,7 +537,10 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
base_insertion_event,
_,
) = await self.event_creation_handler.create_and_send_nonmember_event(
- requester,
+ await self._create_requester_for_user_id_from_app_service(
+ base_insertion_event_dict["sender"],
+ requester.app_service,
+ ),
base_insertion_event_dict,
prev_event_ids=base_insertion_event_dict.get("prev_events"),
auth_event_ids=auth_event_ids,
@@ -558,7 +589,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
}
event, context = await self.event_creation_handler.create_event(
- requester,
+ await self._create_requester_for_user_id_from_app_service(
+ ev["sender"], requester.app_service
+ ),
event_dict,
prev_event_ids=event_dict.get("prev_events"),
auth_event_ids=auth_event_ids,
@@ -588,7 +621,9 @@ class RoomBatchSendEventRestServlet(TransactionRestServlet):
# where topological_ordering is just depth.
for (event, context) in reversed(events_to_persist):
ev = await self.event_creation_handler.handle_new_client_event(
- requester=requester,
+ await self._create_requester_for_user_id_from_app_service(
+ event["sender"], requester.app_service
+ ),
event=event,
context=context,
)
|