diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9c3eafb562..bd3c81827f 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -650,8 +650,8 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
next_id = self._load_next_id_txn(txn)
- txn.call_after(self._mark_id_as_finished, next_id)
- txn.call_on_exception(self._mark_id_as_finished, next_id)
+ txn.call_after(self._mark_ids_as_finished, [next_id])
+ txn.call_on_exception(self._mark_ids_as_finished, [next_id])
txn.call_after(self._notifier.notify_replication)
# Update the `stream_positions` table with newly updated stream
@@ -671,14 +671,50 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return self._return_factor * next_id
- def _mark_id_as_finished(self, next_id: int) -> None:
- """The ID has finished being processed so we should advance the
+ def get_next_mult_txn(self, txn: LoggingTransaction, n: int) -> List[int]:
+ """
+ Usage:
+
+ stream_id = stream_id_gen.get_next_txn(txn)
+ # ... persist event ...
+ """
+
+ # If we have a list of instances that are allowed to write to this
+ # stream, make sure we're in it.
+ if self._writers and self._instance_name not in self._writers:
+ raise Exception("Tried to allocate stream ID on non-writer")
+
+ next_ids = self._load_next_mult_id_txn(txn, n)
+
+ txn.call_after(self._mark_ids_as_finished, next_ids)
+ txn.call_on_exception(self._mark_ids_as_finished, next_ids)
+ txn.call_after(self._notifier.notify_replication)
+
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persisted row with the correct instance name.
+ if self._writers:
+ txn.call_after(
+ run_as_background_process,
+ "MultiWriterIdGenerator._update_table",
+ self._db.runInteraction,
+ "MultiWriterIdGenerator._update_table",
+ self._update_stream_positions_table_txn,
+ )
+
+ return [self._return_factor * next_id for next_id in next_ids]
+
+ def _mark_ids_as_finished(self, next_ids: List[int]) -> None:
+ """These IDs have finished being processed so we should advance the
current position if possible.
"""
with self._lock:
- self._unfinished_ids.discard(next_id)
- self._finished_ids.add(next_id)
+ self._unfinished_ids.difference_update(next_ids)
+ self._finished_ids.update(next_ids)
new_cur: Optional[int] = None
@@ -727,7 +763,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
curr, new_cur, self._max_position_of_local_instance
)
- self._add_persisted_position(next_id)
+ # TODO Can we call this for just the last position or somehow batch
+ # _add_persisted_position.
+ for next_id in next_ids:
+ self._add_persisted_position(next_id)
def get_current_token(self) -> int:
return self.get_persisted_upto_position()
@@ -933,8 +972,7 @@ class _MultiWriterCtxManager:
exc: Optional[BaseException],
tb: Optional[TracebackType],
) -> bool:
- for i in self.stream_ids:
- self.id_gen._mark_id_as_finished(i)
+ self.id_gen._mark_ids_as_finished(self.stream_ids)
self.notifier.notify_replication()
|