summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/replication/tcp/client.py3
-rw-r--r--synapse/storage/_base.py17
-rw-r--r--synapse/storage/databases/main/account_data.py14
-rw-r--r--synapse/storage/databases/main/cache.py11
-rw-r--r--synapse/storage/databases/main/deviceinbox.py7
-rw-r--r--synapse/storage/databases/main/devices.py11
-rw-r--r--synapse/storage/databases/main/events_worker.py15
-rw-r--r--synapse/storage/databases/main/presence.py8
-rw-r--r--synapse/storage/databases/main/push_rule.py7
-rw-r--r--synapse/storage/databases/main/pusher.py6
-rw-r--r--synapse/storage/databases/main/receipts.py7
-rw-r--r--synapse/storage/databases/main/tags.py8
12 files changed, 94 insertions, 20 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 658d89210d..b5e40da533 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -152,6 +152,9 @@ class ReplicationDataHandler:
             rows: a list of Stream.ROW_TYPE objects as returned by Stream.parse_row.
         """
         self.store.process_replication_rows(stream_name, instance_name, token, rows)
+        # NOTE: this must be called after process_replication_rows to ensure any
+        # cache invalidations are first handled before any stream ID advances.
+        self.store.process_replication_position(stream_name, instance_name, token)
 
         if self.send_handler:
             await self.send_handler.process_replication_rows(stream_name, token, rows)
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index 69abf6fa87..41d9111019 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -57,7 +57,22 @@ class SQLBaseStore(metaclass=ABCMeta):
         token: int,
         rows: Iterable[Any],
     ) -> None:
-        pass
+        """
+        Used by storage classes to invalidate caches based on incoming replication data. These
+        must not update any ID generators, use `process_replication_position`.
+        """
+
+    def process_replication_position(  # noqa: B027 (no-op by design)
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+    ) -> None:
+        """
+        Used by storage classes to advance ID generators based on incoming replication data. This
+        is called after process_replication_rows such that caches are invalidated before any token
+        positions advance.
+        """
 
     def _invalidate_state_caches(
         self, room_id: str, members_changed: Collection[str]
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index e59776f434..86032897f5 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -436,10 +436,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
         token: int,
         rows: Iterable[Any],
     ) -> None:
-        if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
-        elif stream_name == AccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
+        if stream_name == AccountDataStream.NAME:
             for row in rows:
                 if not row.room_id:
                     self.get_global_account_data_by_type_for_user.invalidate(
@@ -454,6 +451,15 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+        elif stream_name == AccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     async def add_account_data_to_room(
         self, user_id: str, room_id: str, account_data_type: str, content: JsonDict
     ) -> int:
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index a58668a380..2179a8bf59 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -164,9 +164,6 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
                     backfilled=True,
                 )
         elif stream_name == CachesStream.NAME:
-            if self._cache_id_gen:
-                self._cache_id_gen.advance(instance_name, token)
-
             for row in rows:
                 if row.cache_func == CURRENT_STATE_CACHE_NAME:
                     if row.keys is None:
@@ -182,6 +179,14 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == CachesStream.NAME:
+            if self._cache_id_gen:
+                self._cache_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     def _process_event_stream_row(self, token: int, row: EventsStreamRow) -> None:
         data = row.data
 
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 48a54d9cb8..713be91c5d 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -157,6 +157,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                     )
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == ToDeviceStream.NAME:
+            self._device_inbox_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     def get_to_device_stream_token(self) -> int:
         return self._device_inbox_id_gen.get_current_token()
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index a5bb4d404e..db877e3f13 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -162,14 +162,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
     ) -> None:
         if stream_name == DeviceListsStream.NAME:
-            self._device_list_id_gen.advance(instance_name, token)
             self._invalidate_caches_for_devices(token, rows)
         elif stream_name == UserSignatureStream.NAME:
-            self._device_list_id_gen.advance(instance_name, token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == DeviceListsStream.NAME:
+            self._device_list_id_gen.advance(instance_name, token)
+        elif stream_name == UserSignatureStream.NAME:
+            self._device_list_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     def _invalidate_caches_for_devices(
         self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
     ) -> None:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 761b15a815..d150fa8a94 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -388,11 +388,7 @@ class EventsWorkerStore(SQLBaseStore):
         token: int,
         rows: Iterable[Any],
     ) -> None:
-        if stream_name == EventsStream.NAME:
-            self._stream_id_gen.advance(instance_name, token)
-        elif stream_name == BackfillStream.NAME:
-            self._backfill_id_gen.advance(instance_name, -token)
-        elif stream_name == UnPartialStatedEventStream.NAME:
+        if stream_name == UnPartialStatedEventStream.NAME:
             for row in rows:
                 assert isinstance(row, UnPartialStatedEventStreamRow)
 
@@ -405,6 +401,15 @@ class EventsWorkerStore(SQLBaseStore):
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == EventsStream.NAME:
+            self._stream_id_gen.advance(instance_name, token)
+        elif stream_name == BackfillStream.NAME:
+            self._backfill_id_gen.advance(instance_name, -token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     async def have_censored_event(self, event_id: str) -> bool:
         """Check if an event has been censored, i.e. if the content of the event has been erased
         from the database due to a redaction.
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 9769a18a9d..7b60815043 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -439,8 +439,14 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
         rows: Iterable[Any],
     ) -> None:
         if stream_name == PresenceStream.NAME:
-            self._presence_id_gen.advance(instance_name, token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
         return super().process_replication_rows(stream_name, instance_name, token, rows)
+
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == PresenceStream.NAME:
+            self._presence_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index d4c64c46ad..d4e4b777da 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -154,6 +154,13 @@ class PushRulesWorkerStore(
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == PushRulesStream.NAME:
+            self._push_rules_stream_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     @cached(max_entries=5000)
     async def get_push_rules_for_user(self, user_id: str) -> FilteredPushRules:
         rows = await self.db_pool.simple_select_list(
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 40fd781a6a..7f24a3b6ec 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -111,12 +111,12 @@ class PusherWorkerStore(SQLBaseStore):
     def get_pushers_stream_token(self) -> int:
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(
-        self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
     ) -> None:
         if stream_name == PushersStream.NAME:
             self._pushers_id_gen.advance(instance_name, token)
-        return super().process_replication_rows(stream_name, instance_name, token, rows)
+        super().process_replication_position(stream_name, instance_name, token)
 
     async def get_pushers_by_app_id_and_pushkey(
         self, app_id: str, pushkey: str
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index e06725f69c..86f5bce5f0 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -588,6 +588,13 @@ class ReceiptsWorkerStore(SQLBaseStore):
 
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == ReceiptsStream.NAME:
+            self._receipts_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
     def _insert_linearized_receipt_txn(
         self,
         txn: LoggingTransaction,
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index b0f5de67a3..e23c927e02 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -300,13 +300,19 @@ class TagsWorkerStore(AccountDataWorkerStore):
         rows: Iterable[Any],
     ) -> None:
         if stream_name == TagAccountDataStream.NAME:
-            self._account_data_id_gen.advance(instance_name, token)
             for row in rows:
                 self.get_tags_for_user.invalidate((row.user_id,))
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
 
         super().process_replication_rows(stream_name, instance_name, token, rows)
 
+    def process_replication_position(
+        self, stream_name: str, instance_name: str, token: int
+    ) -> None:
+        if stream_name == TagAccountDataStream.NAME:
+            self._account_data_id_gen.advance(instance_name, token)
+        super().process_replication_position(stream_name, instance_name, token)
+
 
 class TagsStore(TagsWorkerStore):
     pass