diff --git a/changelog.d/14723.bugfix b/changelog.d/14723.bugfix
new file mode 100644
index 0000000000..e1f89cee35
--- /dev/null
+++ b/changelog.d/14723.bugfix
@@ -0,0 +1 @@
+Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 658d89210d..b5e40da533 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -152,6 +152,9 @@ class ReplicationDataHandler:
rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
"""
self.store.process_replication_rows(stream_name, instance_name, token, rows)
+ # NOTE: this must be called after process_replication_rows to ensure any
+ # cache invalidations are first handled before any stream ID advances.
+ self.store.process_replication_position(stream_name, instance_name, token)
if self.send_handler:
await self.send_handler.process_replication_rows(stream_name, token, rows)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 69abf6fa87..41d9111019 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta):
token: int,
rows: Iterable[Any],
) -> None:
- pass
+ """
+ Used by storage classes to invalidate caches based on incoming replication data. These
+ must not update any ID generators, use `process_replication_position`.
+ """
+
+ def process_replication_position( # noqa: B027 (no-op by design)
+ self,
+ stream_name: str,
+ instance_name: str,
+ token: int,
+ ) -> None:
+ """
+ Used by storage classes to advance ID generators based on incoming replication data. This
+ is called after process_replication_rows such that caches are invalidated before any token
+ positions advance.
+ """
def _invalidate_state_caches(
self, room_id: str, members_changed: Collection[str]
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index e59776f434..86032897f5 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
token: int,
rows: Iterable[Any],
) -> None:
- if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
- elif stream_name == AccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
+ if stream_name == AccountDataStream.NAME:
for row in rows:
if not row.room_id:
self.get_global_account_data_by_type_for_user.invalidate(
@@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ elif stream_name == AccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
async def add_account_data_to_room(
self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
) -> int:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index a58668a380..2179a8bf59 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
backfilled=True,
)
elif stream_name == CachesStream.NAME:
- if self._cache_id_gen:
- self._cache_id_gen.advance(instance_name, token)
-
for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME:
if row.keys is None:
@@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == CachesStream.NAME:
+ if self._cache_id_gen:
+ self._cache_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
data = row.data
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 48a54d9cb8..713be91c5d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == ToDeviceStream.NAME:
+ self._device_inbox_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def get_to_device_stream_token(self) -> int:
return self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a5bb4d404e..db877e3f13 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows)
elif stream_name == UserSignatureStream.NAME:
- self._device_list_id_gen.advance(instance_name, token)
for row in rows:
self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == DeviceListsStream.NAME:
+ self._device_list_id_gen.advance(instance_name, token)
+ elif stream_name == UserSignatureStream.NAME:
+ self._device_list_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 761b15a815..d150fa8a94 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
token: int,
rows: Iterable[Any],
) -> None:
- if stream_name == EventsStream.NAME:
- self._stream_id_gen.advance(instance_name, token)
- elif stream_name == BackfillStream.NAME:
- self._backfill_id_gen.advance(instance_name, -token)
- elif stream_name == UnPartialStatedEventStream.NAME:
+ if stream_name == UnPartialStatedEventStream.NAME:
for row in rows:
assert isinstance(row, UnPartialStatedEventStreamRow)
@@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore):
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == EventsStream.NAME:
+ self._stream_id_gen.advance(instance_name, token)
+ elif stream_name == BackfillStream.NAME:
+ self._backfill_id_gen.advance(instance_name, -token)
+ super().process_replication_position(stream_name, instance_name, token)
+
async def have_censored_event(self, event_id: str) -> bool:
"""Check if an event has been censored, i.e. if the content of the event has been erased
from the database due to a redaction.
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 9769a18a9d..7b60815043 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
rows: Iterable[Any],
) -> None:
if stream_name == PresenceStream.NAME:
- self._presence_id_gen.advance(instance_name, token)
for row in rows:
self.presence_stream_cache.entity_has_changed(row.user_id, token)
self._get_presence_for_user.invalidate((row.user_id,))
return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == PresenceStream.NAME:
+ self._presence_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d4c64c46ad..d4e4b777da 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -154,6 +154,13 @@ class PushRulesWorkerStore(
self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == PushRulesStream.NAME:
+ self._push_rules_stream_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
@cached(max_entries=5000)
async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
rows = await self.db_pool.simple_select_list(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 40fd781a6a..7f24a3b6ec 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore):
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
- def process_replication_rows(
- self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
) -> None:
if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token)
- return super().process_replication_rows(stream_name, instance_name, token, rows)
+ super().process_replication_position(stream_name, instance_name, token)
async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e06725f69c..86f5bce5f0 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
return super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == ReceiptsStream.NAME:
+ self._receipts_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
def _insert_linearized_receipt_txn(
self,
txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b0f5de67a3..e23c927e02 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
rows: Iterable[Any],
) -> None:
if stream_name == TagAccountDataStream.NAME:
- self._account_data_id_gen.advance(instance_name, token)
for row in rows:
self.get_tags_for_user.invalidate((row.user_id,))
self._account_data_stream_cache.entity_has_changed(row.user_id, token)
super().process_replication_rows(stream_name, instance_name, token, rows)
+ def process_replication_position(
+ self, stream_name: str, instance_name: str, token: int
+ ) -> None:
+ if stream_name == TagAccountDataStream.NAME:
+ self._account_data_id_gen.advance(instance_name, token)
+ super().process_replication_position(stream_name, instance_name, token)
+
class TagsStore(TagsWorkerStore):
pass
|