diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index b067664473..cd186c8472 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -38,7 +38,7 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.replication.tcp.streams._base import DeviceListsStream, UserSignatureStream
+from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
@@ -163,9 +163,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None:
if stream_name == DeviceListsStream.NAME:
self._invalidate_caches_for_devices(token, rows)
- elif stream_name == UserSignatureStream.NAME:
- for row in rows:
- self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
+
return super().process_replication_rows(stream_name, instance_name, token, rows)
def process_replication_position(
@@ -173,14 +171,17 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None:
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)
- elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
+
super().process_replication_position(stream_name, instance_name, token)
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows:
+ if row.is_signature:
+ self._user_signature_stream_cache.entity_has_changed(row.entity, token)
+ continue
+
# The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about
# changes.
|