diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 3caf9b31cc..ae97bb60af 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import json
import logging
from typing import TYPE_CHECKING, Any, Dict
@@ -90,6 +91,8 @@ class DeviceMessageHandler:
burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
)
+ self._msc3944_enabled = hs.config.experimental.msc3944_enabled
+
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
"""
Handle receiving to-device messages from remote homeservers.
@@ -220,7 +223,7 @@ class DeviceMessageHandler:
set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
- local_messages = {}
+ local_messages: Dict[str, Dict[str, JsonDict]] = {}
remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for user_id, by_device in messages.items():
# add an opentracing log entry for each message
@@ -255,16 +258,56 @@ class DeviceMessageHandler:
# we use UserID.from_string to catch invalid user ids
if self.is_mine(UserID.from_string(user_id)):
- messages_by_device = {
- device_id: {
+ for device_id, message_content in by_device.items():
+ # Drop any previous identical (same request_id and requesting_device_id)
+ # room_key_request, ignoring the action property when comparing.
+ # This handles dropping previous identical and cancelled requests.
+ if (
+ self._msc3944_enabled
+ and message_type == ToDeviceEventTypes.RoomKeyRequest
+ and user_id == sender_user_id
+ ):
+ req_id = message_content.get("request_id")
+ requesting_device_id = message_content.get(
+ "requesting_device_id"
+ )
+ if req_id and requesting_device_id:
+ previous_request_deleted = False
+ for (
+ stream_id,
+ message_json,
+ ) in await self.store.get_all_device_messages(
+ user_id, device_id
+ ):
+ orig_message = json.loads(message_json)
+ if (
+ orig_message["type"]
+ == ToDeviceEventTypes.RoomKeyRequest
+ ):
+ content = orig_message.get("content", {})
+ if (
+ content.get("request_id") == req_id
+ and content.get("requesting_device_id")
+ == requesting_device_id
+ ):
+ if await self.store.delete_device_message(
+ stream_id
+ ):
+ previous_request_deleted = True
+
+ if (
+ message_content.get("action") == "request_cancellation"
+ and previous_request_deleted
+ ):
+ # Do not store the cancellation since we deleted the matching
+ # request(s) before it reaches the device.
+ continue
+ message = {
"content": message_content,
"type": message_type,
"sender": sender_user_id,
}
- for device_id, message_content in by_device.items()
- }
- if messages_by_device:
- local_messages[user_id] = messages_by_device
+ local_messages.setdefault(user_id, {})[device_id] = message
else:
destination = get_domain_from_id(user_id)
remote_messages.setdefault(destination, {})[user_id] = by_device
|