summary refs log tree commit diff
path: root/synapse/storage/databases/main/pusher.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/pusher.py')
-rw-r--r--synapse/storage/databases/main/pusher.py42
1 files changed, 26 insertions, 16 deletions
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 39e22d3b43..a8a37b6c85 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -40,10 +40,7 @@ from synapse.storage.database import (
     LoggingDatabaseConnection,
     LoggingTransaction,
 )
-from synapse.storage.util.id_generators import (
-    AbstractStreamIdGenerator,
-    StreamIdGenerator,
-)
+from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
@@ -84,15 +81,20 @@ class PusherWorkerStore(SQLBaseStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        # In the worker store this is an ID tracker which we overwrite in the non-worker
-        # class below that is used on the main process.
-        self._pushers_id_gen = StreamIdGenerator(
-            db_conn,
-            hs.get_replication_notifier(),
-            "pushers",
-            "id",
-            extra_tables=[("deleted_pushers", "stream_id")],
-            is_writer=hs.config.worker.worker_app is None,
+        self._instance_name = hs.get_instance_name()
+
+        self._pushers_id_gen = MultiWriterIdGenerator(
+            db_conn=db_conn,
+            db=database,
+            notifier=hs.get_replication_notifier(),
+            stream_name="pushers",
+            instance_name=self._instance_name,
+            tables=[
+                ("pushers", "instance_name", "id"),
+                ("deleted_pushers", "instance_name", "stream_id"),
+            ],
+            sequence_name="pushers_sequence",
+            writers=["master"],
         )
 
         self.db_pool.updates.register_background_update_handler(
@@ -655,7 +657,7 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
 class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
     # Because we have write access, this will be a StreamIdGenerator
     # (see PusherWorkerStore.__init__)
-    _pushers_id_gen: AbstractStreamIdGenerator
+    _pushers_id_gen: MultiWriterIdGenerator
 
     async def add_pusher(
         self,
@@ -688,6 +690,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
                     "last_stream_ordering": last_stream_ordering,
                     "profile_tag": profile_tag,
                     "id": stream_id,
+                    "instance_name": self._instance_name,
                     "enabled": enabled,
                     "device_id": device_id,
                     # XXX(quenting): We're only really persisting the access token ID
@@ -735,6 +738,7 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
                 table="deleted_pushers",
                 values={
                     "stream_id": stream_id,
+                    "instance_name": self._instance_name,
                     "app_id": app_id,
                     "pushkey": pushkey,
                     "user_id": user_id,
@@ -773,9 +777,15 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
             self.db_pool.simple_insert_many_txn(
                 txn,
                 table="deleted_pushers",
-                keys=("stream_id", "app_id", "pushkey", "user_id"),
+                keys=("stream_id", "instance_name", "app_id", "pushkey", "user_id"),
                 values=[
-                    (stream_id, pusher.app_id, pusher.pushkey, user_id)
+                    (
+                        stream_id,
+                        self._instance_name,
+                        pusher.app_id,
+                        pusher.pushkey,
+                        user_id,
+                    )
                     for stream_id, pusher in zip(stream_ids, pushers)
                 ],
             )