diff options
author | Mathieu Velten <mathieuv@matrix.org> | 2023-06-27 00:39:10 +0200 |
---|---|---|
committer | Mathieu Velten <mathieuv@matrix.org> | 2023-06-27 09:49:42 +0200 |
commit | e25c15ea0f8c7356600c4ade27f7d92a420bea31 (patch) | |
tree | 40e4435cf01622e639bc88fd767ba88ddef705fd /synapse | |
parent | Bump serde_json from 1.0.97 to 1.0.99 (#15832) (diff) | |
download | synapse-mv/msc3944.tar.xz |
Implements part of MSC 3944 by dropping cancelled&duplicated `m.room_key_request` github/mv/msc3944 mv/msc3944
Diffstat (limited to '')
-rw-r--r-- | synapse/config/experimental.py | 3 | ||||
-rw-r--r-- | synapse/handlers/devicemessage.py | 57 | ||||
-rw-r--r-- | synapse/storage/databases/main/deviceinbox.py | 41 |
3 files changed, 94 insertions, 7 deletions
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 8e0f5356b4..1ac87bc304 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -389,3 +389,6 @@ class ExperimentalConfig(Config): self.msc4010_push_rules_account_data = experimental.get( "msc4010_push_rules_account_data", False ) + + # MSC3944: Dropping stale send-to-device messages + self.msc3944_enabled: bool = experimental.get("msc3944_enabled", False) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 3caf9b31cc..ae97bb60af 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging from typing import TYPE_CHECKING, Any, Dict @@ -90,6 +91,8 @@ class DeviceMessageHandler: burst_count=hs.config.ratelimiting.rc_key_requests.burst_count, ) + self._msc3944_enabled = hs.config.experimental.msc3944_enabled + async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: """ Handle receiving to-device messages from remote homeservers. @@ -220,7 +223,7 @@ class DeviceMessageHandler: set_tag(SynapseTags.TO_DEVICE_TYPE, message_type) set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id) - local_messages = {} + local_messages: Dict[str, Dict[str, JsonDict]] = {} remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} for user_id, by_device in messages.items(): # add an opentracing log entry for each message @@ -255,16 +258,56 @@ class DeviceMessageHandler: # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - messages_by_device = { - device_id: { + for device_id, message_content in by_device.items(): + # Drop any previous identical (same request_id and requesting_device_id) + # room_key_request, ignoring the action property when comparing. + # This handles dropping previous identical and cancelled requests. + if ( + self._msc3944_enabled + and message_type == ToDeviceEventTypes.RoomKeyRequest + and user_id == sender_user_id + ): + req_id = message_content.get("request_id") + requesting_device_id = message_content.get( + "requesting_device_id" + ) + if req_id and requesting_device_id: + previous_request_deleted = False + for ( + stream_id, + message_json, + ) in await self.store.get_all_device_messages( + user_id, device_id + ): + orig_message = json.loads(message_json) + if ( + orig_message["type"] + == ToDeviceEventTypes.RoomKeyRequest + ): + content = orig_message.get("content", {}) + if ( + content.get("request_id") == req_id + and content.get("requesting_device_id") + == requesting_device_id + ): + if await self.store.delete_device_message( + stream_id + ): + previous_request_deleted = True + + if ( + message_content.get("action") == "request_cancellation" + and previous_request_deleted + ): + # Do not store the cancellation since we deleted the matching + # request(s) before it reaches the device. + continue + message = { "content": message_content, "type": message_type, "sender": sender_user_id, } - for device_id, message_content in by_device.items() - } - if messages_by_device: - local_messages[user_id] = messages_by_device + local_messages.setdefault(user_id, {})[device_id] = message else: destination = get_domain_from_id(user_id) remote_messages.setdefault(destination, {})[user_id] = by_device diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index b471fcb064..c08cc53661 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -27,6 +27,7 @@ from typing import ( ) from synapse.api.constants import EventContentFields +from synapse.api.errors import StoreError from synapse.logging import issue9533_logger from synapse.logging.opentracing import ( SynapseTags, @@ -891,6 +892,46 @@ class DeviceInboxWorkerStore(SQLBaseStore): ], ) + async def delete_device_message(self, stream_id: int) -> bool: + """Delete a specific device message from the message inbox. + + Args: + stream_id: the stream ID identifying the message. + Returns: + True if the message has been deleted, False if it didn't exist. + """ + try: + await self.db_pool.simple_delete_one( + "device_inbox", + keyvalues={"stream_id": stream_id}, + desc="delete_device_message", + ) + except StoreError: + # Deletion failed because device message does not exist + return False + return True + + async def get_all_device_messages( + self, + user_id: str, + device_id: str, + ) -> List[Tuple[int, str]]: + """Get all device messages in the inbox from a specific device. + + Args: + user_id: the user ID of the device we want to query. + device_id: the device ID of the device we want to query. + Returns: + A list of (stream ID, message content) tuples. + """ + rows = await self.db_pool.simple_select_list( + table="device_inbox", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcols=("stream_id", "message_json"), + desc="get_all_device_messages", + ) + return [(r["stream_id"], r["message_json"]) for r in rows] + class DeviceInboxBackgroundUpdateStore(SQLBaseStore): DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" |