summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py67
2 files changed, 69 insertions, 2 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a67fdb3c22..f677d048aa 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -1941,6 +1941,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             user_id,
             stream_ids[-1],
         )
+        txn.call_after(
+            self._get_e2e_device_keys_for_federation_query_inner.invalidate,
+            (user_id,),
+        )
 
         min_stream_id = stream_ids[0]
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4bc391f213..91ae9c457d 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -16,6 +16,7 @@
 import abc
 from typing import (
     TYPE_CHECKING,
+    Any,
     Collection,
     Dict,
     Iterable,
@@ -39,6 +40,7 @@ from synapse.appservice import (
     TransactionUnusedFallbackKeys,
 )
 from synapse.logging.opentracing import log_kv, set_tag, trace
+from synapse.replication.tcp.streams._base import DeviceListsStream
 from synapse.storage._base import SQLBaseStore, db_to_json
 from synapse.storage.database import (
     DatabasePool,
@@ -104,6 +106,23 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             self.hs.config.federation.allow_device_name_lookup_over_federation
         )
 
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
+        if stream_name == DeviceListsStream.NAME:
+            for row in rows:
+                assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
+                if row.entity.startswith("@"):
+                    self._get_e2e_device_keys_for_federation_query_inner.invalidate(
+                        (row.entity,)
+                    )
+
+        super().process_replication_rows(stream_name, instance_name, token, rows)
+
     async def get_e2e_device_keys_for_federation_query(
         self, user_id: str
     ) -> Tuple[int, List[JsonDict]]:
@@ -114,6 +133,50 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         """
         now_stream_id = self.get_device_stream_token()
 
+        # We need to be careful with the caching here, as we need to always
+        # return *all* persisted devices, however there may be a lag between a
+        # new device being persisted and the cache being invalidated.
+        cached_results = (
+            self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate(
+                user_id, None
+            )
+        )
+        if cached_results is not None:
+            # Check that there have been no new devices added by another worker
+            # after the cache. This should be quick as there should be few rows
+            # with a higher stream ordering.
+            #
+            # Note that we invalidate based on the device stream, so we only
+            # have to check for potential invalidations after the
+            # `now_stream_id`.
+            sql = """
+                SELECT user_id FROM device_lists_stream
+                WHERE stream_id >= ? AND user_id = ?
+            """
+            rows = await self.db_pool.execute(
+                "get_e2e_device_keys_for_federation_query_check",
+                None,
+                sql,
+                now_stream_id,
+                user_id,
+            )
+            if not rows:
+                # No new rows, so cache is still valid.
+                return now_stream_id, cached_results
+
+            # There has, so let's invalidate the cache and run the query.
+            self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,))
+
+        results = await self._get_e2e_device_keys_for_federation_query_inner(user_id)
+
+        return now_stream_id, results
+
+    @cached(iterable=True)
+    async def _get_e2e_device_keys_for_federation_query_inner(
+        self, user_id: str
+    ) -> List[JsonDict]:
+        """Get all devices (with any device keys) for a user"""
+
         devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
 
         if devices:
@@ -134,9 +197,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
                 results.append(result)
 
-            return now_stream_id, results
+            return results
 
-        return now_stream_id, []
+        return []
 
     @trace
     @cancellable