summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-07-13 13:23:16 -0400
committerGitHub <noreply@github.com>2022-07-13 13:23:16 -0400
commit1d5c80b16188c587427d663c3bec57e9c196dd1b (patch)
treee92bf886e44579a5f64d504728c677cb8a2d47f4 /synapse/rest/client
parentAdd prometheus counters for content types other than events (#13175) (diff)
downloadsynapse-1d5c80b16188c587427d663c3bec57e9c196dd1b.tar.xz
Reduce duplicate code in receipts servlets. (#13198)
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/read_marker.py56
-rw-r--r--synapse/rest/client/receipts.py20
2 files changed, 32 insertions, 44 deletions
diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py
index 3644705e6a..8896f2df50 100644
--- a/synapse/rest/client/read_marker.py
+++ b/synapse/rest/client/read_marker.py
@@ -40,6 +40,10 @@ class ReadMarkerRestServlet(RestServlet):
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
+        self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ}
+        if hs.config.experimental.msc2285_enabled:
+            self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE)
+
     async def on_POST(
         self, request: SynapseRequest, room_id: str
     ) -> Tuple[int, JsonDict]:
@@ -49,13 +53,7 @@ class ReadMarkerRestServlet(RestServlet):
 
         body = parse_json_object_from_request(request)
 
-        valid_receipt_types = {
-            ReceiptTypes.READ,
-            ReceiptTypes.FULLY_READ,
-            ReceiptTypes.READ_PRIVATE,
-        }
-
-        unrecognized_types = set(body.keys()) - valid_receipt_types
+        unrecognized_types = set(body.keys()) - self._known_receipt_types
         if unrecognized_types:
             # It's fine if there are unrecognized receipt types, but let's log
             # it to help debug clients that have typoed the receipt type.
@@ -65,31 +63,25 @@ class ReadMarkerRestServlet(RestServlet):
             # types.
             logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types)
 
-        read_event_id = body.get(ReceiptTypes.READ, None)
-        if read_event_id:
-            await self.receipts_handler.received_client_receipt(
-                room_id,
-                ReceiptTypes.READ,
-                user_id=requester.user.to_string(),
-                event_id=read_event_id,
-            )
-
-        read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
-        if read_private_event_id and self.config.experimental.msc2285_enabled:
-            await self.receipts_handler.received_client_receipt(
-                room_id,
-                ReceiptTypes.READ_PRIVATE,
-                user_id=requester.user.to_string(),
-                event_id=read_private_event_id,
-            )
-
-        read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
-        if read_marker_event_id:
-            await self.read_marker_handler.received_client_read_marker(
-                room_id,
-                user_id=requester.user.to_string(),
-                event_id=read_marker_event_id,
-            )
+        for receipt_type in self._known_receipt_types:
+            event_id = body.get(receipt_type, None)
+            # TODO Add validation to reject non-string event IDs.
+            if not event_id:
+                continue
+
+            if receipt_type == ReceiptTypes.FULLY_READ:
+                await self.read_marker_handler.received_client_read_marker(
+                    room_id,
+                    user_id=requester.user.to_string(),
+                    event_id=event_id,
+                )
+            else:
+                await self.receipts_handler.received_client_receipt(
+                    room_id,
+                    receipt_type,
+                    user_id=requester.user.to_string(),
+                    event_id=event_id,
+                )
 
         return 200, {}
 
diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py
index 4b03eb876b..409bfd43c1 100644
--- a/synapse/rest/client/receipts.py
+++ b/synapse/rest/client/receipts.py
@@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet):
 
     def __init__(self, hs: "HomeServer"):
         super().__init__()
-        self.hs = hs
         self.auth = hs.get_auth()
         self.receipts_handler = hs.get_receipts_handler()
         self.read_marker_handler = hs.get_read_marker_handler()
         self.presence_handler = hs.get_presence_handler()
 
+        self._known_receipt_types = {ReceiptTypes.READ}
+        if hs.config.experimental.msc2285_enabled:
+            self._known_receipt_types.update(
+                (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
+            )
+
     async def on_POST(
         self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
     ) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
-        if self.hs.config.experimental.msc2285_enabled and receipt_type not in [
-            ReceiptTypes.READ,
-            ReceiptTypes.READ_PRIVATE,
-            ReceiptTypes.FULLY_READ,
-        ]:
+        if receipt_type not in self._known_receipt_types:
             raise SynapseError(
                 400,
-                "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'",
+                f"Receipt type must be {', '.join(self._known_receipt_types)}",
             )
-        elif (
-            not self.hs.config.experimental.msc2285_enabled
-            and receipt_type != ReceiptTypes.READ
-        ):
-            raise SynapseError(400, "Receipt type must be 'm.read'")
 
         parse_json_object_from_request(request, allow_empty_body=False)