diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 73b9e120f5..5c5fe77be2 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Any, Dict
from canonicaljson import json
@@ -65,6 +66,9 @@ class DeviceMessageHandler(object):
logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
+ if not by_device:
+ continue
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -73,8 +77,11 @@ class DeviceMessageHandler(object):
}
for device_id, message_content in by_device.items()
}
- if messages_by_device:
- local_messages[user_id] = messages_by_device
+ local_messages[user_id] = messages_by_device
+
+ yield self._check_for_unknown_devices(
+ message_type, sender_user_id, by_device
+ )
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
@@ -85,6 +92,52 @@ class DeviceMessageHandler(object):
)
@defer.inlineCallbacks
+ def _check_for_unknown_devices(
+ self,
+ message_type: str,
+ sender_user_id: str,
+ by_device: Dict[str, Dict[str, Any]],
+ ):
+ """Checks inbound device messages for unkown remote devices, and if
+ found marks the remote cache for the user as stale.
+ """
+
+ if message_type != "m.room_key_request":
+ return
+
+ # Get the sending device IDs
+ requesting_device_ids = set()
+ for message_content in by_device.values():
+ device_id = message_content.get("requesting_device_id")
+ requesting_device_ids.add(device_id)
+
+ # Check if we are tracking the devices of the remote user.
+ room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+ if not room_ids:
+ logger.info(
+ "Received device message from remote device we don't"
+ " share a room with: %s %s",
+ sender_user_id,
+ requesting_device_ids,
+ )
+ return
+
+ # If we are tracking check that we know about the sending
+ # devices.
+ cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+
+ unknown_devices = requesting_device_ids - set(cached_devices)
+ if unknown_devices:
+ logger.info(
+ "Received device message from remote device not in our cache: %s %s",
+ sender_user_id,
+ unknown_devices,
+ )
+ yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+ # TODO: Poke something to start trying to refetch user's
+ # keys.
+
+ @defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
|