summary refs log tree commit diff
path: root/synapse/handlers/devicemessage.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/devicemessage.py')
-rw-r--r--synapse/handlers/devicemessage.py63
1 files changed, 61 insertions, 2 deletions
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 73b9e120f5..05c4b3eec0 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,12 +14,14 @@
 # limitations under the License.
 
 import logging
+from typing import Any, Dict
 
 from canonicaljson import json
 
 from twisted.internet import defer
 
 from synapse.api.errors import SynapseError
+from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
     get_active_span_text_map,
     log_kv,
@@ -47,6 +49,8 @@ class DeviceMessageHandler(object):
             "m.direct_to_device", self.on_direct_to_device_edu
         )
 
+        self._device_list_updater = hs.get_device_handler().device_list_updater
+
     @defer.inlineCallbacks
     def on_direct_to_device_edu(self, origin, content):
         local_messages = {}
@@ -65,6 +69,9 @@ class DeviceMessageHandler(object):
                 logger.warning("Request for keys for non-local user %s", user_id)
                 raise SynapseError(400, "Not a user here")
 
+            if not by_device:
+                continue
+
             messages_by_device = {
                 device_id: {
                     "content": message_content,
@@ -73,8 +80,11 @@ class DeviceMessageHandler(object):
                 }
                 for device_id, message_content in by_device.items()
             }
-            if messages_by_device:
-                local_messages[user_id] = messages_by_device
+            local_messages[user_id] = messages_by_device
+
+            yield self._check_for_unknown_devices(
+                message_type, sender_user_id, by_device
+            )
 
         stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
             origin, message_id, local_messages
@@ -85,6 +95,55 @@ class DeviceMessageHandler(object):
         )
 
     @defer.inlineCallbacks
+    def _check_for_unknown_devices(
+        self,
+        message_type: str,
+        sender_user_id: str,
+        by_device: Dict[str, Dict[str, Any]],
+    ):
+        """Checks inbound device messages for unkown remote devices, and if
+        found marks the remote cache for the user as stale.
+        """
+
+        if message_type != "m.room_key_request":
+            return
+
+        # Get the sending device IDs
+        requesting_device_ids = set()
+        for message_content in by_device.values():
+            device_id = message_content.get("requesting_device_id")
+            requesting_device_ids.add(device_id)
+
+        # Check if we are tracking the devices of the remote user.
+        room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+        if not room_ids:
+            logger.info(
+                "Received device message from remote device we don't"
+                " share a room with: %s %s",
+                sender_user_id,
+                requesting_device_ids,
+            )
+            return
+
+        # If we are tracking check that we know about the sending
+        # devices.
+        cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+
+        unknown_devices = requesting_device_ids - set(cached_devices)
+        if unknown_devices:
+            logger.info(
+                "Received device message from remote device not in our cache: %s %s",
+                sender_user_id,
+                unknown_devices,
+            )
+            yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+
+            # Immediately attempt a resync in the background
+            run_in_background(
+                self._device_list_updater.user_device_resync, sender_user_id
+            )
+
+    @defer.inlineCallbacks
     def send_device_message(self, sender_user_id, message_type, messages):
         set_tag("number_of_messages", len(messages))
         set_tag("sender", sender_user_id)