diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index df3cdc8fba..54293d0b9c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -45,7 +45,7 @@ from synapse.util.retryutils import NotRetryingDestination
from ._base import BaseHandler
if TYPE_CHECKING:
- from synapse.app.homeserver import HomeServer
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -166,7 +166,7 @@ class DeviceWorkerHandler(BaseHandler):
# Fetch the current state at the time.
try:
- event_ids = await self.store.get_forward_extremeties_for_room(
+ event_ids = await self.store.get_forward_extremities_for_room_at_stream_ordering(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
@@ -907,6 +907,7 @@ class DeviceListUpdater:
master_key = result.get("master_key")
self_signing_key = result.get("self_signing_key")
+ ignore_devices = False
# If the remote server has more than ~1000 devices for this user
# we assume that something is going horribly wrong (e.g. a bot
# that logs in and creates a new device every time it tries to
@@ -925,6 +926,12 @@ class DeviceListUpdater:
len(devices),
)
devices = []
+ ignore_devices = True
+ else:
+ cached_devices = await self.store.get_cached_devices_for_user(user_id)
+ if cached_devices == {d["device_id"]: d for d in devices}:
+ devices = []
+ ignore_devices = True
for device in devices:
logger.debug(
@@ -934,7 +941,10 @@ class DeviceListUpdater:
stream_id,
)
- await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
+ if not ignore_devices:
+ await self.store.update_remote_device_list_cache(
+ user_id, devices, stream_id
+ )
device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys.
@@ -945,7 +955,8 @@ class DeviceListUpdater:
)
device_ids = device_ids + cross_signing_device_ids
- await self.device_handler.notify_device_update(user_id, device_ids)
+ if device_ids:
+ await self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
@@ -973,14 +984,17 @@ class DeviceListUpdater:
"""
device_ids = []
- if master_key:
+ current_keys_map = await self.store.get_e2e_cross_signing_keys_bulk([user_id])
+ current_keys = current_keys_map.get(user_id) or {}
+
+ if master_key and master_key != current_keys.get("master"):
await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
- if self_signing_key:
+ if self_signing_key and self_signing_key != current_keys.get("self_signing"):
await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
|