diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 1c77687eea..9d8067342f 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -29,7 +29,13 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
- db_conn, "device_lists_stream", "stream_id"
+ db_conn,
+ "device_lists_stream",
+ "stream_id",
+ extra_tables=[
+ ("user_signature_stream", "stream_id"),
+ ("device_lists_outbound_pokes", "stream_id"),
+ ],
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
@@ -42,36 +48,28 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
- def stream_positions(self):
- result = super(SlavedDeviceStore, self).stream_positions()
- # The user signature stream uses the same stream ID generator as the
- # device list stream, so set them both to the device list ID
- # generator's current token.
- current_token = self._device_list_id_gen.get_current_token()
- result[DeviceListsStream.NAME] = current_token
- result[UserSignatureStream.NAME] = current_token
- return result
-
- def process_replication_rows(self, stream_name, token, rows):
+ def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(token)
- for row in rows:
- self._invalidate_caches_for_devices(token, row.user_id, row.destination)
+ self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
+ self._device_list_id_gen.advance(token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
- return super(SlavedDeviceStore, self).process_replication_rows(
- stream_name, token, rows
- )
-
- def _invalidate_caches_for_devices(self, token, user_id, destination):
- self._device_list_stream_cache.entity_has_changed(user_id, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
- if destination:
- self._device_list_federation_stream_cache.entity_has_changed(
- destination, token
- )
+ def _invalidate_caches_for_devices(self, token, rows):
+ for row in rows:
+ # The entities are either user IDs (starting with '@') whose devices
+ # have changed, or remote servers that we need to tell about
+ # changes.
+ if row.entity.startswith("@"):
+ self._device_list_stream_cache.entity_has_changed(row.entity, token)
+ self.get_cached_devices_for_user.invalidate((row.entity,))
+ self._get_cached_user_device.invalidate_many((row.entity,))
+ self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,))
- self.get_cached_devices_for_user.invalidate((user_id,))
- self._get_cached_user_device.invalidate_many((user_id,))
- self.get_device_list_last_stream_id_for_remote.invalidate((user_id,))
+ else:
+ self._device_list_federation_stream_cache.entity_has_changed(
+ row.entity, token
+ )
|