diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 2e2e5261de..05c4b3eec0 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,10 +14,20 @@
# limitations under the License.
import logging
+from typing import Any, Dict
+
+from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
+from synapse.logging.opentracing import (
+ get_active_span_text_map,
+ log_kv,
+ set_tag,
+ start_active_span,
+)
from synapse.types import UserID, get_domain_from_id
from synapse.util.stringutils import random_string
@@ -25,7 +35,6 @@ logger = logging.getLogger(__name__)
class DeviceMessageHandler(object):
-
def __init__(self, hs):
"""
Args:
@@ -40,24 +49,29 @@ class DeviceMessageHandler(object):
"m.direct_to_device", self.on_direct_to_device_edu
)
+ self._device_list_updater = hs.get_device_handler().device_list_updater
+
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
- logger.warn(
+ logger.warning(
"Dropping device message from %r with spoofed sender %r",
- origin, sender_user_id
+ origin,
+ sender_user_id,
)
message_type = content["type"]
message_id = content["message_id"]
for user_id, by_device in content["messages"].items():
# we use UserID.from_string to catch invalid user ids
if not self.is_mine(UserID.from_string(user_id)):
- logger.warning("Request for keys for non-local user %s",
- user_id)
+ logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
+ if not by_device:
+ continue
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -66,8 +80,11 @@ class DeviceMessageHandler(object):
}
for device_id, message_content in by_device.items()
}
- if messages_by_device:
- local_messages[user_id] = messages_by_device
+ local_messages[user_id] = messages_by_device
+
+ yield self._check_for_unknown_devices(
+ message_type, sender_user_id, by_device
+ )
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
@@ -78,8 +95,58 @@ class DeviceMessageHandler(object):
)
@defer.inlineCallbacks
- def send_device_message(self, sender_user_id, message_type, messages):
+ def _check_for_unknown_devices(
+ self,
+ message_type: str,
+ sender_user_id: str,
+ by_device: Dict[str, Dict[str, Any]],
+ ):
+ """Checks inbound device messages for unkown remote devices, and if
+ found marks the remote cache for the user as stale.
+ """
+
+ if message_type != "m.room_key_request":
+ return
+
+ # Get the sending device IDs
+ requesting_device_ids = set()
+ for message_content in by_device.values():
+ device_id = message_content.get("requesting_device_id")
+ requesting_device_ids.add(device_id)
+
+ # Check if we are tracking the devices of the remote user.
+ room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+ if not room_ids:
+ logger.info(
+ "Received device message from remote device we don't"
+ " share a room with: %s %s",
+ sender_user_id,
+ requesting_device_ids,
+ )
+ return
+
+ # If we are tracking check that we know about the sending
+ # devices.
+ cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+
+ unknown_devices = requesting_device_ids - set(cached_devices)
+ if unknown_devices:
+ logger.info(
+ "Received device message from remote device not in our cache: %s %s",
+ sender_user_id,
+ unknown_devices,
+ )
+ yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+
+ # Immediately attempt a resync in the background
+ run_in_background(
+ self._device_list_updater.user_device_resync, sender_user_id
+ )
+ @defer.inlineCallbacks
+ def send_device_message(self, sender_user_id, message_type, messages):
+ set_tag("number_of_messages", len(messages))
+ set_tag("sender", sender_user_id)
local_messages = {}
remote_messages = {}
for user_id, by_device in messages.items():
@@ -101,15 +168,21 @@ class DeviceMessageHandler(object):
message_id = random_string(16)
+ context = get_active_span_text_map()
+
remote_edu_contents = {}
for destination, messages in remote_messages.items():
- remote_edu_contents[destination] = {
- "messages": messages,
- "sender": sender_user_id,
- "type": message_type,
- "message_id": message_id,
- }
+ with start_active_span("to_device_for_user"):
+ set_tag("destination", destination)
+ remote_edu_contents[destination] = {
+ "messages": messages,
+ "sender": sender_user_id,
+ "type": message_type,
+ "message_id": message_id,
+ "org.matrix.opentracing_context": json.dumps(context),
+ }
+ log_kv({"local_messages": local_messages})
stream_id = yield self.store.add_messages_to_device_inbox(
local_messages, remote_edu_contents
)
@@ -118,6 +191,7 @@ class DeviceMessageHandler(object):
"to_device_key", stream_id, users=local_messages.keys()
)
+ log_kv({"remote_messages": remote_messages})
for destination in remote_messages.keys():
# Enqueue a new federation transaction to send the new
# device messages to each remote destination.
|