summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py10
-rw-r--r--synapse/storage/databases/main/devices.py8
2 files changed, 15 insertions, 3 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 7665425232..b184a48cb1 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -948,8 +948,16 @@ class DeviceListUpdater:
             devices = []
             ignore_devices = True
         else:
+            prev_stream_id = await self.store.get_device_list_last_stream_id_for_remote(
+                user_id
+            )
             cached_devices = await self.store.get_cached_devices_for_user(user_id)
-            if cached_devices == {d["device_id"]: d for d in devices}:
+
+            # To ensure that a user with no devices is cached, we skip the resync only
+            # if we have a stream_id from previously writing a cache entry.
+            if prev_stream_id is not None and cached_devices == {
+                d["device_id"]: d for d in devices
+            }:
                 logging.info(
                     "Skipping device list resync for %s, as our cache matches already",
                     user_id,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 273adb61fd..52fbf50db6 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -713,7 +713,7 @@ class DeviceWorkerStore(SQLBaseStore):
     @cached(max_entries=10000)
     async def get_device_list_last_stream_id_for_remote(
         self, user_id: str
-    ) -> Optional[Any]:
+    ) -> Optional[str]:
         """Get the last stream_id we got for a user. May be None if we haven't
         got any information for them.
         """
@@ -729,7 +729,9 @@ class DeviceWorkerStore(SQLBaseStore):
         cached_method_name="get_device_list_last_stream_id_for_remote",
         list_name="user_ids",
     )
-    async def get_device_list_last_stream_id_for_remotes(self, user_ids: Iterable[str]):
+    async def get_device_list_last_stream_id_for_remotes(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, Optional[str]]:
         rows = await self.db_pool.simple_select_many_batch(
             table="device_lists_remote_extremeties",
             column="user_id",
@@ -1316,6 +1318,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         content: JsonDict,
         stream_id: str,
     ) -> None:
+        """Delete, update or insert a cache entry for this (user, device) pair."""
         if content.get("deleted"):
             self.db_pool.simple_delete_txn(
                 txn,
@@ -1375,6 +1378,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
     def _update_remote_device_list_cache_txn(
         self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
     ) -> None:
+        """Replace the list of cached devices for this user with the given list."""
         self.db_pool.simple_delete_txn(
             txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
         )