diff options
-rw-r--r-- | changelog.d/13198.misc | 1 | ||||
-rw-r--r-- | synapse/rest/client/read_marker.py | 56 | ||||
-rw-r--r-- | synapse/rest/client/receipts.py | 20 |
3 files changed, 33 insertions, 44 deletions
diff --git a/changelog.d/13198.misc b/changelog.d/13198.misc new file mode 100644 index 0000000000..5aef2432df --- /dev/null +++ b/changelog.d/13198.misc @@ -0,0 +1 @@ +Refactor receipts servlet logic to avoid duplicated code. diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 3644705e6a..8896f2df50 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -40,6 +40,10 @@ class ReadMarkerRestServlet(RestServlet): self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() + self._known_receipt_types = {ReceiptTypes.READ, ReceiptTypes.FULLY_READ} + if hs.config.experimental.msc2285_enabled: + self._known_receipt_types.add(ReceiptTypes.READ_PRIVATE) + async def on_POST( self, request: SynapseRequest, room_id: str ) -> Tuple[int, JsonDict]: @@ -49,13 +53,7 @@ class ReadMarkerRestServlet(RestServlet): body = parse_json_object_from_request(request) - valid_receipt_types = { - ReceiptTypes.READ, - ReceiptTypes.FULLY_READ, - ReceiptTypes.READ_PRIVATE, - } - - unrecognized_types = set(body.keys()) - valid_receipt_types + unrecognized_types = set(body.keys()) - self._known_receipt_types if unrecognized_types: # It's fine if there are unrecognized receipt types, but let's log # it to help debug clients that have typoed the receipt type. @@ -65,31 +63,25 @@ class ReadMarkerRestServlet(RestServlet): # types. logger.info("Ignoring unrecognized receipt types: %s", unrecognized_types) - read_event_id = body.get(ReceiptTypes.READ, None) - if read_event_id: - await self.receipts_handler.received_client_receipt( - room_id, - ReceiptTypes.READ, - user_id=requester.user.to_string(), - event_id=read_event_id, - ) - - read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None) - if read_private_event_id and self.config.experimental.msc2285_enabled: - await self.receipts_handler.received_client_receipt( - room_id, - ReceiptTypes.READ_PRIVATE, - user_id=requester.user.to_string(), - event_id=read_private_event_id, - ) - - read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None) - if read_marker_event_id: - await self.read_marker_handler.received_client_read_marker( - room_id, - user_id=requester.user.to_string(), - event_id=read_marker_event_id, - ) + for receipt_type in self._known_receipt_types: + event_id = body.get(receipt_type, None) + # TODO Add validation to reject non-string event IDs. + if not event_id: + continue + + if receipt_type == ReceiptTypes.FULLY_READ: + await self.read_marker_handler.received_client_read_marker( + room_id, + user_id=requester.user.to_string(), + event_id=event_id, + ) + else: + await self.receipts_handler.received_client_receipt( + room_id, + receipt_type, + user_id=requester.user.to_string(), + event_id=event_id, + ) return 200, {} diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 4b03eb876b..409bfd43c1 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -39,31 +39,27 @@ class ReceiptRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): super().__init__() - self.hs = hs self.auth = hs.get_auth() self.receipts_handler = hs.get_receipts_handler() self.read_marker_handler = hs.get_read_marker_handler() self.presence_handler = hs.get_presence_handler() + self._known_receipt_types = {ReceiptTypes.READ} + if hs.config.experimental.msc2285_enabled: + self._known_receipt_types.update( + (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ) + ) + async def on_POST( self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if self.hs.config.experimental.msc2285_enabled and receipt_type not in [ - ReceiptTypes.READ, - ReceiptTypes.READ_PRIVATE, - ReceiptTypes.FULLY_READ, - ]: + if receipt_type not in self._known_receipt_types: raise SynapseError( 400, - "Receipt type must be 'm.read', 'org.matrix.msc2285.read.private' or 'm.fully_read'", + f"Receipt type must be {', '.join(self._known_receipt_types)}", ) - elif ( - not self.hs.config.experimental.msc2285_enabled - and receipt_type != ReceiptTypes.READ - ): - raise SynapseError(400, "Receipt type must be 'm.read'") parse_json_object_from_request(request, allow_empty_body=False) |