summary refs log tree commit diff
path: root/synapse/replication/slave/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/slave/storage')
-rw-r--r--synapse/replication/slave/storage/_base.py50
-rw-r--r--synapse/replication/slave/storage/account_data.py6
-rw-r--r--synapse/replication/slave/storage/deviceinbox.py6
-rw-r--r--synapse/replication/slave/storage/devices.py6
-rw-r--r--synapse/replication/slave/storage/events.py6
-rw-r--r--synapse/replication/slave/storage/groups.py6
-rw-r--r--synapse/replication/slave/storage/presence.py6
-rw-r--r--synapse/replication/slave/storage/push_rule.py6
-rw-r--r--synapse/replication/slave/storage/pushers.py6
-rw-r--r--synapse/replication/slave/storage/receipts.py6
-rw-r--r--synapse/replication/slave/storage/room.py4
11 files changed, 31 insertions, 77 deletions
diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py
index 5d7c8871a4..2904bd0235 100644
--- a/synapse/replication/slave/storage/_base.py
+++ b/synapse/replication/slave/storage/_base.py
@@ -18,14 +18,10 @@ from typing import Optional
 
 import six
 
-from synapse.storage.data_stores.main.cache import (
-    CURRENT_STATE_CACHE_NAME,
-    CacheInvalidationWorkerStore,
-)
+from synapse.storage.data_stores.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.database import Database
 from synapse.storage.engines import PostgresEngine
-
-from ._slaved_id_tracker import SlavedIdTracker
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 logger = logging.getLogger(__name__)
 
@@ -41,40 +37,16 @@ class BaseSlavedStore(CacheInvalidationWorkerStore):
     def __init__(self, database: Database, db_conn, hs):
         super(BaseSlavedStore, self).__init__(database, db_conn, hs)
         if isinstance(self.database_engine, PostgresEngine):
-            self._cache_id_gen = SlavedIdTracker(
-                db_conn, "cache_invalidation_stream", "stream_id"
-            )  # type: Optional[SlavedIdTracker]
+            self._cache_id_gen = MultiWriterIdGenerator(
+                db_conn,
+                database,
+                instance_name=hs.get_instance_name(),
+                table="cache_invalidation_stream_by_instance",
+                instance_column="instance_name",
+                id_column="stream_id",
+                sequence_name="cache_invalidation_stream_seq",
+            )  # type: Optional[MultiWriterIdGenerator]
         else:
             self._cache_id_gen = None
 
         self.hs = hs
-
-    def get_cache_stream_token(self):
-        if self._cache_id_gen:
-            return self._cache_id_gen.get_current_token()
-        else:
-            return 0
-
-    def process_replication_rows(self, stream_name, token, rows):
-        if stream_name == "caches":
-            if self._cache_id_gen:
-                self._cache_id_gen.advance(token)
-            for row in rows:
-                if row.cache_func == CURRENT_STATE_CACHE_NAME:
-                    if row.keys is None:
-                        raise Exception(
-                            "Can't send an 'invalidate all' for current state cache"
-                        )
-
-                    room_id = row.keys[0]
-                    members_changed = set(row.keys[1:])
-                    self._invalidate_state_caches(room_id, members_changed)
-                else:
-                    self._attempt_to_invalidate_cache(row.cache_func, row.keys)
-
-    def _invalidate_cache_and_stream(self, txn, cache_func, keys):
-        txn.call_after(cache_func.invalidate, keys)
-        txn.call_after(self._send_invalidation_poke, cache_func, keys)
-
-    def _send_invalidation_poke(self, cache_func, keys):
-        self.hs.get_tcp_replication().send_invalidate_cache(cache_func, keys)
diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py
index 65e54b1c71..2a4f5c7cfd 100644
--- a/synapse/replication/slave/storage/account_data.py
+++ b/synapse/replication/slave/storage/account_data.py
@@ -32,7 +32,7 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
     def get_max_account_data_stream_id(self):
         return self._account_data_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "tag_account_data":
             self._account_data_id_gen.advance(token)
             for row in rows:
@@ -51,6 +51,4 @@ class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlaved
                     (row.user_id, row.room_id, row.data_type)
                 )
                 self._account_data_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedAccountDataStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py
index c923751e50..6e7fd259d4 100644
--- a/synapse/replication/slave/storage/deviceinbox.py
+++ b/synapse/replication/slave/storage/deviceinbox.py
@@ -43,7 +43,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
             expiry_ms=30 * 60 * 1000,
         )
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "to_device":
             self._device_inbox_id_gen.advance(token)
             for row in rows:
@@ -55,6 +55,4 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
                     self._device_federation_outbox_stream_cache.entity_has_changed(
                         row.entity, token
                     )
-        return super(SlavedDeviceInboxStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py
index 58fb0eaae3..9d8067342f 100644
--- a/synapse/replication/slave/storage/devices.py
+++ b/synapse/replication/slave/storage/devices.py
@@ -48,7 +48,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             "DeviceListFederationStreamChangeCache", device_list_max
         )
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == DeviceListsStream.NAME:
             self._device_list_id_gen.advance(token)
             self._invalidate_caches_for_devices(token, rows)
@@ -56,9 +56,7 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
             self._device_list_id_gen.advance(token)
             for row in rows:
                 self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedDeviceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _invalidate_caches_for_devices(self, token, rows):
         for row in rows:
diff --git a/synapse/replication/slave/storage/events.py b/synapse/replication/slave/storage/events.py
index 15011259df..b313720a4b 100644
--- a/synapse/replication/slave/storage/events.py
+++ b/synapse/replication/slave/storage/events.py
@@ -93,7 +93,7 @@ class SlavedEventStore(
     def get_room_min_stream_ordering(self):
         return self._backfill_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "events":
             self._stream_id_gen.advance(token)
             for row in rows:
@@ -111,9 +111,7 @@ class SlavedEventStore(
                     row.relates_to,
                     backfilled=True,
                 )
-        return super(SlavedEventStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
 
     def _process_event_stream_row(self, token, row):
         data = row.data
diff --git a/synapse/replication/slave/storage/groups.py b/synapse/replication/slave/storage/groups.py
index 01bcf0e882..1851e7d525 100644
--- a/synapse/replication/slave/storage/groups.py
+++ b/synapse/replication/slave/storage/groups.py
@@ -37,12 +37,10 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
     def get_group_stream_token(self):
         return self._group_updates_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "groups":
             self._group_updates_id_gen.advance(token)
             for row in rows:
                 self._group_updates_stream_cache.entity_has_changed(row.user_id, token)
 
-        return super(SlavedGroupServerStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/presence.py b/synapse/replication/slave/storage/presence.py
index fae3125072..bd79ba99be 100644
--- a/synapse/replication/slave/storage/presence.py
+++ b/synapse/replication/slave/storage/presence.py
@@ -41,12 +41,10 @@ class SlavedPresenceStore(BaseSlavedStore):
     def get_current_presence_token(self):
         return self._presence_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "presence":
             self._presence_id_gen.advance(token)
             for row in rows:
                 self.presence_stream_cache.entity_has_changed(row.user_id, token)
                 self._get_presence_for_user.invalidate((row.user_id,))
-        return super(SlavedPresenceStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/push_rule.py b/synapse/replication/slave/storage/push_rule.py
index 6138796da4..5d5816d7eb 100644
--- a/synapse/replication/slave/storage/push_rule.py
+++ b/synapse/replication/slave/storage/push_rule.py
@@ -37,13 +37,11 @@ class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
     def get_max_push_rules_stream_id(self):
         return self._push_rules_stream_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "push_rules":
             self._push_rules_stream_id_gen.advance(token)
             for row in rows:
                 self.get_push_rules_for_user.invalidate((row.user_id,))
                 self.get_push_rules_enabled_for_user.invalidate((row.user_id,))
                 self.push_rules_stream_cache.entity_has_changed(row.user_id, token)
-        return super(SlavedPushRuleStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/pushers.py b/synapse/replication/slave/storage/pushers.py
index 67be337945..cb78b49acb 100644
--- a/synapse/replication/slave/storage/pushers.py
+++ b/synapse/replication/slave/storage/pushers.py
@@ -31,9 +31,7 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
     def get_pushers_stream_token(self):
         return self._pushers_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "pushers":
             self._pushers_id_gen.advance(token)
-        return super(SlavedPusherStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py
index 993432edcb..be716cc558 100644
--- a/synapse/replication/slave/storage/receipts.py
+++ b/synapse/replication/slave/storage/receipts.py
@@ -51,7 +51,7 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
         self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id)
         self.get_receipts_for_room.invalidate((room_id, receipt_type))
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "receipts":
             self._receipts_id_gen.advance(token)
             for row in rows:
@@ -60,6 +60,4 @@ class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore):
                 )
                 self._receipts_stream_cache.entity_has_changed(row.room_id, token)
 
-        return super(SlavedReceiptsStore, self).process_replication_rows(
-            stream_name, token, rows
-        )
+        return super().process_replication_rows(stream_name, instance_name, token, rows)
diff --git a/synapse/replication/slave/storage/room.py b/synapse/replication/slave/storage/room.py
index 10dda8708f..8873bf37e5 100644
--- a/synapse/replication/slave/storage/room.py
+++ b/synapse/replication/slave/storage/room.py
@@ -30,8 +30,8 @@ class RoomStore(RoomWorkerStore, BaseSlavedStore):
     def get_current_public_room_stream_id(self):
         return self._public_room_id_gen.get_current_token()
 
-    def process_replication_rows(self, stream_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == "public_rooms":
             self._public_room_id_gen.advance(token)
 
-        return super(RoomStore, self).process_replication_rows(stream_name, token, rows)
+        return super().process_replication_rows(stream_name, instance_name, token, rows)