diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 3c13859faa..0d7108f01b 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -186,11 +186,13 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
column: str,
extra_tables: Iterable[Tuple[str, str]] = (),
step: int = 1,
+ is_writer: bool = True,
) -> None:
assert step != 0
self._lock = threading.Lock()
self._step: int = step
self._current: int = _load_current_id(db_conn, table, column, step)
+ self._is_writer = is_writer
for table, column in extra_tables:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
@@ -204,9 +206,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
def advance(self, instance_name: str, new_id: int) -> None:
- # `StreamIdGenerator` should only be used when there is a single writer,
- # so replication should never happen.
- raise Exception("Replication is not supported by StreamIdGenerator")
+ # Advance should never be called on a writer instance, only over replication
+ if self._is_writer:
+ raise Exception("Replication is not supported by writer StreamIdGenerator")
+
+ self._current = (max if self._step > 0 else min)(self._current, new_id)
def get_next(self) -> AsyncContextManager[int]:
with self._lock:
@@ -249,6 +253,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
+ if not self._is_writer:
+ return self._current
+
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@@ -460,8 +467,17 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# Cast safety: this corresponds to the types returned by the query above.
rows.extend(cast(Iterable[Tuple[str, int]], cur))
- # Sort so that we handle rows in order for each instance.
- rows.sort()
+ # Sort by stream_id (ascending, lowest -> highest) so that we handle
+ # rows in order for each instance because we don't want to overwrite
+ # the current_position of an instance to a lower stream ID than
+ # we're actually at.
+ def sort_by_stream_id_key_func(row: Tuple[str, int]) -> int:
+ (instance, stream_id) = row
+ # If `stream_id` is ever `None`, we will see a `TypeError: '<'
+ # not supported between instances of 'NoneType' and 'X'` error.
+ return stream_id
+
+ rows.sort(key=sort_by_stream_id_key_func)
with self._lock:
for (
|