diff options
-rw-r--r-- | changelog.d/15842.feature | 1 | ||||
-rw-r--r-- | synapse/config/experimental.py | 3 | ||||
-rw-r--r-- | synapse/handlers/devicemessage.py | 57 | ||||
-rw-r--r-- | synapse/storage/databases/main/deviceinbox.py | 41 | ||||
-rw-r--r-- | tests/handlers/test_device.py | 121 |
5 files changed, 215 insertions, 8 deletions
diff --git a/changelog.d/15842.feature b/changelog.d/15842.feature new file mode 100644 index 0000000000..a2fc20162d --- /dev/null +++ b/changelog.d/15842.feature @@ -0,0 +1 @@ +Implements bullets 1 and 2 of [MSC 3944](https://github.com/matrix-org/matrix-spec-proposals/pull/3944) related to dropping cancelled and duplicated `m.room_key_request`. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 8e0f5356b4..1ac87bc304 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -389,3 +389,6 @@ class ExperimentalConfig(Config): self.msc4010_push_rules_account_data = experimental.get( "msc4010_push_rules_account_data", False ) + + # MSC3944: Dropping stale send-to-device messages + self.msc3944_enabled: bool = experimental.get("msc3944_enabled", False) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 3caf9b31cc..ae97bb60af 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from typing import TYPE_CHECKING, Any, Dict @@ -90,6 +91,8 @@ class DeviceMessageHandler: burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, ) + self._msc3944_enabled = hs.config.experimental.msc3944_enabled + async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: """ Handle receiving to-device messages from remote homeservers. @@ -220,7 +223,7 @@ class DeviceMessageHandler: set_tag(SynapseTags.TO_DEVICE_TYPE, message_type) set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id) - local_messages = {} + local_messages: Dict[str, Dict[str, JsonDict]] = {} remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): # add an opentracing log entry for each message @@ -255,16 +258,56 @@ class DeviceMessageHandler: # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - messages_by_device = { - device_id: { + for device_id, message_content in by_device.items(): + # Drop any previous identical (same request_id and requesting_device_id) + # room_key_request, ignoring the action property when comparing. + # This handles dropping previous identical and cancelled requests. + if ( + self._msc3944_enabled + and message_type == ToDeviceEventTypes.RoomKeyRequest + and user_id == sender_user_id + ): + req_id = message_content.get("request_id") + requesting_device_id = message_content.get( + "requesting_device_id" + ) + if req_id and requesting_device_id: + previous_request_deleted = False + for ( + stream_id, + message_json, + ) in await self.store.get_all_device_messages( + user_id, device_id + ): + orig_message = json.loads(message_json) + if ( + orig_message["type"] + == ToDeviceEventTypes.RoomKeyRequest + ): + content = orig_message.get("content", {}) + if ( + content.get("request_id") == req_id + and content.get("requesting_device_id") + == requesting_device_id + ): + if await self.store.delete_device_message( + stream_id + ): + previous_request_deleted = True + + if ( + message_content.get("action") == "request_cancellation" + and previous_request_deleted + ): + # Do not store the cancellation since we deleted the matching + # request(s) before it reaches the device. + continue + message = { "content": message_content, "type": message_type, "sender": sender_user_id, } - for device_id, message_content in by_device.items() - } - if messages_by_device: - local_messages[user_id] = messages_by_device + local_messages.setdefault(user_id, {})[device_id] = message else: destination = get_domain_from_id(user_id) remote_messages.setdefault(destination, {})[user_id] = by_device diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index b471fcb064..c08cc53661 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -27,6 +27,7 @@ from typing import ( ) from synapse.api.constants import EventContentFields +from synapse.api.errors import StoreError from synapse.logging import issue9533_logger from synapse.logging.opentracing import ( SynapseTags, @@ -891,6 +892,46 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) + async def delete_device_message(self, stream_id: int) -> bool: + """Delete a specific device message from the message inbox. + + Args: + stream_id: the stream ID identifying the message. + Returns: + True if the message has been deleted, False if it didn't exist. + """ + try: + await self.db_pool.simple_delete_one( + "device_inbox", + keyvalues={"stream_id": stream_id}, + desc="delete_device_message", + ) + except StoreError: + # Deletion failed because device message does not exist + return False + return True + + async def get_all_device_messages( + self, + user_id: str, + device_id: str, + ) -> List[Tuple[int, str]]: + """Get all device messages in the inbox from a specific device. + + Args: + user_id: the user ID of the device we want to query. + device_id: the device ID of the device we want to query. + Returns: + A list of (stream ID, message content) tuples. + """ + rows = await self.db_pool.simple_select_list( + table="device_inbox", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("stream_id", "message_json"), + desc="get_all_device_messages", + ) + return [(r["stream_id"], r["message_json"]) for r in rows] + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py index ee48f9e546..0e34ee208a 100644 --- a/tests/handlers/test_device.py +++ b/tests/handlers/test_device.py @@ -19,13 +19,14 @@ from unittest import mock from twisted.test.proto_helpers import MemoryReactor +import synapse from synapse.api.constants import RoomEncryptionAlgorithms from synapse.api.errors import NotFoundError, SynapseError from synapse.appservice import ApplicationService from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler from synapse.server import HomeServer from synapse.storage.databases.main.appservice import _make_exclusive_regex -from synapse.types import JsonDict +from synapse.types import JsonDict, create_requester from synapse.util import Clock from tests import unittest @@ -37,6 +38,11 @@ user2 = "@theresa:bbb" class DeviceTestCase(unittest.HomeserverTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + synapse.rest.client.login.register_servlets, + ] + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.appservice_api = mock.Mock() hs = self.setup_test_homeserver( @@ -47,6 +53,8 @@ class DeviceTestCase(unittest.HomeserverTestCase): handler = hs.get_device_handler() assert isinstance(handler, DeviceHandler) self.handler = handler + self.msg_handler = hs.get_device_message_handler() + self.event_sources = hs.get_event_sources() self.store = hs.get_datastores().main return hs @@ -398,6 +406,117 @@ class DeviceTestCase(unittest.HomeserverTestCase): ], ) + @override_config({"experimental_features": {"msc3944_enabled": True}}) + def test_duplicated_and_cancelled_room_key_request(self) -> None: + myuser = self.register_user("myuser", "pass") + self.login("myuser", "pass", "device") + self.login("myuser", "pass", "device2") + self.login("myuser", "pass", "device3") + + requester = requester = create_requester(myuser) + + from_token = self.event_sources.get_current_token() + + # This room_key_request is for device3 and should not be deleted. + self.get_success( + self.msg_handler.send_device_message( + requester, + "m.room_key_request", + { + myuser: { + "device3": { + "action": "request", + "request_id": "request_id", + "requesting_device_id": "device", + } + } + }, + ) + ) + + for _ in range(0, 2): + self.get_success( + self.msg_handler.send_device_message( + requester, + "m.room_key_request", + { + myuser: { + "device2": { + "action": "request", + "request_id": "request_id", + "requesting_device_id": "device", + } + } + }, + ) + ) + + to_token = self.event_sources.get_current_token() + + # Test that if we queue 2 identical room_key_request, + # only one is delivered to the device. + res = self.get_success( + self.store.get_messages_for_device( + myuser, + "device2", + from_token.to_device_key, + to_token.to_device_key, + ) + ) + self.assertEqual(len(res[0]), 1) + + # room_key_request for device3 should still be around. + res = self.get_success( + self.store.get_messages_for_device( + myuser, + "device3", + from_token.to_device_key, + to_token.to_device_key, + ) + ) + self.assertEqual(len(res[0]), 1) + + self.get_success( + self.msg_handler.send_device_message( + requester, + "m.room_key_request", + { + myuser: { + "device2": { + "action": "request_cancellation", + "request_id": "request_id", + "requesting_device_id": "device", + } + } + }, + ) + ) + + to_token = self.event_sources.get_current_token() + + # Test that if we cancel a room_key_request, both previous matching + # requests and the cancelled request are not delivered to the device. + res = self.get_success( + self.store.get_messages_for_device( + myuser, + "device2", + from_token.to_device_key, + to_token.to_device_key, + ) + ) + self.assertEqual(len(res[0]), 0) + + # room_key_request for device3 should still be around. + res = self.get_success( + self.store.get_messages_for_device( + myuser, + "device3", + from_token.to_device_key, + to_token.to_device_key, + ) + ) + self.assertEqual(len(res[0]), 1) + class DehydrationTestCase(unittest.HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: |