summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py93
1 files changed, 58 insertions, 35 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 40187496e2..5eeca6165d 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             prefilled_cache=user_signature_stream_prefill,
         )
 
-        (
-            device_list_federation_prefill,
-            device_list_federation_list_id,
-        ) = self.db_pool.get_cache_dict(
-            db_conn,
-            "device_lists_outbound_pokes",
-            entity_column="destination",
-            stream_column="stream_id",
-            max_value=device_list_max,
-            limit=10000,
-        )
-        self._device_list_federation_stream_cache = StreamChangeCache(
-            "DeviceListFederationStreamChangeCache",
-            device_list_federation_list_id,
-            prefilled_cache=device_list_federation_prefill,
-        )
+        self._device_list_federation_stream_cache = None
+        if hs.should_send_federation():
+            (
+                device_list_federation_prefill,
+                device_list_federation_list_id,
+            ) = self.db_pool.get_cache_dict(
+                db_conn,
+                "device_lists_outbound_pokes",
+                entity_column="destination",
+                stream_column="stream_id",
+                max_value=device_list_max,
+                limit=10000,
+            )
+            self._device_list_federation_stream_cache = StreamChangeCache(
+                "DeviceListFederationStreamChangeCache",
+                device_list_federation_list_id,
+                prefilled_cache=device_list_federation_prefill,
+            )
 
         if hs.config.worker.run_background_tasks:
             self._clock.looping_call(
@@ -207,23 +209,30 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
     ) -> None:
         for row in rows:
             if row.is_signature:
-                self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+                self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
                 continue
 
             # The entities are either user IDs (starting with '@') whose devices
             # have changed, or remote servers that we need to tell about
             # changes.
-            if row.entity.startswith("@"):
-                self._device_list_stream_cache.entity_has_changed(row.entity, token)
-                self.get_cached_devices_for_user.invalidate((row.entity,))
-                self._get_cached_user_device.invalidate((row.entity,))
-                self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
-
-            else:
-                self._device_list_federation_stream_cache.entity_has_changed(
-                    row.entity, token
+            if not row.hosts_calculated:
+                self._device_list_stream_cache.entity_has_changed(row.user_id, token)
+                self.get_cached_devices_for_user.invalidate((row.user_id,))
+                self._get_cached_user_device.invalidate((row.user_id,))
+                self.get_device_list_last_stream_id_for_remote.invalidate(
+                    (row.user_id,)
                 )
 
+    def device_lists_outbound_pokes_have_changed(
+        self, destinations: StrCollection, token: int
+    ) -> None:
+        assert self._device_list_federation_stream_cache is not None
+
+        for destination in destinations:
+            self._device_list_federation_stream_cache.entity_has_changed(
+                destination, token
+            )
+
     def device_lists_in_rooms_have_changed(
         self, room_ids: StrCollection, token: int
     ) -> None:
@@ -363,6 +372,11 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
               EDU contents.
         """
         now_stream_id = self.get_device_stream_token()
+        if from_stream_id == now_stream_id:
+            return now_stream_id, []
+
+        if self._device_list_federation_stream_cache is None:
+            raise Exception("Func can only be used on federation senders")
 
         has_changed = self._device_list_federation_stream_cache.has_entity_changed(
             destination, int(from_stream_id)
@@ -1018,10 +1032,10 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             # This query Does The Right Thing where it'll correctly apply the
             # bounds to the inner queries.
             sql = """
-                SELECT stream_id, entity FROM (
-                    SELECT stream_id, user_id AS entity FROM device_lists_stream
+                SELECT stream_id, user_id, hosts FROM (
+                    SELECT stream_id, user_id, false AS hosts FROM device_lists_stream
                     UNION ALL
-                    SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+                    SELECT DISTINCT stream_id, user_id, true AS hosts FROM device_lists_outbound_pokes
                 ) AS e
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
@@ -1577,6 +1591,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             get_device_list_changes_in_room_txn,
         )
 
+    async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
+        return await self.db_pool.simple_select_onecol(
+            table="device_lists_outbound_pokes",
+            keyvalues={"stream_id": stream_id},
+            retcol="destination",
+            desc="get_destinations_for_device",
+        )
+
 
 class DeviceBackgroundUpdateStore(SQLBaseStore):
     def __init__(
@@ -2112,12 +2134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         stream_ids: List[int],
         context: Optional[Dict[str, str]],
     ) -> None:
-        for host in hosts:
-            txn.call_after(
-                self._device_list_federation_stream_cache.entity_has_changed,
-                host,
-                stream_ids[-1],
-            )
+        if self._device_list_federation_stream_cache:
+            for host in hosts:
+                txn.call_after(
+                    self._device_list_federation_stream_cache.entity_has_changed,
+                    host,
+                    stream_ids[-1],
+                )
 
         now = self._clock.time_msec()
         stream_id_iterator = iter(stream_ids)