summary refs log tree commit diff
diff options
context:
space:
mode:
authorNick Mills-Barrett <nick@beeper.com>2023-01-04 11:49:26 +0000
committerGitHub <noreply@github.com>2023-01-04 11:49:26 +0000
commitdb1cfe9c80a707995fcad8f3faa839acb247068a (patch)
tree691c711006765e770056d97db624043d5b87b781
parentAdd experimental support for MSC3391: deleting account data (#14714) (diff)
downloadsynapse-db1cfe9c80a707995fcad8f3faa839acb247068a.tar.xz
Update all stream IDs after processing replication rows (#14723)
This creates a new store method, `process_replication_position` that
is called after `process_replication_rows`. By moving stream ID advances
here this guarantees any relevant cache invalidations will have been
applied before the stream is advanced.

This avoids race conditions where Python switches between threads mid
way through processing the `process_replication_rows` method where stream
IDs may be advanced before caches are invalidated due to class resolution
ordering.

See this comment/issue for further discussion:
	https://github.com/matrix-org/synapse/issues/14158#issuecomment-1344048703
-rw-r--r--changelog.d/14723.bugfix1
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/storage/_base.py17
-rw-r--r--synapse/storage/databases/main/account_data.py14
-rw-r--r--synapse/storage/databases/main/cache.py11
-rw-r--r--synapse/storage/databases/main/deviceinbox.py7
-rw-r--r--synapse/storage/databases/main/devices.py11
-rw-r--r--synapse/storage/databases/main/events_worker.py15
-rw-r--r--synapse/storage/databases/main/presence.py8
-rw-r--r--synapse/storage/databases/main/push_rule.py7
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py7
-rw-r--r--synapse/storage/databases/main/tags.py8
13 files changed, 95 insertions, 20 deletions
diff --git a/changelog.d/14723.bugfix b/changelog.d/14723.bugfix
new file mode 100644
index 0000000000..e1f89cee35
--- /dev/null
+++ b/changelog.d/14723.bugfix
@@ -0,0 +1 @@
+Ensure stream IDs are always updated after caches get invalidated with workers. Contributed by Nick @ Beeper (@fizzadar).
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 658d89210d..b5e40da533 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -152,6 +152,9 @@ class ReplicationDataHandler:
             rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
+        # NOTE: this must be called after process_replication_rows to ensure any
+        # cache invalidations are first handled before any stream ID advances.
+        self.store.process_replication_position(stream_name, instance_name, token)
 
         if self.send_handler:
             await self.send_handler.process_replication_rows(stream_name, token, rows)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 69abf6fa87..41d9111019 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta):
         token: int,
         rows: Iterable[Any],
     ) -> None:
-        pass
+        """
+        Used by storage classes to invalidate caches based on incoming replication data. These
+        must not update any ID generators, use `process_replication_position`.
+        """
+
+    def process_replication_position(  # noqa: B027 (no-op by design)
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+    ) -> None:
+        """
+        Used by storage classes to advance ID generators based on incoming replication data. This
+        is called after process_replication_rows such that caches are invalidated before any token
+        positions advance.
+        """
 
     def _invalidate_state_caches(
         self, room_id: str, members_changed: Collection[str]
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index e59776f434..86032897f5 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         token: int,
         rows: Iterable[Any],
     ) -> None:
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-        elif stream_name == AccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
+        if stream_name == AccountDataStream.NAME:
             for row in rows:
                 if not row.room_id:
                     self.get_global_account_data_by_type_for_user.invalidate(
@@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+        elif stream_name == AccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
     ) -> int:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index a58668a380..2179a8bf59 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     backfilled=True,
                 )
         elif stream_name == CachesStream.NAME:
-            if self._cache_id_gen:
-                self._cache_id_gen.advance(instance_name, token)
-
             for row in rows:
                 if row.cache_func == CURRENT_STATE_CACHE_NAME:
                     if row.keys is None:
@@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == CachesStream.NAME:
+            if self._cache_id_gen:
+                self._cache_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
         data = row.data
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 48a54d9cb8..713be91c5d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                     )
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == ToDeviceStream.NAME:
+            self._device_inbox_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     def get_to_device_stream_token(self) -> int:
         return self._device_inbox_id_gen.get_current_token()
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a5bb4d404e..db877e3f13 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
     ) -> None:
         if stream_name == DeviceListsStream.NAME:
-            self._device_list_id_gen.advance(instance_name, token)
             self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
-            self._device_list_id_gen.advance(instance_name, token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == DeviceListsStream.NAME:
+            self._device_list_id_gen.advance(instance_name, token)
+        elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     def _invalidate_caches_for_devices(
         self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
     ) -> None:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 761b15a815..d150fa8a94 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
         token: int,
         rows: Iterable[Any],
     ) -> None:
-        if stream_name == EventsStream.NAME:
-            self._stream_id_gen.advance(instance_name, token)
-        elif stream_name == BackfillStream.NAME:
-            self._backfill_id_gen.advance(instance_name, -token)
-        elif stream_name == UnPartialStatedEventStream.NAME:
+        if stream_name == UnPartialStatedEventStream.NAME:
             for row in rows:
                 assert isinstance(row, UnPartialStatedEventStreamRow)
 
@@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == EventsStream.NAME:
+            self._stream_id_gen.advance(instance_name, token)
+        elif stream_name == BackfillStream.NAME:
+            self._backfill_id_gen.advance(instance_name, -token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     async def have_censored_event(self, event_id: str) -> bool:
         """Check if an event has been censored, i.e. if the content of the event has been erased
         from the database due to a redaction.
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 9769a18a9d..7b60815043 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
         rows: Iterable[Any],
     ) -> None:
         if stream_name == PresenceStream.NAME:
-            self._presence_id_gen.advance(instance_name, token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
         return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == PresenceStream.NAME:
+            self._presence_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d4c64c46ad..d4e4b777da 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -154,6 +154,13 @@ class PushRulesWorkerStore(
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == PushRulesStream.NAME:
+            self._push_rules_stream_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     @cached(max_entries=5000)
     async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
         rows = await self.db_pool.simple_select_list(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 40fd781a6a..7f24a3b6ec 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore):
     def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(
-        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
     ) -> None:
         if stream_name == PushersStream.NAME:
             self._pushers_id_gen.advance(instance_name, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+        super().process_replication_position(stream_name, instance_name, token)
 
     async def get_pushers_by_app_id_and_pushkey(
         self, app_id: str, pushkey: str
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e06725f69c..86f5bce5f0 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == ReceiptsStream.NAME:
+            self._receipts_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     def _insert_linearized_receipt_txn(
         self,
         txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b0f5de67a3..e23c927e02 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
         rows: Iterable[Any],
     ) -> None:
         if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
 
 class TagsStore(TagsWorkerStore):
     pass