diff options
Diffstat (limited to 'synapse/rest/client/room.py')
-rw-r--r-- | synapse/rest/client/room.py | 174 |
1 files changed, 108 insertions, 66 deletions
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 61e4cf0213..129b6fe6b0 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID +from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID from synapse.types.state import StateFilter from synapse.util import json_decoder from synapse.util.cancellation import cancellable @@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet): PATTERNS = "/createRoom" register_txn_path(self, PATTERNS, http_server) - def on_PUT( + async def on_PUT( self, request: SynapseRequest, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request(request, self.on_POST, request) + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester + ) async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) + return await self._do(request, requester) + async def _do( + self, request: SynapseRequest, requester: Requester + ) -> Tuple[int, JsonDict]: room_id, _, _ = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet): # TODO: Needs unit testing for generic events -class RoomStateEventRestServlet(TransactionRestServlet): +class RoomStateEventRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): - super().__init__(hs) + super().__init__() self.event_creation_handler = hs.get_event_creation_handler() self.room_member_handler = hs.get_room_member_handler() self.message_handler = hs.get_message_handler() @@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet): def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" - register_txn_path(self, PATTERNS, http_server, with_get=True) + register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, event_type: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict: JsonDict = { @@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_GET( - self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str - ) -> Tuple[int, str]: - return 200, "Not implemented" + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_type: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_id, event_type, None) - def on_PUT( + async def on_PUT( self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_type, txn_id + return await self.txns.fetch_or_execute_request( + request, + requester, + self._do, + request, + requester, + room_id, + event_type, + txn_id, ) @@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_identifier: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - content = parse_json_object_from_request(request, allow_empty_body=True) # twisted.web.server.Request.args is incorrectly defined as Optional[Any] @@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet): return 200, {"room_id": room_id} - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_identifier: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_identifier, None) + + async def on_PUT( self, request: SynapseRequest, room_identifier: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_identifier, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester, room_identifier, txn_id ) # TODO: Needs unit testing -class PublicRoomListRestServlet(TransactionRestServlet): +class PublicRoomListRestServlet(RestServlet): PATTERNS = client_patterns("/publicRooms$", v1=True) def __init__(self, hs: "HomeServer"): - super().__init__(hs) + super().__init__() self.hs = hs self.auth = hs.get_auth() @@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - async def on_POST( - self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None - ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=False) - + async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]: await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} - def on_PUT( + async def on_POST( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=False) + return await self._do(requester, room_id) + + async def on_PUT( self, request: SynapseRequest, room_id: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=False) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, requester, room_id ) @@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, membership_action: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request, allow_guest=True) - if requester.is_guest and membership_action not in { Membership.JOIN, Membership.LEAVE, @@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet): return 200, return_value - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + membership_action: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) + return await self._do(request, requester, room_id, membership_action, None) + + async def on_PUT( self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request, allow_guest=True) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, membership_action, txn_id + return await self.txns.fetch_or_execute_request( + request, + requester, + self._do, + request, + requester, + room_id, + membership_action, + txn_id, ) @@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - async def on_POST( + async def _do( self, request: SynapseRequest, + requester: Requester, room_id: str, event_id: str, - txn_id: Optional[str] = None, + txn_id: Optional[str], ) -> Tuple[int, JsonDict]: - requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) try: @@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet): set_tag("event_id", event_id) return 200, {"event_id": event_id} - def on_PUT( + async def on_POST( + self, + request: SynapseRequest, + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + return await self._do(request, requester, room_id, event_id, None) + + async def on_PUT( self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str - ) -> Awaitable[Tuple[int, JsonDict]]: + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) set_tag("txn_id", txn_id) - return self.txns.fetch_or_execute_request( - request, self.on_POST, request, room_id, event_id, txn_id + return await self.txns.fetch_or_execute_request( + request, requester, self._do, request, requester, room_id, event_id, txn_id ) @@ -1224,7 +1280,6 @@ def register_txn_path( servlet: RestServlet, regex_string: str, http_server: HttpServer, - with_get: bool = False, ) -> None: """Registers a transaction-based path. @@ -1236,7 +1291,6 @@ def register_txn_path( regex_string: The regex string to register. Must NOT have a trailing $ as this string will be appended to. http_server: The http_server to register paths with. - with_get: True to also register respective GET paths for the PUTs. """ on_POST = getattr(servlet, "on_POST", None) on_PUT = getattr(servlet, "on_PUT", None) @@ -1254,18 +1308,6 @@ def register_txn_path( on_PUT, servlet.__class__.__name__, ) - on_GET = getattr(servlet, "on_GET", None) - if with_get: - if on_GET is None: - raise RuntimeError( - "register_txn_path called with with_get = True, but no on_GET method exists" - ) - http_server.register_paths( - "GET", - client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True), - on_GET, - servlet.__class__.__name__, - ) class TimestampLookupRestServlet(RestServlet): |