diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 852bd79fee..670811611f 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -36,7 +36,7 @@ from typing import (
)
import attr
-from sortedcontainers import SortedSet
+from sortedcontainers import SortedList, SortedSet
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import (
@@ -265,6 +265,15 @@ class MultiWriterIdGenerator:
# should be less than the minimum of this set (if not empty).
self._unfinished_ids: SortedSet[int] = SortedSet()
+ # We also need to track when we've requested some new stream IDs but
+ # they haven't yet been added to the `_unfinished_ids` set. Every time
+ # we request a new stream ID we add the current max stream ID to the
+ # list, and remove it once we've added the newly allocated IDs to the
+ # `_unfinished_ids` set. This means that we *may* be allocated stream
+ # IDs above those in the list, and so we can't advance the local current
+ # position beyond the minimum stream ID in this list.
+ self._in_flight_fetches: SortedList[int] = SortedList()
+
# Set of local IDs that we've processed that are larger than the current
# position, due to there being smaller unpersisted IDs.
self._finished_ids: Set[int] = set()
@@ -290,6 +299,9 @@ class MultiWriterIdGenerator:
)
self._known_persisted_positions: List[int] = []
+ # The maximum stream ID that we have seen been allocated across any writer.
+ self._max_seen_allocated_stream_id = 1
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
@@ -305,6 +317,10 @@ class MultiWriterIdGenerator:
# This goes and fills out the above state from the database.
self._load_current_ids(db_conn, tables)
+ self._max_seen_allocated_stream_id = max(
+ self._current_positions.values(), default=1
+ )
+
def _load_current_ids(
self,
db_conn: LoggingDatabaseConnection,
@@ -411,10 +427,32 @@ class MultiWriterIdGenerator:
cur.close()
def _load_next_id_txn(self, txn: Cursor) -> int:
- return self._sequence_gen.get_next_id_txn(txn)
+ stream_ids = self._load_next_mult_id_txn(txn, 1)
+ return stream_ids[0]
def _load_next_mult_id_txn(self, txn: Cursor, n: int) -> List[int]:
- return self._sequence_gen.get_next_mult_txn(txn, n)
+ # We need to track that we've requested some more stream IDs, and what
+ # the current max allocated stream ID is. This is to prevent a race
+ # where we've been allocated stream IDs but they have not yet been added
+ # to the `_unfinished_ids` set, allowing the current position to advance
+ # past them.
+ with self._lock:
+ current_max = self._max_seen_allocated_stream_id
+ self._in_flight_fetches.add(current_max)
+
+ try:
+ stream_ids = self._sequence_gen.get_next_mult_txn(txn, n)
+
+ with self._lock:
+ self._unfinished_ids.update(stream_ids)
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, self._unfinished_ids[-1]
+ )
+ finally:
+ with self._lock:
+ self._in_flight_fetches.remove(current_max)
+
+ return stream_ids
def get_next(self) -> AsyncContextManager[int]:
"""
@@ -463,9 +501,6 @@ class MultiWriterIdGenerator:
next_id = self._load_next_id_txn(txn)
- with self._lock:
- self._unfinished_ids.add(next_id)
-
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
@@ -497,15 +532,27 @@ class MultiWriterIdGenerator:
new_cur: Optional[int] = None
- if self._unfinished_ids:
+ if self._unfinished_ids or self._in_flight_fetches:
# If there are unfinished IDs then the new position will be the
- # largest finished ID less than the minimum unfinished ID.
+ # largest finished ID strictly less than the minimum unfinished
+ # ID.
+
+ # The minimum unfinished ID needs to take account of both
+ # `_unfinished_ids` and `_in_flight_fetches`.
+ if self._unfinished_ids and self._in_flight_fetches:
+ # `_in_flight_fetches` stores the maximum safe stream ID, so
+ # we add one to make it equivalent to the minimum unsafe ID.
+ min_unfinished = min(
+ self._unfinished_ids[0], self._in_flight_fetches[0] + 1
+ )
+ elif self._in_flight_fetches:
+ min_unfinished = self._in_flight_fetches[0] + 1
+ else:
+ min_unfinished = self._unfinished_ids[0]
finished = set()
-
- min_unfinshed = self._unfinished_ids[0]
for s in self._finished_ids:
- if s < min_unfinshed:
+ if s < min_unfinished:
if new_cur is None or new_cur < s:
new_cur = s
else:
@@ -575,6 +622,10 @@ class MultiWriterIdGenerator:
new_id, self._current_positions.get(instance_name, 0)
)
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, new_id
+ )
+
self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int:
@@ -605,7 +656,11 @@ class MultiWriterIdGenerator:
# to report a recent position when asked, rather than a potentially old
# one (if this instance hasn't written anything for a while).
our_current_position = self._current_positions.get(self._instance_name)
- if our_current_position and not self._unfinished_ids:
+ if (
+ our_current_position
+ and not self._unfinished_ids
+ and not self._in_flight_fetches
+ ):
self._current_positions[self._instance_name] = max(
our_current_position, new_id
)
@@ -697,9 +752,6 @@ class _MultiWriterCtxManager:
db_autocommit=True,
)
- with self.id_gen._lock:
- self.id_gen._unfinished_ids.update(self.stream_ids)
-
if self.multiple_ids is None:
return self.stream_ids[0] * self.id_gen._return_factor
else:
|