summary refs log tree commit diff
path: root/synapse/storage/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util')
-rw-r--r--synapse/storage/util/id_generators.py103
-rw-r--r--synapse/storage/util/sequence.py8
2 files changed, 108 insertions, 3 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ddb5c8c60c..5b07847773 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,9 +14,10 @@
 # limitations under the License.
 
 import contextlib
+import heapq
 import threading
 from collections import deque
-from typing import Dict, Set
+from typing import Dict, List, Set
 
 from typing_extensions import Deque
 
@@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        # We track the max position where we know everything before has been
+        # persisted. This is done by a) looking at the min across all instances
+        # and b) noting that if we have seen a run of persisted positions
+        # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+        #
+        # Note: There is no guarentee that the IDs generated by the sequence
+        # will be gapless; gaps can form when e.g. a transaction was rolled
+        # back. This means that sometimes we won't be able to skip forward the
+        # position even though everything has been persisted. However, since
+        # gaps should be relatively rare it's still worth doing the book keeping
+        # that allows us to skip forwards when there are gapless runs of
+        # positions.
+        self._persisted_upto_position = (
+            min(self._current_positions.values()) if self._current_positions else 0
+        )
+        self._known_persisted_positions = []  # type: List[int]
+
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
     def _load_current_ids(
@@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
 
         return current_positions
 
-    def _load_next_id_txn(self, txn):
+    def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
 
+    def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+        return self._sequence_gen.get_next_mult_txn(txn, n)
+
     async def get_next(self):
         """
         Usage:
@@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
 
         return manager()
 
+    async def get_next_mult(self, n: int):
+        """
+        Usage:
+            with await stream_id_gen.get_next_mult(5) as stream_ids:
+                # ... persist events ...
+        """
+        next_ids = await self._db.runInteraction(
+            "_load_next_mult_id", self._load_next_mult_id_txn, n
+        )
+
+        # Assert the fetched ID is actually greater than any ID we've already
+        # seen. If not, then the sequence and table have got out of sync
+        # somehow.
+        assert max(self.get_positions().values(), default=0) < min(next_ids)
+
+        with self._lock:
+            self._unfinished_ids.update(next_ids)
+
+        @contextlib.contextmanager
+        def manager():
+            try:
+                yield next_ids
+            finally:
+                for i in next_ids:
+                    self._mark_id_as_finished(i)
+
+        return manager()
+
     def get_next_txn(self, txn: LoggingTransaction):
         """
         Usage:
@@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
             self._current_positions[instance_name] = max(
                 new_id, self._current_positions.get(instance_name, 0)
             )
+
+            self._add_persisted_position(new_id)
+
+    def get_persisted_upto_position(self) -> int:
+        """Get the max position where all previous positions have been
+        persisted.
+
+        Note: In the worst case scenario this will be equal to the minimum
+        position across writers. This means that the returned position here can
+        lag if one writer doesn't write very often.
+        """
+
+        with self._lock:
+            return self._persisted_upto_position
+
+    def _add_persisted_position(self, new_id: int):
+        """Record that we have persisted a position.
+
+        This is used to keep the `_current_positions` up to date.
+        """
+
+        # We require that the lock is locked by caller
+        assert self._lock.locked()
+
+        heapq.heappush(self._known_persisted_positions, new_id)
+
+        # We move the current min position up if the minimum current positions
+        # of all instances is higher (since by definition all positions less
+        # that that have been persisted).
+        min_curr = min(self._current_positions.values())
+        self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+        # We now iterate through the seen positions, discarding those that are
+        # less than the current min positions, and incrementing the min position
+        # if its exactly one greater.
+        #
+        # This is also where we discard items from `_known_persisted_positions`
+        # (to ensure the list doesn't infinitely grow).
+        while self._known_persisted_positions:
+            if self._known_persisted_positions[0] <= self._persisted_upto_position:
+                heapq.heappop(self._known_persisted_positions)
+            elif (
+                self._known_persisted_positions[0] == self._persisted_upto_position + 1
+            ):
+                heapq.heappop(self._known_persisted_positions)
+                self._persisted_upto_position += 1
+            else:
+                # There was a gap in seen positions, so there is nothing more to
+                # do.
+                break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 63dfea4220..ffc1894748 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import abc
 import threading
-from typing import Callable, Optional
+from typing import Callable, List, Optional
 
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
 from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
         txn.execute("SELECT nextval(?)", (self._sequence_name,))
         return txn.fetchone()[0]
 
+    def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+        txn.execute(
+            "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+        )
+        return [i for (i,) in txn]
+
 
 GetFirstCallbackType = Callable[[Cursor], int]