diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index d2c874b9a8..9c3eafb562 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -134,6 +134,15 @@ class AbstractStreamIdGenerator(metaclass=abc.ABCMeta):
raise NotImplementedError()
@abc.abstractmethod
+ def get_minimal_local_current_token(self) -> int:
+ """Tries to return a minimal current token for the local instance,
+ i.e. for writers this would be the last successful write.
+
+ If local instance is not a writer (or has written yet) then falls back
+ to returning the normal "current token".
+ """
+
+ @abc.abstractmethod
def get_next(self) -> AsyncContextManager[int]:
"""
Usage:
@@ -312,6 +321,9 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
def get_current_token_for_writer(self, instance_name: str) -> int:
return self.get_current_token()
+ def get_minimal_local_current_token(self) -> int:
+ return self.get_current_token()
+
class MultiWriterIdGenerator(AbstractStreamIdGenerator):
"""Generates and tracks stream IDs for a stream with multiple writers.
@@ -408,6 +420,11 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# The maximum stream ID that we have seen been allocated across any writer.
self._max_seen_allocated_stream_id = 1
+ # The maximum position of the local instance. This can be higher than
+ # the corresponding position in `current_positions` table when there are
+ # no active writes in progress.
+ self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
@@ -427,6 +444,16 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._current_positions.values(), default=1
)
+ # For the case where `stream_positions` is not up to date,
+ # `_persisted_upto_position` may be higher.
+ self._max_seen_allocated_stream_id = max(
+ self._max_seen_allocated_stream_id, self._persisted_upto_position
+ )
+
+ # Bump our local maximum position now that we've loaded things from the
+ # DB.
+ self._max_position_of_local_instance = self._max_seen_allocated_stream_id
+
if not writers:
# If there have been no explicit writers given then any instance can
# write to the stream. In which case, let's pre-seed our own
@@ -545,6 +572,14 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance == self._instance_name:
self._current_positions[instance] = stream_id
+ if self._writers:
+ # If we have explicit writers then make sure that each instance has
+ # a position.
+ for writer in self._writers:
+ self._current_positions.setdefault(
+ writer, self._persisted_upto_position
+ )
+
cur.close()
def _load_next_id_txn(self, txn: Cursor) -> int:
@@ -688,6 +723,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur)
+ self._max_position_of_local_instance = max(
+ curr, new_cur, self._max_position_of_local_instance
+ )
self._add_persisted_position(next_id)
@@ -702,10 +740,26 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# persisted up to position. This stops Synapse from doing a full table
# scan when a new writer announces itself over replication.
with self._lock:
- return self._return_factor * self._current_positions.get(
+ if self._instance_name == instance_name:
+ return self._return_factor * self._max_position_of_local_instance
+
+ pos = self._current_positions.get(
instance_name, self._persisted_upto_position
)
+ # We want to return the maximum "current token" that we can for a
+ # writer, this helps ensure that streams progress as fast as
+ # possible.
+ pos = max(pos, self._persisted_upto_position)
+
+ return self._return_factor * pos
+
+ def get_minimal_local_current_token(self) -> int:
+ with self._lock:
+ return self._return_factor * self._current_positions.get(
+ self._instance_name, self._persisted_upto_position
+ )
+
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
@@ -774,6 +828,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
+ # Advance our local max position.
+ self._max_position_of_local_instance = max(
+ self._max_position_of_local_instance, self._persisted_upto_position
+ )
+
+ if not self._unfinished_ids and not self._in_flight_fetches:
+ # If we don't have anything in flight, it's safe to advance to the
+ # max seen stream ID.
+ self._max_position_of_local_instance = max(
+ self._max_seen_allocated_stream_id, self._max_position_of_local_instance
+ )
+
# We now iterate through the seen positions, discarding those that are
# less than the current min positions, and incrementing the min position
# if its exactly one greater.
|