diff options
-rw-r--r-- | changelog.d/16327.bugfix | 1 | ||||
-rw-r--r-- | synapse/handlers/receipts.py | 26 | ||||
-rw-r--r-- | synapse/rest/client/read_marker.py | 2 | ||||
-rw-r--r-- | synapse/rest/client/receipts.py | 2 | ||||
-rw-r--r-- | tests/rest/client/test_receipts.py | 221 | ||||
-rw-r--r-- | tests/rest/client/test_sync.py | 154 |
6 files changed, 241 insertions, 165 deletions
diff --git a/changelog.d/16327.bugfix b/changelog.d/16327.bugfix new file mode 100644 index 0000000000..be3d1b4f21 --- /dev/null +++ b/changelog.d/16327.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where invalid receipts would be accepted. diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 2bacdebfb5..c7edada353 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -37,6 +37,8 @@ class ReceiptsHandler: self.server_name = hs.config.server.server_name self.store = hs.get_datastores().main self.event_auth_handler = hs.get_event_auth_handler() + self.event_handler = hs.get_event_handler() + self._storage_controllers = hs.get_storage_controllers() self.hs = hs @@ -81,6 +83,20 @@ class ReceiptsHandler: ) continue + # Let's check that the origin server is in the room before accepting the receipt. + # We don't want to block waiting on a partial state so take an + # approximation if needed. + domains = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation( + room_id + ) + if origin not in domains: + logger.info( + "Ignoring receipt for room %r from server %s as they're not in the room", + room_id, + origin, + ) + continue + for receipt_type, users in room_values.items(): for user_id, user_values in users.items(): if get_domain_from_id(user_id) != origin: @@ -158,17 +174,23 @@ class ReceiptsHandler: self, room_id: str, receipt_type: str, - user_id: str, + user_id: UserID, event_id: str, thread_id: Optional[str], ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. """ + + # Ensure the room/event exists, this will raise an error if the user + # cannot view the event. + if not await self.event_handler.get_event(user_id, room_id, event_id): + return + receipt = ReadReceipt( room_id=room_id, receipt_type=receipt_type, - user_id=user_id, + user_id=user_id.to_string(), event_ids=[event_id], thread_id=thread_id, data={"ts": int(self.clock.time_msec())}, diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 1707e51972..15e4d56cdb 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -84,7 +84,7 @@ class ReadMarkerRestServlet(RestServlet): await self.receipts_handler.received_client_receipt( room_id, receipt_type, - user_id=requester.user.to_string(), + user_id=requester.user, event_id=event_id, # Setting the thread ID is not possible with the /read_markers endpoint. thread_id=None, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 869a374459..814d075faf 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -108,7 +108,7 @@ class ReceiptRestServlet(RestServlet): await self.receipts_handler.received_client_receipt( room_id, receipt_type, - user_id=requester.user.to_string(), + user_id=requester.user, event_id=event_id, thread_id=thread_id, ) diff --git a/tests/rest/client/test_receipts.py b/tests/rest/client/test_receipts.py index 2a7fcea386..ec638c89b7 100644 --- a/tests/rest/client/test_receipts.py +++ b/tests/rest/client/test_receipts.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from http import HTTPStatus +from typing import Optional + from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin -from synapse.rest.client import login, receipts, register +from synapse.api.constants import EduTypes, EventTypes, HistoryVisibility, ReceiptTypes +from synapse.rest.client import login, receipts, room, sync from synapse.server import HomeServer +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest @@ -24,30 +29,113 @@ from tests import unittest class ReceiptsTestCase(unittest.HomeserverTestCase): servlets = [ login.register_servlets, - register.register_servlets, receipts.register_servlets, synapse.rest.admin.register_servlets, + room.register_servlets, + sync.register_servlets, ] def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.owner = self.register_user("owner", "pass") - self.owner_tok = self.login("owner", "pass") + self.url = "/sync?since=%s" + self.next_batch = "s0" + + # Register the first user + self.user_id = self.register_user("kermit", "monkey") + self.tok = self.login("kermit", "monkey") + + # Create the room + self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) + + # Register the second user + self.user2 = self.register_user("kermit2", "monkey") + self.tok2 = self.login("kermit2", "monkey") + + # Join the second user + self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) def test_send_receipt(self) -> None: + # Send a message. + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertNotEqual(self._get_read_receipt(), None) + + def test_send_receipt_unknown_event(self) -> None: + """Receipts sent for unknown events are ignored to not break message retention.""" + # Attempt to send a receipt to an unknown room. channel = self.make_request( "POST", "/rooms/!abc:beep/receipt/m.read/$def", content={}, - access_token=self.owner_tok, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertIsNone(self._get_read_receipt()) + + # Attempt to send a receipt to an unknown event. + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/m.read/$def", + content={}, + access_token=self.tok2, ) self.assertEqual(channel.code, 200, channel.result) + self.assertIsNone(self._get_read_receipt()) + + def test_send_receipt_unviewable_event(self) -> None: + """Receipts sent for unviewable events are errors.""" + # Create a room where new users can't see events from before their join + # & send events into it. + room_id = self.helper.create_room_as( + self.user_id, + tok=self.tok, + extra_content={ + "preset": "private_chat", + "initial_state": [ + { + "content": {"history_visibility": HistoryVisibility.JOINED}, + "state_key": "", + "type": EventTypes.RoomHistoryVisibility, + } + ], + }, + ) + res = self.helper.send(room_id, body="hello", tok=self.tok) + + # Attempt to send a receipt from the wrong user. + channel = self.make_request( + "POST", + f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + content={}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 403, channel.result) + + # Join the user to the room, but they still can't see the event. + self.helper.invite(room_id, self.user_id, self.user2, tok=self.tok) + self.helper.join(room=room_id, user=self.user2, tok=self.tok2) + + channel = self.make_request( + "POST", + f"/rooms/{room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + content={}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 403, channel.result) def test_send_receipt_invalid_room_id(self) -> None: channel = self.make_request( "POST", "/rooms/not-a-room-id/receipt/m.read/$def", content={}, - access_token=self.owner_tok, + access_token=self.tok, ) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -59,7 +147,7 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): "POST", "/rooms/!abc:beep/receipt/m.read/not-an-event-id", content={}, - access_token=self.owner_tok, + access_token=self.tok, ) self.assertEqual(channel.code, 400, channel.result) self.assertEqual( @@ -71,6 +159,123 @@ class ReceiptsTestCase(unittest.HomeserverTestCase): "POST", "/rooms/!abc:beep/receipt/invalid-receipt-type/$def", content={}, - access_token=self.owner_tok, + access_token=self.tok, ) self.assertEqual(channel.code, 400, channel.result) + + def test_private_read_receipts(self) -> None: + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a private read receipt to tell the server the first user's message was read + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that the first user can't see the other user's private read receipt + self.assertIsNone(self._get_read_receipt()) + + def test_public_receipt_can_override_private(self) -> None: + """ + Sending a public read receipt to the same event which has a private read + receipt should cause that receipt to become public. + """ + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a private read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertIsNone(self._get_read_receipt()) + + # Send a public read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that we did override the private read receipt + self.assertNotEqual(self._get_read_receipt(), None) + + def test_private_receipt_cannot_override_public(self) -> None: + """ + Sending a private read receipt to the same event which has a public read + receipt should cause no change. + """ + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a public read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + self.assertNotEqual(self._get_read_receipt(), None) + + # Send a private read receipt + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", + {}, + access_token=self.tok2, + ) + self.assertEqual(channel.code, 200) + + # Test that we didn't override the public read receipt + self.assertIsNone(self._get_read_receipt()) + + def test_read_receipt_with_empty_body_is_rejected(self) -> None: + # Send a message as the first user + res = self.helper.send(self.room_id, body="hello", tok=self.tok) + + # Send a read receipt for this message with an empty body + channel = self.make_request( + "POST", + f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", + access_token=self.tok2, + ) + self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) + self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body) + + def _get_read_receipt(self) -> Optional[JsonDict]: + """Syncs and returns the read receipt.""" + + # Checks if event is a read receipt + def is_read_receipt(event: JsonDict) -> bool: + return event["type"] == EduTypes.RECEIPT + + # Sync + channel = self.make_request( + "GET", + self.url % self.next_batch, + access_token=self.tok, + ) + self.assertEqual(channel.code, 200) + + # Store the next batch for the next request. + self.next_batch = channel.json_body["next_batch"] + + if channel.json_body.get("rooms", None) is None: + return None + + # Return the read receipt + ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ + "ephemeral" + ]["events"] + receipt_event = filter(is_read_receipt, ephemeral_events) + return next(receipt_event, None) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 9c876c7a32..d60665254e 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from http import HTTPStatus -from typing import List, Optional +from typing import List from parameterized import parameterized @@ -22,7 +21,6 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.api.constants import ( - EduTypes, EventContentFields, EventTypes, ReceiptTypes, @@ -376,156 +374,6 @@ class SyncKnockTestCase(KnockingStrippedStateEventHelperMixin): ) -class ReadReceiptsTestCase(unittest.HomeserverTestCase): - servlets = [ - synapse.rest.admin.register_servlets, - login.register_servlets, - receipts.register_servlets, - room.register_servlets, - sync.register_servlets, - ] - - def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: - config = self.default_config() - - return self.setup_test_homeserver(config=config) - - def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: - self.url = "/sync?since=%s" - self.next_batch = "s0" - - # Register the first user - self.user_id = self.register_user("kermit", "monkey") - self.tok = self.login("kermit", "monkey") - - # Create the room - self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok) - - # Register the second user - self.user2 = self.register_user("kermit2", "monkey") - self.tok2 = self.login("kermit2", "monkey") - - # Join the second user - self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2) - - def test_private_read_receipts(self) -> None: - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a private read receipt to tell the server the first user's message was read - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that the first user can't see the other user's private read receipt - self.assertIsNone(self._get_read_receipt()) - - def test_public_receipt_can_override_private(self) -> None: - """ - Sending a public read receipt to the same event which has a private read - receipt should cause that receipt to become public. - """ - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a private read receipt - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - self.assertIsNone(self._get_read_receipt()) - - # Send a public read receipt - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that we did override the private read receipt - self.assertNotEqual(self._get_read_receipt(), None) - - def test_private_receipt_cannot_override_public(self) -> None: - """ - Sending a private read receipt to the same event which has a public read - receipt should cause no change. - """ - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a public read receipt - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - self.assertNotEqual(self._get_read_receipt(), None) - - # Send a private read receipt - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}", - {}, - access_token=self.tok2, - ) - self.assertEqual(channel.code, 200) - - # Test that we didn't override the public read receipt - self.assertIsNone(self._get_read_receipt()) - - def test_read_receipt_with_empty_body_is_rejected(self) -> None: - # Send a message as the first user - res = self.helper.send(self.room_id, body="hello", tok=self.tok) - - # Send a read receipt for this message with an empty body - channel = self.make_request( - "POST", - f"/rooms/{self.room_id}/receipt/m.read/{res['event_id']}", - access_token=self.tok2, - ) - self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST) - self.assertEqual(channel.json_body["errcode"], "M_NOT_JSON", channel.json_body) - - def _get_read_receipt(self) -> Optional[JsonDict]: - """Syncs and returns the read receipt.""" - - # Checks if event is a read receipt - def is_read_receipt(event: JsonDict) -> bool: - return event["type"] == EduTypes.RECEIPT - - # Sync - channel = self.make_request( - "GET", - self.url % self.next_batch, - access_token=self.tok, - ) - self.assertEqual(channel.code, 200) - - # Store the next batch for the next request. - self.next_batch = channel.json_body["next_batch"] - - if channel.json_body.get("rooms", None) is None: - return None - - # Return the read receipt - ephemeral_events = channel.json_body["rooms"]["join"][self.room_id][ - "ephemeral" - ]["events"] - receipt_event = filter(is_read_receipt, ephemeral_events) - return next(receipt_event, None) - - class UnreadMessagesTestCase(unittest.HomeserverTestCase): servlets = [ synapse.rest.admin.register_servlets, |