diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index ddb5c8c60c..5b07847773 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -14,9 +14,10 @@
# limitations under the License.
import contextlib
+import heapq
import threading
from collections import deque
-from typing import Dict, Set
+from typing import Dict, List, Set
from typing_extensions import Deque
@@ -210,6 +211,23 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # We track the max position where we know everything before has been
+ # persisted. This is done by a) looking at the min across all instances
+ # and b) noting that if we have seen a run of persisted positions
+ # without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
+ #
+ # Note: There is no guarentee that the IDs generated by the sequence
+ # will be gapless; gaps can form when e.g. a transaction was rolled
+ # back. This means that sometimes we won't be able to skip forward the
+ # position even though everything has been persisted. However, since
+ # gaps should be relatively rare it's still worth doing the book keeping
+ # that allows us to skip forwards when there are gapless runs of
+ # positions.
+ self._persisted_upto_position = (
+ min(self._current_positions.values()) if self._current_positions else 0
+ )
+ self._known_persisted_positions = [] # type: List[int]
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
def _load_current_ids(
@@ -234,9 +252,12 @@ class MultiWriterIdGenerator:
return current_positions
- def _load_next_id_txn(self, txn):
+ def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
+ def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
+ return self._sequence_gen.get_next_mult_txn(txn, n)
+
async def get_next(self):
"""
Usage:
@@ -262,6 +283,34 @@ class MultiWriterIdGenerator:
return manager()
+ async def get_next_mult(self, n: int):
+ """
+ Usage:
+ with await stream_id_gen.get_next_mult(5) as stream_ids:
+ # ... persist events ...
+ """
+ next_ids = await self._db.runInteraction(
+ "_load_next_mult_id", self._load_next_mult_id_txn, n
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ assert max(self.get_positions().values(), default=0) < min(next_ids)
+
+ with self._lock:
+ self._unfinished_ids.update(next_ids)
+
+ @contextlib.contextmanager
+ def manager():
+ try:
+ yield next_ids
+ finally:
+ for i in next_ids:
+ self._mark_id_as_finished(i)
+
+ return manager()
+
def get_next_txn(self, txn: LoggingTransaction):
"""
Usage:
@@ -326,3 +375,53 @@ class MultiWriterIdGenerator:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
)
+
+ self._add_persisted_position(new_id)
+
+ def get_persisted_upto_position(self) -> int:
+ """Get the max position where all previous positions have been
+ persisted.
+
+ Note: In the worst case scenario this will be equal to the minimum
+ position across writers. This means that the returned position here can
+ lag if one writer doesn't write very often.
+ """
+
+ with self._lock:
+ return self._persisted_upto_position
+
+ def _add_persisted_position(self, new_id: int):
+ """Record that we have persisted a position.
+
+ This is used to keep the `_current_positions` up to date.
+ """
+
+ # We require that the lock is locked by caller
+ assert self._lock.locked()
+
+ heapq.heappush(self._known_persisted_positions, new_id)
+
+ # We move the current min position up if the minimum current positions
+ # of all instances is higher (since by definition all positions less
+ # that that have been persisted).
+ min_curr = min(self._current_positions.values())
+ self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+
+ # We now iterate through the seen positions, discarding those that are
+ # less than the current min positions, and incrementing the min position
+ # if its exactly one greater.
+ #
+ # This is also where we discard items from `_known_persisted_positions`
+ # (to ensure the list doesn't infinitely grow).
+ while self._known_persisted_positions:
+ if self._known_persisted_positions[0] <= self._persisted_upto_position:
+ heapq.heappop(self._known_persisted_positions)
+ elif (
+ self._known_persisted_positions[0] == self._persisted_upto_position + 1
+ ):
+ heapq.heappop(self._known_persisted_positions)
+ self._persisted_upto_position += 1
+ else:
+ # There was a gap in seen positions, so there is nothing more to
+ # do.
+ break
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 63dfea4220..ffc1894748 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -14,7 +14,7 @@
# limitations under the License.
import abc
import threading
-from typing import Callable, Optional
+from typing import Callable, List, Optional
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.types import Cursor
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
txn.execute("SELECT nextval(?)", (self._sequence_name,))
return txn.fetchone()[0]
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ txn.execute(
+ "SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
+ )
+ return [i for (i,) in txn]
+
GetFirstCallbackType = Callable[[Cursor], int]
|