diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 8670ffbfa3..9adff3f4f5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,6 +20,7 @@ from collections import OrderedDict
from contextlib import contextmanager
from types import TracebackType
from typing import (
+ TYPE_CHECKING,
AsyncContextManager,
ContextManager,
Dict,
@@ -49,6 +50,9 @@ from synapse.storage.database import (
from synapse.storage.types import Cursor
from synapse.storage.util.sequence import PostgresSequenceGenerator
+if TYPE_CHECKING:
+ from synapse.notifier import ReplicationNotifier
+
logger = logging.getLogger(__name__)
@@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def __init__(
self,
db_conn: LoggingDatabaseConnection,
+ notifier: "ReplicationNotifier",
table: str,
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
@@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
+ self._notifier = notifier
+
def advance(self, instance_name: str, new_id: int) -> None:
# Advance should never be called on a writer instance, only over replication
if self._is_writer:
@@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
with self._lock:
self._unfinished_ids.pop(next_id)
+ self._notifier.notify_replication()
+
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
@@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
for next_id in next_ids:
self._unfinished_ids.pop(next_id)
+ self._notifier.notify_replication()
+
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
@@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self,
db_conn: LoggingDatabaseConnection,
db: DatabasePool,
+ notifier: "ReplicationNotifier",
stream_name: str,
instance_name: str,
tables: List[Tuple[str, str, str]],
@@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive: bool = True,
) -> None:
self._db = db
+ self._notifier = notifier
self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
@@ -535,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
# controls the return type. If `None` or omitted, the context manager yields
# a single integer stream_id; otherwise it yields a list of stream_ids.
- return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
+ return cast(
+ AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
+ )
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
# If we have a list of instances that are allowed to write to this
@@ -544,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
raise Exception("Tried to allocate stream ID on non-writer")
# Cast safety: see get_next.
- return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
+ return cast(
+ AsyncContextManager[List[int]],
+ _MultiWriterCtxManager(self, self._notifier, n),
+ )
def get_next_txn(self, txn: LoggingTransaction) -> int:
"""
@@ -563,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
+ txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
# ID (unless self._writers is not set in which case we don't
@@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
"""Async context manager returned by MultiWriterIdGenerator"""
id_gen: MultiWriterIdGenerator
+ notifier: "ReplicationNotifier"
multiple_ids: Optional[int] = None
stream_ids: List[int] = attr.Factory(list)
@@ -814,6 +834,8 @@ class _MultiWriterCtxManager:
for i in self.stream_ids:
self.id_gen._mark_id_as_finished(i)
+ self.notifier.notify_replication()
+
if exc_type is not None:
return False
|