1 files changed, 49 insertions, 13 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 9adff3f4f5..d2c874b9a8 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -93,8 +93,11 @@ def _load_current_id(
return res
-class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
- """Tracks the "current" stream ID of a stream that may have multiple writers.
+class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
+ """Generates or tracks stream IDs for a stream that may have multiple writers.
+
+ Each stream ID represents a write transaction, whose completion is tracked
+ so that the "current" stream ID of the stream can be determined.
Stream IDs are monotonically increasing or decreasing integers representing write
transactions. The "current" stream ID is the stream ID such that all transactions
@@ -130,16 +133,6 @@ class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
"""
raise NotImplementedError()
-
-class AbstractStreamIdGenerator(AbstractStreamIdTracker):
- """Generates stream IDs for a stream that may have multiple writers.
-
- Each stream ID represents a write transaction, whose completion is tracked
- so that the "current" stream ID of the stream can be determined.
-
- See `AbstractStreamIdTracker` for more details.
- """
-
@abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
"""
@@ -158,6 +151,15 @@ class AbstractStreamIdGenerator(AbstractStreamIdTracker):
"""
raise NotImplementedError()
+ @abc.abstractmethod
+ def get_next_txn(self, txn: LoggingTransaction) -> int:
+ """
+ Usage:
+ stream_id_gen.get_next_txn(txn)
+ # ... persist events ...
+ """
+ raise NotImplementedError()
+
class StreamIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with a single writer.
@@ -263,6 +265,40 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
+ def get_next_txn(self, txn: LoggingTransaction) -> int:
+ """
+ Retrieve the next stream ID from within a database transaction.
+
+ Clean-up functions will be called when the transaction finishes.
+
+ Args:
+ txn: The database transaction object.
+
+ Returns:
+ The next stream ID.
+ """
+ if not self._is_writer:
+ raise Exception("Tried to allocate stream ID on non-writer")
+
+ # Get the next stream ID.
+ with self._lock:
+ self._current += self._step
+ next_id = self._current
+
+ self._unfinished_ids[next_id] = next_id
+
+ def clear_unfinished_id(id_to_clear: int) -> None:
+ """A function to mark processing this ID as finished"""
+ with self._lock:
+ self._unfinished_ids.pop(id_to_clear)
+
+ # Mark this ID as finished once the database transaction itself finishes.
+ txn.call_after(clear_unfinished_id, next_id)
+ txn.call_on_exception(clear_unfinished_id, next_id)
+
+ # Return the new ID.
+ return next_id
+
def get_current_token(self) -> int:
if not self._is_writer:
return self._current
@@ -568,7 +604,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""
Usage:
- stream_id = stream_id_gen.get_next(txn)
+ stream_id = stream_id_gen.get_next_txn(txn)
# ... persist event ...
"""
|