summary refs log tree commit diff
path: root/synapse/handlers/devicemessage.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/devicemessage.py')
-rw-r--r--synapse/handlers/devicemessage.py102
1 files changed, 88 insertions, 14 deletions
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py

index 2e2e5261de..05c4b3eec0 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -14,10 +14,20 @@ # limitations under the License. import logging +from typing import Any, Dict + +from canonicaljson import json from twisted.internet import defer from synapse.api.errors import SynapseError +from synapse.logging.context import run_in_background +from synapse.logging.opentracing import ( + get_active_span_text_map, + log_kv, + set_tag, + start_active_span, +) from synapse.types import UserID, get_domain_from_id from synapse.util.stringutils import random_string @@ -25,7 +35,6 @@ logger = logging.getLogger(__name__) class DeviceMessageHandler(object): - def __init__(self, hs): """ Args: @@ -40,24 +49,29 @@ class DeviceMessageHandler(object): "m.direct_to_device", self.on_direct_to_device_edu ) + self._device_list_updater = hs.get_device_handler().device_list_updater + @defer.inlineCallbacks def on_direct_to_device_edu(self, origin, content): local_messages = {} sender_user_id = content["sender"] if origin != get_domain_from_id(sender_user_id): - logger.warn( + logger.warning( "Dropping device message from %r with spoofed sender %r", - origin, sender_user_id + origin, + sender_user_id, ) message_type = content["type"] message_id = content["message_id"] for user_id, by_device in content["messages"].items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): - logger.warning("Request for keys for non-local user %s", - user_id) + logger.warning("Request for keys for non-local user %s", user_id) raise SynapseError(400, "Not a user here") + if not by_device: + continue + messages_by_device = { device_id: { "content": message_content, @@ -66,8 +80,11 @@ class DeviceMessageHandler(object): } for device_id, message_content in by_device.items() } - if messages_by_device: - local_messages[user_id] = messages_by_device + local_messages[user_id] = messages_by_device + + yield self._check_for_unknown_devices( + message_type, sender_user_id, by_device + ) stream_id = yield self.store.add_messages_from_remote_to_device_inbox( origin, message_id, local_messages @@ -78,8 +95,58 @@ class DeviceMessageHandler(object): ) @defer.inlineCallbacks - def send_device_message(self, sender_user_id, message_type, messages): + def _check_for_unknown_devices( + self, + message_type: str, + sender_user_id: str, + by_device: Dict[str, Dict[str, Any]], + ): + """Checks inbound device messages for unkown remote devices, and if + found marks the remote cache for the user as stale. + """ + + if message_type != "m.room_key_request": + return + + # Get the sending device IDs + requesting_device_ids = set() + for message_content in by_device.values(): + device_id = message_content.get("requesting_device_id") + requesting_device_ids.add(device_id) + + # Check if we are tracking the devices of the remote user. + room_ids = yield self.store.get_rooms_for_user(sender_user_id) + if not room_ids: + logger.info( + "Received device message from remote device we don't" + " share a room with: %s %s", + sender_user_id, + requesting_device_ids, + ) + return + + # If we are tracking check that we know about the sending + # devices. + cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id) + + unknown_devices = requesting_device_ids - set(cached_devices) + if unknown_devices: + logger.info( + "Received device message from remote device not in our cache: %s %s", + sender_user_id, + unknown_devices, + ) + yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id) + + # Immediately attempt a resync in the background + run_in_background( + self._device_list_updater.user_device_resync, sender_user_id + ) + @defer.inlineCallbacks + def send_device_message(self, sender_user_id, message_type, messages): + set_tag("number_of_messages", len(messages)) + set_tag("sender", sender_user_id) local_messages = {} remote_messages = {} for user_id, by_device in messages.items(): @@ -101,15 +168,21 @@ class DeviceMessageHandler(object): message_id = random_string(16) + context = get_active_span_text_map() + remote_edu_contents = {} for destination, messages in remote_messages.items(): - remote_edu_contents[destination] = { - "messages": messages, - "sender": sender_user_id, - "type": message_type, - "message_id": message_id, - } + with start_active_span("to_device_for_user"): + set_tag("destination", destination) + remote_edu_contents[destination] = { + "messages": messages, + "sender": sender_user_id, + "type": message_type, + "message_id": message_id, + "org.matrix.opentracing_context": json.dumps(context), + } + log_kv({"local_messages": local_messages}) stream_id = yield self.store.add_messages_to_device_inbox( local_messages, remote_edu_contents ) @@ -118,6 +191,7 @@ class DeviceMessageHandler(object): "to_device_key", stream_id, users=local_messages.keys() ) + log_kv({"remote_messages": remote_messages}) for destination in remote_messages.keys(): # Enqueue a new federation transaction to send the new # device messages to each remote destination.