diff options
Diffstat (limited to 'synapse/handlers/devicemessage.py')
-rw-r--r-- | synapse/handlers/devicemessage.py | 57 |
1 files changed, 50 insertions, 7 deletions
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 3caf9b31cc..ae97bb60af 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from typing import TYPE_CHECKING, Any, Dict @@ -90,6 +91,8 @@ class DeviceMessageHandler: burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, ) + self._msc3944_enabled = hs.config.experimental.msc3944_enabled + async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: """ Handle receiving to-device messages from remote homeservers. @@ -220,7 +223,7 @@ class DeviceMessageHandler: set_tag(SynapseTags.TO_DEVICE_TYPE, message_type) set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id) - local_messages = {} + local_messages: Dict[str, Dict[str, JsonDict]] = {} remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): # add an opentracing log entry for each message @@ -255,16 +258,56 @@ class DeviceMessageHandler: # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - messages_by_device = { - device_id: { + for device_id, message_content in by_device.items(): + # Drop any previous identical (same request_id and requesting_device_id) + # room_key_request, ignoring the action property when comparing. + # This handles dropping previous identical and cancelled requests. + if ( + self._msc3944_enabled + and message_type == ToDeviceEventTypes.RoomKeyRequest + and user_id == sender_user_id + ): + req_id = message_content.get("request_id") + requesting_device_id = message_content.get( + "requesting_device_id" + ) + if req_id and requesting_device_id: + previous_request_deleted = False + for ( + stream_id, + message_json, + ) in await self.store.get_all_device_messages( + user_id, device_id + ): + orig_message = json.loads(message_json) + if ( + orig_message["type"] + == ToDeviceEventTypes.RoomKeyRequest + ): + content = orig_message.get("content", {}) + if ( + content.get("request_id") == req_id + and content.get("requesting_device_id") + == requesting_device_id + ): + if await self.store.delete_device_message( + stream_id + ): + previous_request_deleted = True + + if ( + message_content.get("action") == "request_cancellation" + and previous_request_deleted + ): + # Do not store the cancellation since we deleted the matching + # request(s) before it reaches the device. + continue + message = { "content": message_content, "type": message_type, "sender": sender_user_id, } - for device_id, message_content in by_device.items() - } - if messages_by_device: - local_messages[user_id] = messages_by_device + local_messages.setdefault(user_id, {})[device_id] = message else: destination = get_domain_from_id(user_id) remote_messages.setdefault(destination, {})[user_id] = by_device |