summary refs log tree commit diff
path: root/synapse/handlers/devicemessage.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/handlers/devicemessage.py57
1 files changed, 50 insertions, 7 deletions
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 3caf9b31cc..ae97bb60af 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import json
 import logging
 from typing import TYPE_CHECKING, Any, Dict
 
@@ -90,6 +91,8 @@ class DeviceMessageHandler:
             burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
         )
 
+        self._msc3944_enabled = hs.config.experimental.msc3944_enabled
+
     async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
         """
         Handle receiving to-device messages from remote homeservers.
@@ -220,7 +223,7 @@ class DeviceMessageHandler:
 
         set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
         set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
-        local_messages = {}
+        local_messages: Dict[str, Dict[str, JsonDict]] = {}
         remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for user_id, by_device in messages.items():
             # add an opentracing log entry for each message
@@ -255,16 +258,56 @@ class DeviceMessageHandler:
 
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
-                messages_by_device = {
-                    device_id: {
+                for device_id, message_content in by_device.items():
+                    # Drop any previous identical (same request_id and requesting_device_id)
+                    # room_key_request, ignoring the action property when comparing.
+                    # This handles dropping previous identical and cancelled requests.
+                    if (
+                        self._msc3944_enabled
+                        and message_type == ToDeviceEventTypes.RoomKeyRequest
+                        and user_id == sender_user_id
+                    ):
+                        req_id = message_content.get("request_id")
+                        requesting_device_id = message_content.get(
+                            "requesting_device_id"
+                        )
+                        if req_id and requesting_device_id:
+                            previous_request_deleted = False
+                            for (
+                                stream_id,
+                                message_json,
+                            ) in await self.store.get_all_device_messages(
+                                user_id, device_id
+                            ):
+                                orig_message = json.loads(message_json)
+                                if (
+                                    orig_message["type"]
+                                    == ToDeviceEventTypes.RoomKeyRequest
+                                ):
+                                    content = orig_message.get("content", {})
+                                    if (
+                                        content.get("request_id") == req_id
+                                        and content.get("requesting_device_id")
+                                        == requesting_device_id
+                                    ):
+                                        if await self.store.delete_device_message(
+                                            stream_id
+                                        ):
+                                            previous_request_deleted = True
+
+                            if (
+                                message_content.get("action") == "request_cancellation"
+                                and previous_request_deleted
+                            ):
+                                # Do not store the cancellation since we deleted the matching
+                                # request(s) before it reaches the device.
+                                continue
+                    message = {
                         "content": message_content,
                         "type": message_type,
                         "sender": sender_user_id,
                     }
-                    for device_id, message_content in by_device.items()
-                }
-                if messages_by_device:
-                    local_messages[user_id] = messages_by_device
+                    local_messages.setdefault(user_id, {})[device_id] = message
             else:
                 destination = get_domain_from_id(user_id)
                 remote_messages.setdefault(destination, {})[user_id] = by_device