diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 01206950a9..40fd781a6a 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -27,13 +27,18 @@ from typing import (
)
from synapse.push import PusherConfig, ThrottleParams
+from synapse.replication.tcp.streams import PushersStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
)
-from synapse.storage.util.id_generators import StreamIdGenerator
+from synapse.storage.util.id_generators import (
+ AbstractStreamIdGenerator,
+ AbstractStreamIdTracker,
+ StreamIdGenerator,
+)
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached
@@ -52,8 +57,15 @@ class PusherWorkerStore(SQLBaseStore):
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
- self._pushers_id_gen = StreamIdGenerator(
- db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
+
+ # In the worker store this is an ID tracker which we overwrite in the non-worker
+ # class below that is used on the main process.
+ self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
+ db_conn,
+ "pushers",
+ "id",
+ extra_tables=[("deleted_pushers", "stream_id")],
+ is_writer=hs.config.worker.worker_app is None,
)
self.db_pool.updates.register_background_update_handler(
@@ -96,6 +108,16 @@ class PusherWorkerStore(SQLBaseStore):
yield PusherConfig(**r)
+ def get_pushers_stream_token(self) -> int:
+ return self._pushers_id_gen.get_current_token()
+
+ def process_replication_rows(
+ self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
+ ) -> None:
+ if stream_name == PushersStream.NAME:
+ self._pushers_id_gen.advance(instance_name, token)
+ return super().process_replication_rows(stream_name, instance_name, token, rows)
+
async def get_pushers_by_app_id_and_pushkey(
self, app_id: str, pushkey: str
) -> Iterator[PusherConfig]:
@@ -303,14 +325,11 @@ class PusherWorkerStore(SQLBaseStore):
async def set_throttle_params(
self, pusher_id: str, room_id: str, params: ThrottleParams
) -> None:
- # no need to lock because `pusher_throttle` has a primary key on
- # (pusher, room_id) so simple_upsert will retry
await self.db_pool.simple_upsert(
"pusher_throttle",
{"pusher": pusher_id, "room_id": room_id},
{"last_sent_ts": params.last_sent_ts, "throttle_ms": params.throttle_ms},
desc="set_throttle_params",
- lock=False,
)
async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
@@ -545,8 +564,9 @@ class PusherBackgroundUpdatesStore(SQLBaseStore):
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
- def get_pushers_stream_token(self) -> int:
- return self._pushers_id_gen.get_current_token()
+ # Because we have write access, this will be a StreamIdGenerator
+ # (see PusherWorkerStore.__init__)
+ _pushers_id_gen: AbstractStreamIdGenerator
async def add_pusher(
self,
@@ -566,8 +586,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
device_id: Optional[str] = None,
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
- # no need to lock because `pushers` has a unique key on
- # (app_id, pushkey, user_name) so simple_upsert will retry
await self.db_pool.simple_upsert(
table="pushers",
keyvalues={"app_id": app_id, "pushkey": pushkey, "user_name": user_id},
@@ -586,7 +604,6 @@ class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
"device_id": device_id,
},
desc="add_pusher",
- lock=False,
)
user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
|