diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 61e4cf0213..129b6fe6b0 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.rest.client._base import client_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
from synapse.types.state import StateFilter
from synapse.util import json_decoder
from synapse.util.cancellation import cancellable
@@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet):
PATTERNS = "/createRoom"
register_txn_path(self, PATTERNS, http_server)
- def on_PUT(
+ async def on_PUT(
self, request: SynapseRequest, txn_id: str
- ) -> Awaitable[Tuple[int, JsonDict]]:
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
- return self.txns.fetch_or_execute_request(request, self.on_POST, request)
+ return await self.txns.fetch_or_execute_request(
+ request, requester, self._do, request, requester
+ )
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
+ return await self._do(request, requester)
+ async def _do(
+ self, request: SynapseRequest, requester: Requester
+ ) -> Tuple[int, JsonDict]:
room_id, _, _ = await self._room_creation_handler.create_room(
requester, self.get_room_config(request)
)
@@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet):
# TODO: Needs unit testing for generic events
-class RoomStateEventRestServlet(TransactionRestServlet):
+class RoomStateEventRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ super().__init__()
self.event_creation_handler = hs.get_event_creation_handler()
self.room_member_handler = hs.get_room_member_handler()
self.message_handler = hs.get_message_handler()
@@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet):
def register(self, http_server: HttpServer) -> None:
# /rooms/$roomid/send/$event_type[/$txn_id]
PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
- register_txn_path(self, PATTERNS, http_server, with_get=True)
+ register_txn_path(self, PATTERNS, http_server)
- async def on_POST(
+ async def _do(
self,
request: SynapseRequest,
+ requester: Requester,
room_id: str,
event_type: str,
- txn_id: Optional[str] = None,
+ txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
content = parse_json_object_from_request(request)
event_dict: JsonDict = {
@@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_GET(
- self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
- ) -> Tuple[int, str]:
- return 200, "Not implemented"
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_type: str,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ return await self._do(request, requester, room_id, event_type, None)
- def on_PUT(
+ async def on_PUT(
self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
- ) -> Awaitable[Tuple[int, JsonDict]]:
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
- return self.txns.fetch_or_execute_request(
- request, self.on_POST, request, room_id, event_type, txn_id
+ return await self.txns.fetch_or_execute_request(
+ request,
+ requester,
+ self._do,
+ request,
+ requester,
+ room_id,
+ event_type,
+ txn_id,
)
@@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
PATTERNS = "/join/(?P<room_identifier>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(
+ async def _do(
self,
request: SynapseRequest,
+ requester: Requester,
room_identifier: str,
- txn_id: Optional[str] = None,
+ txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
content = parse_json_object_from_request(request, allow_empty_body=True)
# twisted.web.server.Request.args is incorrectly defined as Optional[Any]
@@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
return 200, {"room_id": room_id}
- def on_PUT(
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_identifier: str,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ return await self._do(request, requester, room_identifier, None)
+
+ async def on_PUT(
self, request: SynapseRequest, room_identifier: str, txn_id: str
- ) -> Awaitable[Tuple[int, JsonDict]]:
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
- return self.txns.fetch_or_execute_request(
- request, self.on_POST, request, room_identifier, txn_id
+ return await self.txns.fetch_or_execute_request(
+ request, requester, self._do, request, requester, room_identifier, txn_id
)
# TODO: Needs unit testing
-class PublicRoomListRestServlet(TransactionRestServlet):
+class PublicRoomListRestServlet(RestServlet):
PATTERNS = client_patterns("/publicRooms$", v1=True)
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ super().__init__()
self.hs = hs
self.auth = hs.get_auth()
@@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(
- self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
- ) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=False)
-
+ async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
await self.room_member_handler.forget(user=requester.user, room_id=room_id)
return 200, {}
- def on_PUT(
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
+ return await self._do(requester, room_id)
+
+ async def on_PUT(
self, request: SynapseRequest, room_id: str, txn_id: str
- ) -> Awaitable[Tuple[int, JsonDict]]:
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=False)
set_tag("txn_id", txn_id)
- return self.txns.fetch_or_execute_request(
- request, self.on_POST, request, room_id, txn_id
+ return await self.txns.fetch_or_execute_request(
+ request, requester, self._do, requester, room_id
)
@@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet):
)
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(
+ async def _do(
self,
request: SynapseRequest,
+ requester: Requester,
room_id: str,
membership_action: str,
- txn_id: Optional[str] = None,
+ txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
if requester.is_guest and membership_action not in {
Membership.JOIN,
Membership.LEAVE,
@@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet):
return 200, return_value
- def on_PUT(
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ membership_action: str,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
+ return await self._do(request, requester, room_id, membership_action, None)
+
+ async def on_PUT(
self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
- ) -> Awaitable[Tuple[int, JsonDict]]:
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
set_tag("txn_id", txn_id)
- return self.txns.fetch_or_execute_request(
- request, self.on_POST, request, room_id, membership_action, txn_id
+ return await self.txns.fetch_or_execute_request(
+ request,
+ requester,
+ self._do,
+ request,
+ requester,
+ room_id,
+ membership_action,
+ txn_id,
)
@@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
register_txn_path(self, PATTERNS, http_server)
- async def on_POST(
+ async def _do(
self,
request: SynapseRequest,
+ requester: Requester,
room_id: str,
event_id: str,
- txn_id: Optional[str] = None,
+ txn_id: Optional[str],
) -> Tuple[int, JsonDict]:
- requester = await self.auth.get_user_by_req(request)
content = parse_json_object_from_request(request)
try:
@@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
set_tag("event_id", event_id)
return 200, {"event_id": event_id}
- def on_PUT(
+ async def on_POST(
+ self,
+ request: SynapseRequest,
+ room_id: str,
+ event_id: str,
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
+ return await self._do(request, requester, room_id, event_id, None)
+
+ async def on_PUT(
self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
- ) -> Awaitable[Tuple[int, JsonDict]]:
+ ) -> Tuple[int, JsonDict]:
+ requester = await self.auth.get_user_by_req(request)
set_tag("txn_id", txn_id)
- return self.txns.fetch_or_execute_request(
- request, self.on_POST, request, room_id, event_id, txn_id
+ return await self.txns.fetch_or_execute_request(
+ request, requester, self._do, request, requester, room_id, event_id, txn_id
)
@@ -1224,7 +1280,6 @@ def register_txn_path(
servlet: RestServlet,
regex_string: str,
http_server: HttpServer,
- with_get: bool = False,
) -> None:
"""Registers a transaction-based path.
@@ -1236,7 +1291,6 @@ def register_txn_path(
regex_string: The regex string to register. Must NOT have a
trailing $ as this string will be appended to.
http_server: The http_server to register paths with.
- with_get: True to also register respective GET paths for the PUTs.
"""
on_POST = getattr(servlet, "on_POST", None)
on_PUT = getattr(servlet, "on_PUT", None)
@@ -1254,18 +1308,6 @@ def register_txn_path(
on_PUT,
servlet.__class__.__name__,
)
- on_GET = getattr(servlet, "on_GET", None)
- if with_get:
- if on_GET is None:
- raise RuntimeError(
- "register_txn_path called with with_get = True, but no on_GET method exists"
- )
- http_server.register_paths(
- "GET",
- client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
- on_GET,
- servlet.__class__.__name__,
- )
class TimestampLookupRestServlet(RestServlet):
|