diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index b2c764bfe8..561a6f4b22 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -774,7 +774,7 @@ class FederationSenderHandler(object):
# ... as well as device updates and messages
elif stream_name == DeviceListsStream.NAME:
- hosts = {row.destination for row in rows}
+ hosts = {row.entity for row in rows if not row.entity.startswith("@")}
for host in hosts:
self.federation_sender.send_device_messages(host)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index bf46cc4f8a..01a4f85884 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -61,23 +61,24 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def process_replication_rows(self, stream_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
- for row in rows:
- self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
+ self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super(SlavedDeviceStore, self).process_replication_rows(
stream_name, token, rows
)
- def _invalidate_caches_for_devices(self, token, user_id, destination):
- self._device_list_stream_cache.entity_has_changed(user_id, token)
-
- if destination:
- self._device_list_federation_stream_cache.entity_has_changed(
- destination, token
- )
+ def _invalidate_caches_for_devices(self, token, rows):
+ for row in rows:
+ if row.entity.startswith("@"):
+ self._device_list_stream_cache.entity_has_changed(row.entity, token)
+ self.get_cached_devices_for_user.invalidate((row.entity,))
+ self._get_cached_user_device.invalidate_many((row.entity,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
- self.get_cached_devices_for_user.invalidate((user_id,))
- self._get_cached_user_device.invalidate_many((user_id,))
- self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+ else:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ row.entity, token
+ )
diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py
index 208e8a667b..7a8b6e9df1 100644
--- a/synapse/replication/tcp/streams/_base.py
+++ b/synapse/replication/tcp/streams/_base.py
@@ -94,9 +94,13 @@ PublicRoomsStreamRow = namedtuple(
"network_id", # str, optional
),
)
-DeviceListsStreamRow = namedtuple(
- "DeviceListsStreamRow", ("user_id", "destination") # str # str
-)
+
+
+@attr.s
+class DeviceListsStreamRow:
+ entity = attr.ib(type=str)
+
+
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", ("entity",)) # str
TagAccountDataStreamRow = namedtuple(
"TagAccountDataStreamRow", ("user_id", "room_id", "data") # str # str # dict
@@ -363,7 +367,8 @@ class PublicRoomsStream(Stream):
class DeviceListsStream(Stream):
- """Someone added/changed/removed a device
+ """Either a user has updated their devices or a remote server needs to be
+ told about a device update.
"""
NAME = "device_lists"
diff --git a/synapse/storage/data_stores/main/devices.py b/synapse/storage/data_stores/main/devices.py
index 3299607910..768afe7a6c 100644
--- a/synapse/storage/data_stores/main/devices.py
+++ b/synapse/storage/data_stores/main/devices.py
@@ -612,15 +612,18 @@ class DeviceWorkerStore(SQLBaseStore):
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
- # We do a group by here as there can be a large number of duplicate
- # entries, since we throw away device IDs.
+
+ # This query Does The Right Thing where it'll correctly apply the
+ # bounds to the inner queries.
sql = """
- SELECT MAX(stream_id) AS stream_id, user_id, destination
- FROM device_lists_stream
- LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
+ SELECT stream_id, entity FROM (
+ SELECT stream_id, user_id AS entity FROM device_lists_stream
+ UNION ALL
+ SELECT stream_id, destination AS entity FROM device_lists_outbound_pokes
+ ) AS e
WHERE ? < stream_id AND stream_id <= ?
- GROUP BY user_id, destination
"""
+
return self.db.execute(
"get_all_device_list_changes_for_remotes", None, sql, from_key, to_key
)
|