summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/account_data.py2
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/deviceinbox.py3
-rw-r--r--synapse/storage/databases/main/devices.py1
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py5
-rw-r--r--synapse/storage/databases/main/events_worker.py10
-rw-r--r--synapse/storage/databases/main/presence.py3
-rw-r--r--synapse/storage/databases/main/push_rule.py1
-rw-r--r--synapse/storage/databases/main/pusher.py1
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/databases/main/room.py6
-rw-r--r--synapse/storage/util/id_generators.py26
12 files changed, 54 insertions, 7 deletions
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 881d7089db..8a359d7eb8 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             self._account_data_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="account_data",
                 instance_name=self._instance_name,
                 tables=[
@@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             # SQLite).
             self._account_data_id_gen = StreamIdGenerator(
                 db_conn,
+                hs.get_replication_notifier(),
                 "room_account_data",
                 "stream_id",
                 extra_tables=[("room_tags_revisions", "stream_id")],
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2179a8bf59..5b66431691 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
                 tables=[
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 713be91c5d..8e61aba454 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -91,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 MultiWriterIdGenerator(
                     db_conn=db_conn,
                     db=database,
+                    notifier=hs.get_replication_notifier(),
                     stream_name="to_device",
                     instance_name=self._instance_name,
                     tables=[("device_inbox", "instance_name", "stream_id")],
@@ -101,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         else:
             self._can_write_to_device = True
             self._device_inbox_id_gen = StreamIdGenerator(
-                db_conn, "device_inbox", "stream_id"
+                db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
             )
 
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index cd186c8472..903606fb46 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -92,6 +92,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         # class below that is used on the main process.
         self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
             db_conn,
+            hs.get_replication_notifier(),
             "device_lists_stream",
             "stream_id",
             extra_tables=[
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4c691642e2..c4ac6c33ba 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         super().__init__(database, db_conn, hs)
 
         self._cross_signing_id_gen = StreamIdGenerator(
-            db_conn, "e2e_cross_signing_keys", "stream_id"
+            db_conn,
+            hs.get_replication_notifier(),
+            "e2e_cross_signing_keys",
+            "stream_id",
         )
 
     async def set_e2e_device_keys(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d150fa8a94..d8a8bcafb6 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -191,6 +191,7 @@ class EventsWorkerStore(SQLBaseStore):
             self._stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="events",
                 instance_name=hs.get_instance_name(),
                 tables=[("events", "instance_name", "stream_ordering")],
@@ -200,6 +201,7 @@ class EventsWorkerStore(SQLBaseStore):
             self._backfill_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="backfill",
                 instance_name=hs.get_instance_name(),
                 tables=[("events", "instance_name", "stream_ordering")],
@@ -217,12 +219,14 @@ class EventsWorkerStore(SQLBaseStore):
             # SQLite).
             self._stream_id_gen = StreamIdGenerator(
                 db_conn,
+                hs.get_replication_notifier(),
                 "events",
                 "stream_ordering",
                 is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
             )
             self._backfill_id_gen = StreamIdGenerator(
                 db_conn,
+                hs.get_replication_notifier(),
                 "events",
                 "stream_ordering",
                 step=-1,
@@ -300,6 +304,7 @@ class EventsWorkerStore(SQLBaseStore):
             self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="un_partial_stated_event_stream",
                 instance_name=hs.get_instance_name(),
                 tables=[
@@ -311,7 +316,10 @@ class EventsWorkerStore(SQLBaseStore):
             )
         else:
             self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
-                db_conn, "un_partial_stated_event_stream", "stream_id"
+                db_conn,
+                hs.get_replication_notifier(),
+                "un_partial_stated_event_stream",
+                "stream_id",
             )
 
     def get_un_partial_stated_events_token(self) -> int:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 7b60815043..beb210f8ee 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
             self._presence_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="presence_stream",
                 instance_name=self._instance_name,
                 tables=[("presence_stream", "instance_name", "stream_id")],
@@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
             )
         else:
             self._presence_id_gen = StreamIdGenerator(
-                db_conn, "presence_stream", "stream_id"
+                db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
             )
 
         self.hs = hs
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 03182887d1..14ca167b34 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -118,6 +118,7 @@ class PushRulesWorkerStore(
         # class below that is used on the main process.
         self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
             db_conn,
+            hs.get_replication_notifier(),
             "push_rules_stream",
             "stream_id",
             is_writer=hs.config.worker.worker_app is None,
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7f24a3b6ec..df53e726e6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore):
         # class below that is used on the main process.
         self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
             db_conn,
+            hs.get_replication_notifier(),
             "pushers",
             "id",
             extra_tables=[("deleted_pushers", "stream_id")],
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 86f5bce5f0..3468f354e6 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             self._receipts_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="receipts",
                 instance_name=self._instance_name,
                 tables=[("receipts_linearized", "instance_name", "stream_id")],
@@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             # SQLite).
             self._receipts_id_gen = StreamIdGenerator(
                 db_conn,
+                hs.get_replication_notifier(),
                 "receipts_linearized",
                 "stream_id",
                 is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 78906a5e1d..7264a33cd4 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -126,6 +126,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="un_partial_stated_room_stream",
                 instance_name=self._instance_name,
                 tables=[
@@ -137,7 +138,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             )
         else:
             self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
-                db_conn, "un_partial_stated_room_stream", "stream_id"
+                db_conn,
+                hs.get_replication_notifier(),
+                "un_partial_stated_room_stream",
+                "stream_id",
             )
 
     async def store_room(
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 8670ffbfa3..9adff3f4f5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,6 +20,7 @@ from collections import OrderedDict
 from contextlib import contextmanager
 from types import TracebackType
 from typing import (
+    TYPE_CHECKING,
     AsyncContextManager,
     ContextManager,
     Dict,
@@ -49,6 +50,9 @@ from synapse.storage.database import (
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
+if TYPE_CHECKING:
+    from synapse.notifier import ReplicationNotifier
+
 logger = logging.getLogger(__name__)
 
 
@@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
     def __init__(
         self,
         db_conn: LoggingDatabaseConnection,
+        notifier: "ReplicationNotifier",
         table: str,
         column: str,
         extra_tables: Iterable[Tuple[str, str]] = (),
@@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         # The key and values are the same, but we never look at the values.
         self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
+        self._notifier = notifier
+
     def advance(self, instance_name: str, new_id: int) -> None:
         # Advance should never be called on a writer instance, only over replication
         if self._is_writer:
@@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
                 with self._lock:
                     self._unfinished_ids.pop(next_id)
 
+                self._notifier.notify_replication()
+
         return _AsyncCtxManagerWrapper(manager())
 
     def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
@@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
                     for next_id in next_ids:
                         self._unfinished_ids.pop(next_id)
 
+                self._notifier.notify_replication()
+
         return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self) -> int:
@@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         self,
         db_conn: LoggingDatabaseConnection,
         db: DatabasePool,
+        notifier: "ReplicationNotifier",
         stream_name: str,
         instance_name: str,
         tables: List[Tuple[str, str, str]],
@@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         positive: bool = True,
     ) -> None:
         self._db = db
+        self._notifier = notifier
         self._stream_name = stream_name
         self._instance_name = instance_name
         self._positive = positive
@@ -535,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
         # controls the return type. If `None` or omitted, the context manager yields
         # a single integer stream_id; otherwise it yields a list of stream_ids.
-        return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
+        return cast(
+            AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
+        )
 
     def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
         # If we have a list of instances that are allowed to write to this
@@ -544,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             raise Exception("Tried to allocate stream ID on non-writer")
 
         # Cast safety: see get_next.
-        return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
+        return cast(
+            AsyncContextManager[List[int]],
+            _MultiWriterCtxManager(self, self._notifier, n),
+        )
 
     def get_next_txn(self, txn: LoggingTransaction) -> int:
         """
@@ -563,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
+        txn.call_after(self._notifier.notify_replication)
 
         # Update the `stream_positions` table with newly updated stream
         # ID (unless self._writers is not set in which case we don't
@@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
     """Async context manager returned by MultiWriterIdGenerator"""
 
     id_gen: MultiWriterIdGenerator
+    notifier: "ReplicationNotifier"
     multiple_ids: Optional[int] = None
     stream_ids: List[int] = attr.Factory(list)
 
@@ -814,6 +834,8 @@ class _MultiWriterCtxManager:
         for i in self.stream_ids:
             self.id_gen._mark_id_as_finished(i)
 
+        self.notifier.notify_replication()
+
         if exc_type is not None:
             return False