summary refs log tree commit diff
path: root/synapse/rest/client/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/room.py')
-rw-r--r--synapse/rest/client/room.py174
1 files changed, 108 insertions, 66 deletions
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 61e4cf0213..129b6fe6b0 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -57,7 +57,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.client._base import client_patterns
 from synapse.rest.client.transactions import HttpTransactionCache
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, StreamToken, ThirdPartyInstanceID, UserID
+from synapse.types import JsonDict, Requester, StreamToken, ThirdPartyInstanceID, UserID
 from synapse.types.state import StateFilter
 from synapse.util import json_decoder
 from synapse.util.cancellation import cancellable
@@ -151,15 +151,22 @@ class RoomCreateRestServlet(TransactionRestServlet):
         PATTERNS = "/createRoom"
         register_txn_path(self, PATTERNS, http_server)
 
-    def on_PUT(
+    async def on_PUT(
         self, request: SynapseRequest, txn_id: str
-    ) -> Awaitable[Tuple[int, JsonDict]]:
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
         set_tag("txn_id", txn_id)
-        return self.txns.fetch_or_execute_request(request, self.on_POST, request)
+        return await self.txns.fetch_or_execute_request(
+            request, requester, self._do, request, requester
+        )
 
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
+        return await self._do(request, requester)
 
+    async def _do(
+        self, request: SynapseRequest, requester: Requester
+    ) -> Tuple[int, JsonDict]:
         room_id, _, _ = await self._room_creation_handler.create_room(
             requester, self.get_room_config(request)
         )
@@ -172,9 +179,9 @@ class RoomCreateRestServlet(TransactionRestServlet):
 
 
 # TODO: Needs unit testing for generic events
-class RoomStateEventRestServlet(TransactionRestServlet):
+class RoomStateEventRestServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        super().__init__()
         self.event_creation_handler = hs.get_event_creation_handler()
         self.room_member_handler = hs.get_room_member_handler()
         self.message_handler = hs.get_message_handler()
@@ -324,16 +331,16 @@ class RoomSendEventRestServlet(TransactionRestServlet):
     def register(self, http_server: HttpServer) -> None:
         # /rooms/$roomid/send/$event_type[/$txn_id]
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
-        register_txn_path(self, PATTERNS, http_server, with_get=True)
+        register_txn_path(self, PATTERNS, http_server)
 
-    async def on_POST(
+    async def _do(
         self,
         request: SynapseRequest,
+        requester: Requester,
         room_id: str,
         event_type: str,
-        txn_id: Optional[str] = None,
+        txn_id: Optional[str],
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         content = parse_json_object_from_request(request)
 
         event_dict: JsonDict = {
@@ -362,18 +369,30 @@ class RoomSendEventRestServlet(TransactionRestServlet):
         set_tag("event_id", event_id)
         return 200, {"event_id": event_id}
 
-    def on_GET(
-        self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
-    ) -> Tuple[int, str]:
-        return 200, "Not implemented"
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        event_type: str,
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
+        return await self._do(request, requester, room_id, event_type, None)
 
-    def on_PUT(
+    async def on_PUT(
         self, request: SynapseRequest, room_id: str, event_type: str, txn_id: str
-    ) -> Awaitable[Tuple[int, JsonDict]]:
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         set_tag("txn_id", txn_id)
 
-        return self.txns.fetch_or_execute_request(
-            request, self.on_POST, request, room_id, event_type, txn_id
+        return await self.txns.fetch_or_execute_request(
+            request,
+            requester,
+            self._do,
+            request,
+            requester,
+            room_id,
+            event_type,
+            txn_id,
         )
 
 
@@ -389,14 +408,13 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
         PATTERNS = "/join/(?P<room_identifier>[^/]*)"
         register_txn_path(self, PATTERNS, http_server)
 
-    async def on_POST(
+    async def _do(
         self,
         request: SynapseRequest,
+        requester: Requester,
         room_identifier: str,
-        txn_id: Optional[str] = None,
+        txn_id: Optional[str],
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
         content = parse_json_object_from_request(request, allow_empty_body=True)
 
         # twisted.web.server.Request.args is incorrectly defined as Optional[Any]
@@ -420,22 +438,31 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, TransactionRestServlet):
 
         return 200, {"room_id": room_id}
 
-    def on_PUT(
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+        room_identifier: str,
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
+        return await self._do(request, requester, room_identifier, None)
+
+    async def on_PUT(
         self, request: SynapseRequest, room_identifier: str, txn_id: str
-    ) -> Awaitable[Tuple[int, JsonDict]]:
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         set_tag("txn_id", txn_id)
 
-        return self.txns.fetch_or_execute_request(
-            request, self.on_POST, request, room_identifier, txn_id
+        return await self.txns.fetch_or_execute_request(
+            request, requester, self._do, request, requester, room_identifier, txn_id
         )
 
 
 # TODO: Needs unit testing
-class PublicRoomListRestServlet(TransactionRestServlet):
+class PublicRoomListRestServlet(RestServlet):
     PATTERNS = client_patterns("/publicRooms$", v1=True)
 
     def __init__(self, hs: "HomeServer"):
-        super().__init__(hs)
+        super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
 
@@ -907,22 +934,25 @@ class RoomForgetRestServlet(TransactionRestServlet):
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
         register_txn_path(self, PATTERNS, http_server)
 
-    async def on_POST(
-        self, request: SynapseRequest, room_id: str, txn_id: Optional[str] = None
-    ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_guest=False)
-
+    async def _do(self, requester: Requester, room_id: str) -> Tuple[int, JsonDict]:
         await self.room_member_handler.forget(user=requester.user, room_id=room_id)
 
         return 200, {}
 
-    def on_PUT(
+    async def on_POST(
+        self, request: SynapseRequest, room_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=False)
+        return await self._do(requester, room_id)
+
+    async def on_PUT(
         self, request: SynapseRequest, room_id: str, txn_id: str
-    ) -> Awaitable[Tuple[int, JsonDict]]:
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=False)
         set_tag("txn_id", txn_id)
 
-        return self.txns.fetch_or_execute_request(
-            request, self.on_POST, request, room_id, txn_id
+        return await self.txns.fetch_or_execute_request(
+            request, requester, self._do, requester, room_id
         )
 
 
@@ -941,15 +971,14 @@ class RoomMembershipRestServlet(TransactionRestServlet):
         )
         register_txn_path(self, PATTERNS, http_server)
 
-    async def on_POST(
+    async def _do(
         self,
         request: SynapseRequest,
+        requester: Requester,
         room_id: str,
         membership_action: str,
-        txn_id: Optional[str] = None,
+        txn_id: Optional[str],
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request, allow_guest=True)
-
         if requester.is_guest and membership_action not in {
             Membership.JOIN,
             Membership.LEAVE,
@@ -1014,13 +1043,30 @@ class RoomMembershipRestServlet(TransactionRestServlet):
 
         return 200, return_value
 
-    def on_PUT(
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        membership_action: str,
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
+        return await self._do(request, requester, room_id, membership_action, None)
+
+    async def on_PUT(
         self, request: SynapseRequest, room_id: str, membership_action: str, txn_id: str
-    ) -> Awaitable[Tuple[int, JsonDict]]:
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         set_tag("txn_id", txn_id)
 
-        return self.txns.fetch_or_execute_request(
-            request, self.on_POST, request, room_id, membership_action, txn_id
+        return await self.txns.fetch_or_execute_request(
+            request,
+            requester,
+            self._do,
+            request,
+            requester,
+            room_id,
+            membership_action,
+            txn_id,
         )
 
 
@@ -1036,14 +1082,14 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
         register_txn_path(self, PATTERNS, http_server)
 
-    async def on_POST(
+    async def _do(
         self,
         request: SynapseRequest,
+        requester: Requester,
         room_id: str,
         event_id: str,
-        txn_id: Optional[str] = None,
+        txn_id: Optional[str],
     ) -> Tuple[int, JsonDict]:
-        requester = await self.auth.get_user_by_req(request)
         content = parse_json_object_from_request(request)
 
         try:
@@ -1094,13 +1140,23 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
         set_tag("event_id", event_id)
         return 200, {"event_id": event_id}
 
-    def on_PUT(
+    async def on_POST(
+        self,
+        request: SynapseRequest,
+        room_id: str,
+        event_id: str,
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+        return await self._do(request, requester, room_id, event_id, None)
+
+    async def on_PUT(
         self, request: SynapseRequest, room_id: str, event_id: str, txn_id: str
-    ) -> Awaitable[Tuple[int, JsonDict]]:
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
         set_tag("txn_id", txn_id)
 
-        return self.txns.fetch_or_execute_request(
-            request, self.on_POST, request, room_id, event_id, txn_id
+        return await self.txns.fetch_or_execute_request(
+            request, requester, self._do, request, requester, room_id, event_id, txn_id
         )
 
 
@@ -1224,7 +1280,6 @@ def register_txn_path(
     servlet: RestServlet,
     regex_string: str,
     http_server: HttpServer,
-    with_get: bool = False,
 ) -> None:
     """Registers a transaction-based path.
 
@@ -1236,7 +1291,6 @@ def register_txn_path(
         regex_string: The regex string to register. Must NOT have a
             trailing $ as this string will be appended to.
         http_server: The http_server to register paths with.
-        with_get: True to also register respective GET paths for the PUTs.
     """
     on_POST = getattr(servlet, "on_POST", None)
     on_PUT = getattr(servlet, "on_PUT", None)
@@ -1254,18 +1308,6 @@ def register_txn_path(
         on_PUT,
         servlet.__class__.__name__,
     )
-    on_GET = getattr(servlet, "on_GET", None)
-    if with_get:
-        if on_GET is None:
-            raise RuntimeError(
-                "register_txn_path called with with_get = True, but no on_GET method exists"
-            )
-        http_server.register_paths(
-            "GET",
-            client_patterns(regex_string + "/(?P<txn_id>[^/]*)$", v1=True),
-            on_GET,
-            servlet.__class__.__name__,
-        )
 
 
 class TimestampLookupRestServlet(RestServlet):