summary refs log tree commit diff
diff options
context:
space:
mode:
authorMathieu Velten <mathieuv@matrix.org>2023-06-27 00:39:10 +0200
committerMathieu Velten <mathieuv@matrix.org>2023-06-27 09:49:42 +0200
commite25c15ea0f8c7356600c4ade27f7d92a420bea31 (patch)
tree40e4435cf01622e639bc88fd767ba88ddef705fd
parentBump serde_json from 1.0.97 to 1.0.99 (#15832) (diff)
downloadsynapse-mv/msc3944.tar.xz
Implements part of MSC 3944 by dropping cancelled&duplicated `m.room_key_request` github/mv/msc3944 mv/msc3944
-rw-r--r--changelog.d/15842.feature1
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/handlers/devicemessage.py57
-rw-r--r--synapse/storage/databases/main/deviceinbox.py41
-rw-r--r--tests/handlers/test_device.py121
5 files changed, 215 insertions, 8 deletions
diff --git a/changelog.d/15842.feature b/changelog.d/15842.feature
new file mode 100644
index 0000000000..a2fc20162d
--- /dev/null
+++ b/changelog.d/15842.feature
@@ -0,0 +1 @@
+Implements bullets 1 and 2 of [MSC 3944](https://github.com/matrix-org/matrix-spec-proposals/pull/3944) related to dropping cancelled and duplicated `m.room_key_request`.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 8e0f5356b4..1ac87bc304 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -389,3 +389,6 @@ class ExperimentalConfig(Config):
         self.msc4010_push_rules_account_data = experimental.get(
             "msc4010_push_rules_account_data", False
         )
+
+        # MSC3944: Dropping stale send-to-device messages
+        self.msc3944_enabled: bool = experimental.get("msc3944_enabled", False)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 3caf9b31cc..ae97bb60af 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import json
 import logging
 from typing import TYPE_CHECKING, Any, Dict
 
@@ -90,6 +91,8 @@ class DeviceMessageHandler:
             burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
         )
 
+        self._msc3944_enabled = hs.config.experimental.msc3944_enabled
+
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
         """
         Handle receiving to-device messages from remote homeservers.
@@ -220,7 +223,7 @@ class DeviceMessageHandler:
 
         set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
         set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
-        local_messages = {}
+        local_messages: Dict[str, Dict[str, JsonDict]] = {}
         remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for user_id, by_device in messages.items():
             # add an opentracing log entry for each message
@@ -255,16 +258,56 @@ class DeviceMessageHandler:
 
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                messages_by_device = {
-                    device_id: {
+                for device_id, message_content in by_device.items():
+                    # Drop any previous identical (same request_id and requesting_device_id)
+                    # room_key_request, ignoring the action property when comparing.
+                    # This handles dropping previous identical and cancelled requests.
+                    if (
+                        self._msc3944_enabled
+                        and message_type == ToDeviceEventTypes.RoomKeyRequest
+                        and user_id == sender_user_id
+                    ):
+                        req_id = message_content.get("request_id")
+                        requesting_device_id = message_content.get(
+                            "requesting_device_id"
+                        )
+                        if req_id and requesting_device_id:
+                            previous_request_deleted = False
+                            for (
+                                stream_id,
+                                message_json,
+                            ) in await self.store.get_all_device_messages(
+                                user_id, device_id
+                            ):
+                                orig_message = json.loads(message_json)
+                                if (
+                                    orig_message["type"]
+                                    == ToDeviceEventTypes.RoomKeyRequest
+                                ):
+                                    content = orig_message.get("content", {})
+                                    if (
+                                        content.get("request_id") == req_id
+                                        and content.get("requesting_device_id")
+                                        == requesting_device_id
+                                    ):
+                                        if await self.store.delete_device_message(
+                                            stream_id
+                                        ):
+                                            previous_request_deleted = True
+
+                            if (
+                                message_content.get("action") == "request_cancellation"
+                                and previous_request_deleted
+                            ):
+                                # Do not store the cancellation since we deleted the matching
+                                # request(s) before it reaches the device.
+                                continue
+                    message = {
                         "content": message_content,
                         "type": message_type,
                         "sender": sender_user_id,
                     }
-                    for device_id, message_content in by_device.items()
-                }
-                if messages_by_device:
-                    local_messages[user_id] = messages_by_device
+                    local_messages.setdefault(user_id, {})[device_id] = message
             else:
                 destination = get_domain_from_id(user_id)
                 remote_messages.setdefault(destination, {})[user_id] = by_device
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b471fcb064..c08cc53661 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -27,6 +27,7 @@ from typing import (
 )
 
 from synapse.api.constants import EventContentFields
+from synapse.api.errors import StoreError
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import (
     SynapseTags,
@@ -891,6 +892,46 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 ],
             )
 
+    async def delete_device_message(self, stream_id: int) -> bool:
+        """Delete a specific device message from the message inbox.
+
+        Args:
+            stream_id: the stream ID identifying the message.
+        Returns:
+            True if the message has been deleted, False if it didn't exist.
+        """
+        try:
+            await self.db_pool.simple_delete_one(
+                "device_inbox",
+                keyvalues={"stream_id": stream_id},
+                desc="delete_device_message",
+            )
+        except StoreError:
+            # Deletion failed because device message does not exist
+            return False
+        return True
+
+    async def get_all_device_messages(
+        self,
+        user_id: str,
+        device_id: str,
+    ) -> List[Tuple[int, str]]:
+        """Get all device messages in the inbox from a specific device.
+
+        Args:
+            user_id: the user ID of the device we want to query.
+            device_id: the device ID of the device we want to query.
+        Returns:
+            A list of (stream ID, message content) tuples.
+        """
+        rows = await self.db_pool.simple_select_list(
+            table="device_inbox",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            retcols=("stream_id", "message_json"),
+            desc="get_all_device_messages",
+        )
+        return [(r["stream_id"], r["message_json"]) for r in rows]
+
 
 class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
     DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ee48f9e546..0e34ee208a 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -19,13 +19,14 @@ from unittest import mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
+import synapse
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -37,6 +38,11 @@ user2 = "@theresa:bbb"
 
 
 class DeviceTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        synapse.rest.admin.register_servlets,
+        synapse.rest.client.login.register_servlets,
+    ]
+
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.appservice_api = mock.Mock()
         hs = self.setup_test_homeserver(
@@ -47,6 +53,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
+        self.msg_handler = hs.get_device_message_handler()
+        self.event_sources = hs.get_event_sources()
         self.store = hs.get_datastores().main
         return hs
 
@@ -398,6 +406,117 @@ class DeviceTestCase(unittest.HomeserverTestCase):
             ],
         )
 
+    @override_config({"experimental_features": {"msc3944_enabled": True}})
+    def test_duplicated_and_cancelled_room_key_request(self) -> None:
+        myuser = self.register_user("myuser", "pass")
+        self.login("myuser", "pass", "device")
+        self.login("myuser", "pass", "device2")
+        self.login("myuser", "pass", "device3")
+
+        requester = requester = create_requester(myuser)
+
+        from_token = self.event_sources.get_current_token()
+
+        # This room_key_request is for device3 and should not be deleted.
+        self.get_success(
+            self.msg_handler.send_device_message(
+                requester,
+                "m.room_key_request",
+                {
+                    myuser: {
+                        "device3": {
+                            "action": "request",
+                            "request_id": "request_id",
+                            "requesting_device_id": "device",
+                        }
+                    }
+                },
+            )
+        )
+
+        for _ in range(0, 2):
+            self.get_success(
+                self.msg_handler.send_device_message(
+                    requester,
+                    "m.room_key_request",
+                    {
+                        myuser: {
+                            "device2": {
+                                "action": "request",
+                                "request_id": "request_id",
+                                "requesting_device_id": "device",
+                            }
+                        }
+                    },
+                )
+            )
+
+            to_token = self.event_sources.get_current_token()
+
+            # Test that if we queue 2 identical room_key_request,
+            # only one is delivered to the device.
+            res = self.get_success(
+                self.store.get_messages_for_device(
+                    myuser,
+                    "device2",
+                    from_token.to_device_key,
+                    to_token.to_device_key,
+                )
+            )
+            self.assertEqual(len(res[0]), 1)
+
+        # room_key_request for device3 should still be around.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device3",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 1)
+
+        self.get_success(
+            self.msg_handler.send_device_message(
+                requester,
+                "m.room_key_request",
+                {
+                    myuser: {
+                        "device2": {
+                            "action": "request_cancellation",
+                            "request_id": "request_id",
+                            "requesting_device_id": "device",
+                        }
+                    }
+                },
+            )
+        )
+
+        to_token = self.event_sources.get_current_token()
+
+        # Test that if we cancel a room_key_request, both previous matching
+        # requests and the cancelled request are not delivered to the device.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device2",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 0)
+
+        # room_key_request for device3 should still be around.
+        res = self.get_success(
+            self.store.get_messages_for_device(
+                myuser,
+                "device3",
+                from_token.to_device_key,
+                to_token.to_device_key,
+            )
+        )
+        self.assertEqual(len(res[0]), 1)
+
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: