summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/devices.py54
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py19
-rw-r--r--synapse/storage/databases/main/push_rule.py24
-rw-r--r--synapse/storage/databases/main/pusher.py42
4 files changed, 83 insertions, 56 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 48384e238c..1c771e48f7 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,10 +57,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import (
     JsonDict,
     JsonMapping,
@@ -99,19 +96,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
 
         # In the worker store this is an ID tracker which we overwrite in the non-worker
         # class below that is used on the main process.
-        self._device_list_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "device_lists_stream",
-            "stream_id",
-            extra_tables=[
-                ("user_signature_stream", "stream_id"),
-                ("device_lists_outbound_pokes", "stream_id"),
-                ("device_lists_changes_in_room", "stream_id"),
-                ("device_lists_remote_pending", "stream_id"),
-                ("device_lists_changes_converted_stream_position", "stream_id"),
+        self._device_list_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="device_lists_stream",
+            instance_name=self._instance_name,
+            tables=[
+                ("device_lists_stream", "instance_name", "stream_id"),
+                ("user_signature_stream", "instance_name", "stream_id"),
+                ("device_lists_outbound_pokes", "instance_name", "stream_id"),
+                ("device_lists_changes_in_room", "instance_name", "stream_id"),
+                ("device_lists_remote_pending", "instance_name", "stream_id"),
             ],
-            is_writer=hs.config.worker.worker_app is None,
+            sequence_name="device_lists_sequence",
+            writers=["master"],
         )
 
         device_list_max = self._device_list_id_gen.get_current_token()
@@ -762,6 +761,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
                 "stream_id": stream_id,
                 "from_user_id": from_user_id,
                 "user_ids": json_encoder.encode(user_ids),
+                "instance_name": self._instance_name,
             },
         )
 
@@ -1582,6 +1582,8 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
+        self._instance_name = hs.get_instance_name()
+
         self.db_pool.updates.register_background_index_update(
             "device_lists_stream_idx",
             index_name="device_lists_stream_user_id",
@@ -1694,6 +1696,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
                     "device_lists_outbound_pokes",
                     {
                         "stream_id": stream_id,
+                        "instance_name": self._instance_name,
                         "destination": destination,
                         "user_id": user_id,
                         "device_id": device_id,
@@ -1730,10 +1733,6 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
 
 
 class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
-    # Because we have write access, this will be a StreamIdGenerator
-    # (see DeviceWorkerStore.__init__)
-    _device_list_id_gen: AbstractStreamIdGenerator
-
     def __init__(
         self,
         database: DatabasePool,
@@ -2092,9 +2091,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         self.db_pool.simple_insert_many_txn(
             txn,
             table="device_lists_stream",
-            keys=("stream_id", "user_id", "device_id"),
+            keys=("instance_name", "stream_id", "user_id", "device_id"),
             values=[
-                (stream_id, user_id, device_id)
+                (self._instance_name, stream_id, user_id, device_id)
                 for stream_id, device_id in zip(stream_ids, device_ids)
             ],
         )
@@ -2124,6 +2123,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         values = [
             (
                 destination,
+                self._instance_name,
                 next(stream_id_iterator),
                 user_id,
                 device_id,
@@ -2139,6 +2139,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             table="device_lists_outbound_pokes",
             keys=(
                 "destination",
+                "instance_name",
                 "stream_id",
                 "user_id",
                 "device_id",
@@ -2157,7 +2158,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 device_id,
                 {
                     stream_id: destination
-                    for (destination, stream_id, _, _, _, _, _) in values
+                    for (destination, _, stream_id, _, _, _, _, _) in values
                 },
             )
 
@@ -2210,6 +2211,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                 "device_id",
                 "room_id",
                 "stream_id",
+                "instance_name",
                 "converted_to_destinations",
                 "opentracing_context",
             ),
@@ -2219,6 +2221,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     device_id,
                     room_id,
                     stream_id,
+                    self._instance_name,
                     # We only need to calculate outbound pokes for local users
                     not self.hs.is_mine_id(user_id),
                     encoded_context,
@@ -2338,7 +2341,10 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
                     "user_id": user_id,
                     "device_id": device_id,
                 },
-                values={"stream_id": stream_id},
+                values={
+                    "stream_id": stream_id,
+                    "instance_name": self._instance_name,
+                },
                 desc="add_remote_device_list_to_pending",
             )
 
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index b219ea70ee..38d8785faa 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -58,7 +58,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict, JsonMapping
 from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
@@ -1448,11 +1448,17 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self._cross_signing_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "e2e_cross_signing_keys",
-            "stream_id",
+        self._cross_signing_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="e2e_cross_signing_keys",
+            instance_name=self._instance_name,
+            tables=[
+                ("e2e_cross_signing_keys", "instance_name", "stream_id"),
+            ],
+            sequence_name="e2e_cross_signing_keys_sequence",
+            writers=["master"],
         )
 
     async def set_e2e_device_keys(
@@ -1627,6 +1633,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
                 "keytype": key_type,
                 "keydata": json_encoder.encode(key),
                 "stream_id": stream_id,
+                "instance_name": self._instance_name,
             },
         )
 
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 660c834518..2a39dc9f90 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -53,7 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
-from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import IdGenerator, MultiWriterIdGenerator
 from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
 from synapse.types import JsonDict
 from synapse.util import json_encoder, unwrapFirstError
@@ -126,7 +126,7 @@ class PushRulesWorkerStore(
     `get_max_push_rules_stream_id` which can be called in the initializer.
     """
 
-    _push_rules_stream_id_gen: StreamIdGenerator
+    _push_rules_stream_id_gen: MultiWriterIdGenerator
 
     def __init__(
         self,
@@ -140,14 +140,17 @@ class PushRulesWorkerStore(
             hs.get_instance_name() in hs.config.worker.writers.push_rules
         )
 
-        # In the worker store this is an ID tracker which we overwrite in the non-worker
-        # class below that is used on the main process.
-        self._push_rules_stream_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "push_rules_stream",
-            "stream_id",
-            is_writer=self._is_push_writer,
+        self._push_rules_stream_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="push_rules_stream",
+            instance_name=self._instance_name,
+            tables=[
+                ("push_rules_stream", "instance_name", "stream_id"),
+            ],
+            sequence_name="push_rules_stream_sequence",
+            writers=hs.config.worker.writers.push_rules,
         )
 
         push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@@ -880,6 +883,7 @@ class PushRulesWorkerStore(
             raise Exception("Not a push writer")
 
         values = {
+            "instance_name": self._instance_name,
             "stream_id": stream_id,
             "event_stream_ordering": event_stream_ordering,
             "user_id": user_id,
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 39e22d3b43..a8a37b6c85 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -40,10 +40,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        # In the worker store this is an ID tracker which we overwrite in the non-worker
-        # class below that is used on the main process.
-        self._pushers_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "pushers",
-            "id",
-            extra_tables=[("deleted_pushers", "stream_id")],
-            is_writer=hs.config.worker.worker_app is None,
+        self._instance_name = hs.get_instance_name()
+
+        self._pushers_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="pushers",
+            instance_name=self._instance_name,
+            tables=[
+                ("pushers", "instance_name", "id"),
+                ("deleted_pushers", "instance_name", "stream_id"),
+            ],
+            sequence_name="pushers_sequence",
+            writers=["master"],
         )
 
         self.db_pool.updates.register_background_update_handler(
@@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
 class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
     # Because we have write access, this will be a StreamIdGenerator
     # (see PusherWorkerStore.__init__)
-    _pushers_id_gen: AbstractStreamIdGenerator
+    _pushers_id_gen: MultiWriterIdGenerator
 
     async def add_pusher(
         self,
@@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
+                    "instance_name": self._instance_name,
                     "enabled": enabled,
                     "device_id": device_id,
                     # XXX(quenting): We're only really persisting the access token ID
@@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
                 table="deleted_pushers",
                 values={
                     "stream_id": stream_id,
+                    "instance_name": self._instance_name,
                     "app_id": app_id,
                     "pushkey": pushkey,
                     "user_id": user_id,
@@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="deleted_pushers",
-                keys=("stream_id", "app_id", "pushkey", "user_id"),
+                keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
                 values=[
-                    (stream_id, pusher.app_id, pusher.pushkey, user_id)
+                    (
+                        stream_id,
+                        self._instance_name,
+                        pusher.app_id,
+                        pusher.pushkey,
+                        user_id,
+                    )
                     for stream_id, pusher in zip(stream_ids, pushers)
                 ],
             )