summary refs log tree commit diff
path: root/synapse/storage/util/id_generators.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util/id_generators.py')
-rw-r--r--synapse/storage/util/id_generators.py116
1 files changed, 62 insertions, 54 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ac56bc9a05..4ff3013908 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -89,31 +89,77 @@ def _load_current_id(
     return (max if step > 0 else min)(current_id, step)
 
 
-class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
-    def get_next(self) -> AsyncContextManager[int]:
-        raise NotImplementedError()
+class AbstractStreamIdTracker(metaclass=abc.ABCMeta):
+    """Tracks the "current" stream ID of a stream that may have multiple writers.
+
+    Stream IDs are monotonically increasing or decreasing integers representing write
+    transactions. The "current" stream ID is the stream ID such that all transactions
+    with equal or smaller stream IDs have completed. Since transactions may complete out
+    of order, this is not the same as the stream ID of the last completed transaction.
+
+    Completed transactions include both committed transactions and transactions that
+    have been rolled back.
+    """
 
     @abc.abstractmethod
-    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+    def advance(self, instance_name: str, new_id: int) -> None:
+        """Advance the position of the named writer to the given ID, if greater
+        than existing entry.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token(self) -> int:
+        """Returns the maximum stream id such that all stream ids less than or
+        equal to it have been successfully persisted.
+
+        Returns:
+            The maximum stream id.
+        """
         raise NotImplementedError()
 
     @abc.abstractmethod
     def get_current_token_for_writer(self, instance_name: str) -> int:
+        """Returns the position of the given writer.
+
+        For streams with single writers this is equivalent to `get_current_token`.
+        """
+        raise NotImplementedError()
+
+
+class AbstractStreamIdGenerator(AbstractStreamIdTracker):
+    """Generates stream IDs for a stream that may have multiple writers.
+
+    Each stream ID represents a write transaction, whose completion is tracked
+    so that the "current" stream ID of the stream can be determined.
+
+    See `AbstractStreamIdTracker` for more details.
+    """
+
+    @abc.abstractmethod
+    def get_next(self) -> AsyncContextManager[int]:
+        """
+        Usage:
+            async with stream_id_gen.get_next() as stream_id:
+                # ... persist event ...
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
+        """
+        Usage:
+            async with stream_id_gen.get_next(n) as stream_ids:
+                # ... persist events ...
+        """
         raise NotImplementedError()
 
 
 class StreamIdGenerator(AbstractStreamIdGenerator):
-    """Used to generate new stream ids when persisting events while keeping
-    track of which transactions have been completed.
+    """Generates and tracks stream IDs for a stream with a single writer.
 
-    This allows us to get the "current" stream id, i.e. the stream id such that
-    all ids less than or equal to it have completed. This handles the fact that
-    persistence of events can complete out of order.
+    This class must only be used when the current Synapse process is the sole
+    writer for a stream.
 
     Args:
         db_conn(connection):  A database connection to use to fetch the
@@ -157,12 +203,12 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         # The key and values are the same, but we never look at the values.
         self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
+    def advance(self, instance_name: str, new_id: int) -> None:
+        # `StreamIdGenerator` should only be used when there is a single writer,
+        # so replication should never happen.
+        raise Exception("Replication is not supported by StreamIdGenerator")
+
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
         with self._lock:
             self._current += self._step
             next_id = self._current
@@ -180,11 +226,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next(n) as stream_ids:
-                # ... persist events ...
-        """
         with self._lock:
             next_ids = range(
                 self._current + self._step,
@@ -208,12 +249,6 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-
-        Returns:
-            The maximum stream id.
-        """
         with self._lock:
             if self._unfinished_ids:
                 return next(iter(self._unfinished_ids)) - self._step
@@ -221,16 +256,11 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
             return self._current
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer.
-
-        For streams with single writers this is equivalent to
-        `get_current_token`.
-        """
         return self.get_current_token()
 
 
 class MultiWriterIdGenerator(AbstractStreamIdGenerator):
-    """An ID generator that tracks a stream that can have multiple writers.
+    """Generates and tracks stream IDs for a stream with multiple writers.
 
     Uses a Postgres sequence to coordinate ID assignment, but positions of other
     writers will only get updated when `advance` is called (by replication).
@@ -475,12 +505,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return stream_ids
 
     def get_next(self) -> AsyncContextManager[int]:
-        """
-        Usage:
-            async with stream_id_gen.get_next() as stream_id:
-                # ... persist event ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -492,12 +516,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
 
     def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
-        """
-        Usage:
-            async with stream_id_gen.get_next_mult(5) as stream_ids:
-                # ... persist events ...
-        """
-
         # If we have a list of instances that are allowed to write to this
         # stream, make sure we're in it.
         if self._writers and self._instance_name not in self._writers:
@@ -597,15 +615,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             self._add_persisted_position(next_id)
 
     def get_current_token(self) -> int:
-        """Returns the maximum stream id such that all stream ids less than or
-        equal to it have been successfully persisted.
-        """
-
         return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
-        """Returns the position of the given writer."""
-
         # If we don't have an entry for the given instance name, we assume it's a
         # new writer.
         #
@@ -631,10 +643,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             }
 
     def advance(self, instance_name: str, new_id: int) -> None:
-        """Advance the position of the named writer to the given ID, if greater
-        than existing entry.
-        """
-
         new_id *= self._return_factor
 
         with self._lock: