diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 0043cbea17..05c4b3eec0 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,12 +14,14 @@
# limitations under the License.
import logging
+from typing import Any, Dict
from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
from synapse.logging.opentracing import (
get_active_span_text_map,
log_kv,
@@ -47,12 +49,14 @@ class DeviceMessageHandler(object):
"m.direct_to_device", self.on_direct_to_device_edu
)
+ self._device_list_updater = hs.get_device_handler().device_list_updater
+
@defer.inlineCallbacks
def on_direct_to_device_edu(self, origin, content):
local_messages = {}
sender_user_id = content["sender"]
if origin != get_domain_from_id(sender_user_id):
- logger.warn(
+ logger.warning(
"Dropping device message from %r with spoofed sender %r",
origin,
sender_user_id,
@@ -65,6 +69,9 @@ class DeviceMessageHandler(object):
logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
+ if not by_device:
+ continue
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -73,8 +80,11 @@ class DeviceMessageHandler(object):
}
for device_id, message_content in by_device.items()
}
- if messages_by_device:
- local_messages[user_id] = messages_by_device
+ local_messages[user_id] = messages_by_device
+
+ yield self._check_for_unknown_devices(
+ message_type, sender_user_id, by_device
+ )
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
@@ -85,6 +95,55 @@ class DeviceMessageHandler(object):
)
@defer.inlineCallbacks
+ def _check_for_unknown_devices(
+ self,
+ message_type: str,
+ sender_user_id: str,
+ by_device: Dict[str, Dict[str, Any]],
+ ):
+ """Checks inbound device messages for unkown remote devices, and if
+ found marks the remote cache for the user as stale.
+ """
+
+ if message_type != "m.room_key_request":
+ return
+
+ # Get the sending device IDs
+ requesting_device_ids = set()
+ for message_content in by_device.values():
+ device_id = message_content.get("requesting_device_id")
+ requesting_device_ids.add(device_id)
+
+ # Check if we are tracking the devices of the remote user.
+ room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+ if not room_ids:
+ logger.info(
+ "Received device message from remote device we don't"
+ " share a room with: %s %s",
+ sender_user_id,
+ requesting_device_ids,
+ )
+ return
+
+ # If we are tracking check that we know about the sending
+ # devices.
+ cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+
+ unknown_devices = requesting_device_ids - set(cached_devices)
+ if unknown_devices:
+ logger.info(
+ "Received device message from remote device not in our cache: %s %s",
+ sender_user_id,
+ unknown_devices,
+ )
+ yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+
+ # Immediately attempt a resync in the background
+ run_in_background(
+ self._device_list_updater.user_device_resync, sender_user_id
+ )
+
+ @defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
|