summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/generic_worker.py3
-rw-r--r--synapse/replication/slave/storage/devices.py3
-rw-r--r--synapse/storage/data_stores/main/devices.py27
3 files changed, 25 insertions, 8 deletions
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index d596852419..cdc078cf11 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -775,6 +775,9 @@ class FederationSenderHandler(object):
 
         # ... as well as device updates and messages
         elif stream_name == DeviceListsStream.NAME:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
             hosts = {row.entity for row in rows if not row.entity.startswith("@")}
             for host in hosts:
                 self.federation_sender.send_device_messages(host)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 01a4f85884..23b1650e41 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -72,6 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
 
     def _invalidate_caches_for_devices(self, token, rows):
         for row in rows:
+            # The entities are either user IDs (starting with '@') whose devices
+            # have changed, or remote servers that we need to tell about
+            # changes.
             if row.entity.startswith("@"):
                 self._device_list_stream_cache.entity_has_changed(row.entity, token)
                 self.get_cached_devices_for_user.invalidate((row.entity,))
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 06e1d9f033..4c19c02bbc 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List, Tuple
 
 from six import iteritems
 
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import Database, LoggingTransaction
 from synapse.types import Collection, get_verify_key_from_cross_signing_key
 from synapse.util.caches.descriptors import (
     Cache,
@@ -574,10 +575,12 @@ class DeviceWorkerStore(SQLBaseStore):
         else:
             return set()
 
-    def get_all_device_list_changes_for_remotes(self, from_key, to_key):
-        """Return a list of `(stream_id, user_id, destination)` which is the
-        combined list of changes to devices, and which destinations need to be
-        poked. `destination` may be None if no destinations need to be poked.
+    async def get_all_device_list_changes_for_remotes(
+        self, from_key: int, to_key: int
+    ) -> List[Tuple[int, str]]:
+        """Return a list of `(stream_id, entity)` which is the combined list of
+        changes to devices and which destinations need to be poked. Entity is
+        either a user ID (starting with '@') or a remote destination.
         """
 
         # This query Does The Right Thing where it'll correctly apply the
@@ -591,7 +594,7 @@ class DeviceWorkerStore(SQLBaseStore):
             WHERE ? < stream_id AND stream_id <= ?
         """
 
-        return self.db.execute(
+        return await self.db.execute(
             "get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
         )
 
@@ -1018,11 +1021,19 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         return stream_ids[-1]
 
-    def _add_device_change_to_stream_txn(self, txn, user_id, device_ids, stream_ids):
+    def _add_device_change_to_stream_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_ids: Collection[str],
+        stream_ids: List[str],
+    ):
         txn.call_after(
             self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
         )
 
+        min_stream_id = stream_ids[0]
+
         # Delete older entries in the table, as we really only care about
         # when the latest change happened.
         txn.executemany(
@@ -1030,7 +1041,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             DELETE FROM device_lists_stream
             WHERE user_id = ? AND device_id = ? AND stream_id < ?
             """,
-            [(user_id, device_id, stream_ids[0]) for device_id in device_ids],
+            [(user_id, device_id, min_stream_id) for device_id in device_ids],
         )
 
         self.db.simple_insert_many_txn(