diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 73b9e120f5..5c5fe77be2 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,6 +14,7 @@
# limitations under the License.
import logging
+from typing import Any, Dict
from canonicaljson import json
@@ -65,6 +66,9 @@ class DeviceMessageHandler(object):
logger.warning("Request for keys for non-local user %s", user_id)
raise SynapseError(400, "Not a user here")
+ if not by_device:
+ continue
+
messages_by_device = {
device_id: {
"content": message_content,
@@ -73,8 +77,11 @@ class DeviceMessageHandler(object):
}
for device_id, message_content in by_device.items()
}
- if messages_by_device:
- local_messages[user_id] = messages_by_device
+ local_messages[user_id] = messages_by_device
+
+ yield self._check_for_unknown_devices(
+ message_type, sender_user_id, by_device
+ )
stream_id = yield self.store.add_messages_from_remote_to_device_inbox(
origin, message_id, local_messages
@@ -85,6 +92,52 @@ class DeviceMessageHandler(object):
)
@defer.inlineCallbacks
+ def _check_for_unknown_devices(
+ self,
+ message_type: str,
+ sender_user_id: str,
+ by_device: Dict[str, Dict[str, Any]],
+ ):
+ """Checks inbound device messages for unkown remote devices, and if
+ found marks the remote cache for the user as stale.
+ """
+
+ if message_type != "m.room_key_request":
+ return
+
+ # Get the sending device IDs
+ requesting_device_ids = set()
+ for message_content in by_device.values():
+ device_id = message_content.get("requesting_device_id")
+ requesting_device_ids.add(device_id)
+
+ # Check if we are tracking the devices of the remote user.
+ room_ids = yield self.store.get_rooms_for_user(sender_user_id)
+ if not room_ids:
+ logger.info(
+ "Received device message from remote device we don't"
+ " share a room with: %s %s",
+ sender_user_id,
+ requesting_device_ids,
+ )
+ return
+
+ # If we are tracking check that we know about the sending
+ # devices.
+ cached_devices = yield self.store.get_cached_devices_for_user(sender_user_id)
+
+ unknown_devices = requesting_device_ids - set(cached_devices)
+ if unknown_devices:
+ logger.info(
+ "Received device message from remote device not in our cache: %s %s",
+ sender_user_id,
+ unknown_devices,
+ )
+ yield self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
+ # TODO: Poke something to start trying to refetch user's
+ # keys.
+
+ @defer.inlineCallbacks
def send_device_message(self, sender_user_id, message_type, messages):
set_tag("number_of_messages", len(messages))
set_tag("sender", sender_user_id)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 180f165a7a..a67020a259 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -742,6 +742,26 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
await self.user_joined_room(user, room_id)
+ # For encrypted messages we check that we know about the sending device,
+ # if we don't then we mark the device cache for that user as stale.
+ if event.type == EventTypes.Encryption:
+ device_id = event.content.get("device_id")
+ if device_id is not None:
+ cached_devices = await self.store.get_cached_devices_for_user(
+ event.sender
+ )
+ if device_id not in cached_devices:
+ logger.info(
+ "Received event from remote device not in our cache: %s %s",
+ event.sender,
+ device_id,
+ )
+ await self.store.mark_remote_user_device_cache_as_stale(
+ event.sender
+ )
+ # TODO: Poke something to start trying to refetch user's
+ # keys.
+
@log_function
async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
|