diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 06e1d9f033..4c19c02bbc 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+from typing import List, Tuple
from six import iteritems
@@ -31,7 +32,7 @@ from synapse.logging.opentracing import (
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import Database
+from synapse.storage.database import Database, LoggingTransaction
from synapse.types import Collection, get_verify_key_from_cross_signing_key
from synapse.util.caches.descriptors import (
Cache,
@@ -574,10 +575,12 @@ class DeviceWorkerStore(SQLBaseStore):
else:
return set()
- def get_all_device_list_changes_for_remotes(self, from_key, to_key):
- """Return a list of `(stream_id, user_id, destination)` which is the
- combined list of changes to devices, and which destinations need to be
- poked. `destination` may be None if no destinations need to be poked.
+ async def get_all_device_list_changes_for_remotes(
+ self, from_key: int, to_key: int
+ ) -> List[Tuple[int, str]]:
+ """Return a list of `(stream_id, entity)` which is the combined list of
+ changes to devices and which destinations need to be poked. Entity is
+ either a user ID (starting with '@') or a remote destination.
"""
# This query Does The Right Thing where it'll correctly apply the
@@ -591,7 +594,7 @@ class DeviceWorkerStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
"""
- return self.db.execute(
+ return await self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
@@ -1018,11 +1021,19 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
return stream_ids[-1]
- def _add_device_change_to_stream_txn(self, txn, user_id, device_ids, stream_ids):
+ def _add_device_change_to_stream_txn(
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_ids: Collection[str],
+ stream_ids: List[str],
+ ):
txn.call_after(
self._device_list_stream_cache.entity_has_changed, user_id, stream_ids[-1],
)
+ min_stream_id = stream_ids[0]
+
# Delete older entries in the table, as we really only care about
# when the latest change happened.
txn.executemany(
@@ -1030,7 +1041,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ?
""",
- [(user_id, device_id, stream_ids[0]) for device_id in device_ids],
+ [(user_id, device_id, min_stream_id) for device_id in device_ids],
)
self.db.simple_insert_many_txn(
|