diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
return (max if step > 0 else min)(current_id, step)
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
- @abc.abstractmethod
- def get_next(self) -> AsyncContextManager[int]:
- raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+ """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+ Stream IDs are monotonically increasing or decreasing integers representing write
+ transactions. The "current" stream ID is the stream ID such that all transactions
+ with equal or smaller stream IDs have completed. Since transactions may complete out
+ of order, this is not the same as the stream ID of the last completed transaction.
+
+ Completed transactions include both committed transactions and transactions that
+ have been rolled back.
+ """
@abc.abstractmethod
- def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ def advance(self, instance_name: str, new_id: int) -> None:
+ """Advance the position of the named writer to the given ID, if greater
+ than existing entry.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token(self) -> int:
+ """Returns the maximum stream id such that all stream ids less than or
+ equal to it have been successfully persisted.
+
+ Returns:
+ The maximum stream id.
+ """
raise NotImplementedError()
@abc.abstractmethod
def get_current_token_for_writer(self, instance_name: str) -> int:
+ """Returns the position of the given writer.
+
+ For streams with single writers this is equivalent to `get_current_token`.
+ """
+ raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+ """Generates stream IDs for a stream that may have multiple writers.
+
+ Each stream ID represents a write transaction, whose completion is tracked
+ so that the "current" stream ID of the stream can be determined.
+
+ See `AbstractStreamIdTracker` for more details.
+ """
+
+ @abc.abstractmethod
+ def get_next(self) -> AsyncContextManager[int]:
+ """
+ Usage:
+ async with stream_id_gen.get_next() as stream_id:
+ # ... persist event ...
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+ """
+ Usage:
+ async with stream_id_gen.get_next(n) as stream_ids:
+ # ... persist events ...
+ """
raise NotImplementedError()
class StreamIdGenerator(AbstractStreamIdGenerator):
- """Used to generate new stream ids when persisting events while keeping
- track of which transactions have been completed.
+ """Generates and tracks stream IDs for a stream with a single writer.
- This allows us to get the "current" stream id, i.e. the stream id such that
- all ids less than or equal to it have completed. This handles the fact that
- persistence of events can complete out of order.
+ This class must only be used when the current Synapse process is the sole
+ writer for a stream.
Args:
db_conn(connection): A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
# The key and values are the same, but we never look at the values.
self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
+ def advance(self, instance_name: str, new_id: int) -> None:
+ # `StreamIdGenerator` should only be used when there is a single writer,
+ # so replication should never happen.
+ raise Exception("Replication is not supported by StreamIdGenerator")
+
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
with self._lock:
self._current += self._step
next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
- """
- Usage:
- async with stream_id_gen.get_next(n) as stream_ids:
- # ... persist events ...
- """
with self._lock:
next_ids = range(
self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return _AsyncCtxManagerWrapper(manager())
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
-
- Returns:
- The maximum stream id.
- """
with self._lock:
if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer.
-
- For streams with single writers this is equivalent to
- `get_current_token`.
- """
return self.get_current_token()
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
- """An ID generator that tracks a stream that can have multiple writers.
+ """Generates and tracks stream IDs for a stream with multiple writers.
Uses a Postgres sequence to coordinate ID assignment, but positions of other
writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return stream_ids
def get_next(self) -> AsyncContextManager[int]:
- """
- Usage:
- async with stream_id_gen.get_next() as stream_id:
- # ... persist event ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
- """
- Usage:
- async with stream_id_gen.get_next_mult(5) as stream_ids:
- # ... persist events ...
- """
-
# If we have a list of instances that are allowed to write to this
# stream, make sure we're in it.
if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._add_persisted_position(next_id)
def get_current_token(self) -> int:
- """Returns the maximum stream id such that all stream ids less than or
- equal to it have been successfully persisted.
- """
-
return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
- """Returns the position of the given writer."""
-
# If we don't have an entry for the given instance name, we assume it's a
# new writer.
#
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
}
def advance(self, instance_name: str, new_id: int) -> None:
- """Advance the position of the named writer to the given ID, if greater
- than existing entry.
- """
-
new_id *= self._return_factor
with self._lock:
|