diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 2dfe4c0b66..0d7108f01b 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
step: int = 1,
+ is_writer: bool = True,
) -> None:
assert step != 0
self._lock = threading.Lock()
self._step: int = step
self._current: int = _load_current_id(db_conn, table, column, step)
+ self._is_writer = is_writer
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
@@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def advance(self, instance_name: str, new_id: int) -> None:
- # `StreamIdGenerator` should only be used when there is a single writer,
- # so replication should never happen.
- raise Exception("Replication is not supported by StreamIdGenerator")
+ # Advance should never be called on a writer instance, only over replication
+ if self._is_writer:
+ raise Exception("Replication is not supported by writer StreamIdGenerator")
+
+ self._current = (max if self._step > 0 else min)(self._current, new_id)
def get_next(self) -> AsyncContextManager[int]:
with self._lock:
@@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
+ if not self._is_writer:
+ return self._current
+
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
|