diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b7eb4f8ac9..02fbb656e8 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,16 +12,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import contextlib
import heapq
import logging
import threading
from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
+import attr
from typing_extensions import Deque
+from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.util.sequence import PostgresSequenceGenerator
@@ -86,7 +87,7 @@ class StreamIdGenerator:
upwards, -1 to grow downwards.
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
@@ -101,10 +102,10 @@ class StreamIdGenerator:
)
self._unfinished_ids = deque() # type: Deque[int]
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
with self._lock:
@@ -113,7 +114,7 @@ class StreamIdGenerator:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_id
@@ -121,12 +122,12 @@ class StreamIdGenerator:
with self._lock:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
- async def get_next_mult(self, n):
+ def get_next_mult(self, n):
"""
Usage:
- with await stream_id_gen.get_next(n) as stream_ids:
+ async with stream_id_gen.get_next(n) as stream_ids:
# ... persist events ...
"""
with self._lock:
@@ -140,7 +141,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.append(next_id)
- @contextlib.contextmanager
+ @contextmanager
def manager():
try:
yield next_ids
@@ -149,7 +150,7 @@ class StreamIdGenerator:
for next_id in next_ids:
self._unfinished_ids.remove(next_id)
- return manager()
+ return _AsyncCtxManagerWrapper(manager())
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
+ stream_name: A name for the stream.
instance_name: The name of this instance.
table: Database table associated with stream.
instance_column: Column that stores the row's writer's instance name
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
+ writers: A list of known writers to use to populate current positions
+ on startup. Can be empty if nothing uses `get_current_token` or
+ `get_positions` (e.g. caches stream).
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
@@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
self,
db_conn,
db: DatabasePool,
+ stream_name: str,
instance_name: str,
table: str,
instance_column: str,
id_column: str,
sequence_name: str,
+ writers: List[str],
positive: bool = True,
):
self._db = db
+ self._stream_name = stream_name
self._instance_name = instance_name
self._positive = positive
+ self._writers = writers
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
@@ -216,14 +225,16 @@ class MultiWriterIdGenerator:
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
- self._current_positions = self._load_current_ids(
- db_conn, table, instance_column, id_column
- )
+ self._current_positions = {} # type: Dict[str, int]
# Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int]
+ # Set of local IDs that we've processed that are larger than the current
+ # position, due to there being smaller unpersisted IDs.
+ self._finished_ids = set() # type: Set[int]
+
# We track the max position where we know everything before has been
# persisted. This is done by a) looking at the min across all instances
# and b) noting that if we have seen a run of persisted positions
@@ -236,37 +247,113 @@ class MultiWriterIdGenerator:
# gaps should be relatively rare it's still worth doing the book keeping
# that allows us to skip forwards when there are gapless runs of
# positions.
+ #
+ # We start at 1 here as a) the first generated stream ID will be 2, and
+ # b) other parts of the code assume that stream IDs are strictly greater
+ # than 0.
self._persisted_upto_position = (
- min(self._current_positions.values()) if self._current_positions else 0
+ min(self._current_positions.values()) if self._current_positions else 1
)
self._known_persisted_positions = [] # type: List[int]
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
+ # We check that the table and sequence haven't diverged.
+ self._sequence_gen.check_consistency(
+ db_conn, table=table, id_column=id_column, positive=positive
+ )
+
+ # This goes and fills out the above state from the database.
+ self._load_current_ids(db_conn, table, instance_column, id_column)
+
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
- ) -> Dict[str, int]:
- # If positive stream aggregate via MAX. For negative stream use MIN
- # *and* negate the result to get a positive number.
- sql = """
- SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
- GROUP BY %(instance)s
- """ % {
- "instance": instance_column,
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
-
+ ):
cur = db_conn.cursor()
- cur.execute(sql)
- # `cur` is an iterable over returned rows, which are 2-tuples.
- current_positions = dict(cur)
+ # Load the current positions of all writers for the stream.
+ if self._writers:
+ # We delete any stale entries in the positions table. This is
+ # important if we add back a writer after a long time; we want to
+ # consider that a "new" writer, rather than using the old stale
+ # entry here.
+ sql = """
+ DELETE FROM stream_positions
+ WHERE
+ stream_name = ?
+ AND instance_name != ALL(?)
+ """
+ sql = self._db.engine.convert_param_style(sql)
+ cur.execute(sql, (self._stream_name, self._writers))
+
+ sql = """
+ SELECT instance_name, stream_id FROM stream_positions
+ WHERE stream_name = ?
+ """
+ sql = self._db.engine.convert_param_style(sql)
+
+ cur.execute(sql, (self._stream_name,))
+
+ self._current_positions = {
+ instance: stream_id * self._return_factor
+ for instance, stream_id in cur
+ if instance in self._writers
+ }
- cur.close()
+ # We set the `_persisted_upto_position` to be the minimum of all current
+ # positions. If empty we use the max stream ID from the DB table.
+ min_stream_id = min(self._current_positions.values(), default=None)
+
+ if min_stream_id is None:
+ # We add a GREATEST here to ensure that the result is always
+ # positive. (This can be a problem for e.g. backfill streams where
+ # the server has never backfilled).
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+ self._persisted_upto_position = stream_id
+ else:
+ # If we have a min_stream_id then we pull out everything greater
+ # than it from the DB so that we can prefill
+ # `_known_persisted_positions` and get a more accurate
+ # `_persisted_upto_position`.
+ #
+ # We also check if any of the later rows are from this instance, in
+ # which case we use that for this instance's current position. This
+ # is to handle the case where we didn't finish persisting to the
+ # stream positions table before restart (or the stream position
+ # table otherwise got out of date).
+
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ sql = self._db.engine.convert_param_style(sql)
+ cur.execute(sql, (min_stream_id,))
- return current_positions
+ self._persisted_upto_position = min_stream_id
+
+ with self._lock:
+ for (instance, stream_id,) in cur:
+ stream_id = self._return_factor * stream_id
+ self._add_persisted_position(stream_id)
+
+ if instance == self._instance_name:
+ self._current_positions[instance] = stream_id
+
+ cur.close()
def _load_next_id_txn(self, txn) -> int:
return self._sequence_gen.get_next_id_txn(txn)
@@ -274,59 +361,23 @@ class MultiWriterIdGenerator:
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
return self._sequence_gen.get_next_mult_txn(txn, n)
- async def get_next(self):
+ def get_next(self):
"""
Usage:
- with await stream_id_gen.get_next() as stream_id:
+ async with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
- next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
- # Assert the fetched ID is actually greater than what we currently
- # believe the ID to be. If not, then the sequence and table have got
- # out of sync somehow.
- with self._lock:
- assert self._current_positions.get(self._instance_name, 0) < next_id
-
- self._unfinished_ids.add(next_id)
+ return _MultiWriterCtxManager(self)
- @contextlib.contextmanager
- def manager():
- try:
- # Multiply by the return factor so that the ID has correct sign.
- yield self._return_factor * next_id
- finally:
- self._mark_id_as_finished(next_id)
-
- return manager()
-
- async def get_next_mult(self, n: int):
+ def get_next_mult(self, n: int):
"""
Usage:
- with await stream_id_gen.get_next_mult(5) as stream_ids:
+ async with stream_id_gen.get_next_mult(5) as stream_ids:
# ... persist events ...
"""
- next_ids = await self._db.runInteraction(
- "_load_next_mult_id", self._load_next_mult_id_txn, n
- )
- # Assert the fetched ID is actually greater than any ID we've already
- # seen. If not, then the sequence and table have got out of sync
- # somehow.
- with self._lock:
- assert max(self._current_positions.values(), default=0) < min(next_ids)
-
- self._unfinished_ids.update(next_ids)
-
- @contextlib.contextmanager
- def manager():
- try:
- yield [self._return_factor * i for i in next_ids]
- finally:
- for i in next_ids:
- self._mark_id_as_finished(i)
-
- return manager()
+ return _MultiWriterCtxManager(self, n)
def get_next_txn(self, txn: LoggingTransaction):
"""
@@ -344,21 +395,63 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persited row with the correct instance name.
+ if self._writers:
+ txn.call_after(
+ run_as_background_process,
+ "MultiWriterIdGenerator._update_table",
+ self._db.runInteraction,
+ "MultiWriterIdGenerator._update_table",
+ self._update_stream_positions_table_txn,
+ )
+
return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
- current poistion if possible.
+ current position if possible.
"""
with self._lock:
self._unfinished_ids.discard(next_id)
+ self._finished_ids.add(next_id)
+
+ new_cur = None
+
+ if self._unfinished_ids:
+ # If there are unfinished IDs then the new position will be the
+ # largest finished ID less than the minimum unfinished ID.
+
+ finished = set()
+
+ min_unfinshed = min(self._unfinished_ids)
+ for s in self._finished_ids:
+ if s < min_unfinshed:
+ if new_cur is None or new_cur < s:
+ new_cur = s
+ else:
+ finished.add(s)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids = finished
+ else:
+ # There are no unfinished IDs so the new position is simply the
+ # largest finished one.
+ new_cur = max(self._finished_ids)
+
+ # We clear these out since they're now all less than the new
+ # position.
+ self._finished_ids.clear()
- # Figure out if its safe to advance the position by checking there
- # aren't any lower allocated IDs that are yet to finish.
- if all(c > next_id for c in self._unfinished_ids):
+ if new_cur:
curr = self._current_positions.get(self._instance_name, 0)
- self._current_positions[self._instance_name] = max(curr, next_id)
+ self._current_positions[self._instance_name] = max(curr, new_cur)
self._add_persisted_position(next_id)
@@ -367,19 +460,28 @@ class MultiWriterIdGenerator:
equal to it have been successfully persisted.
"""
- # Currently we don't support this operation, as it's not obvious how to
- # condense the stream positions of multiple writers into a single int.
- raise NotImplementedError()
+ return self.get_persisted_upto_position()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
"""
+ # If we don't have an entry for the given instance name, we assume it's a
+ # new writer.
+ #
+ # For new writers we assume their initial position to be the current
+ # persisted up to position. This stops Synapse from doing a full table
+ # scan when a new writer announces itself over replication.
with self._lock:
- return self._return_factor * self._current_positions.get(instance_name, 0)
+ return self._return_factor * self._current_positions.get(
+ instance_name, self._persisted_upto_position
+ )
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
+
+ Note that this won't necessarily include all configured writers if some
+ writers haven't written anything yet.
"""
with self._lock:
@@ -428,7 +530,7 @@ class MultiWriterIdGenerator:
# We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less
# that that have been persisted).
- min_curr = min(self._current_positions.values())
+ min_curr = min(self._current_positions.values(), default=0)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are
@@ -449,3 +551,95 @@ class MultiWriterIdGenerator:
# There was a gap in seen positions, so there is nothing more to
# do.
break
+
+ def _update_stream_positions_table_txn(self, txn):
+ """Update the `stream_positions` table with newly persisted position.
+ """
+
+ if not self._writers:
+ return
+
+ # We upsert the value, ensuring on conflict that we always increase the
+ # value (or decrease if stream goes backwards).
+ sql = """
+ INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+ VALUES (?, ?, ?)
+ ON CONFLICT (stream_name, instance_name)
+ DO UPDATE SET
+ stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+ """ % {
+ "agg": "GREATEST" if self._positive else "LEAST",
+ }
+
+ pos = (self.get_current_token_for_writer(self._instance_name),)
+ txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+ """Helper class to convert a plain context manager to an async one.
+
+ This is mainly useful if you have a plain context manager but the interface
+ requires an async one.
+ """
+
+ inner = attr.ib()
+
+ async def __aenter__(self):
+ return self.inner.__enter__()
+
+ async def __aexit__(self, exc_type, exc, tb):
+ return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+ """Async context manager returned by MultiWriterIdGenerator
+ """
+
+ id_gen = attr.ib(type=MultiWriterIdGenerator)
+ multiple_ids = attr.ib(type=Optional[int], default=None)
+ stream_ids = attr.ib(type=List[int], factory=list)
+
+ async def __aenter__(self) -> Union[int, List[int]]:
+ self.stream_ids = await self.id_gen._db.runInteraction(
+ "_load_next_mult_id",
+ self.id_gen._load_next_mult_id_txn,
+ self.multiple_ids or 1,
+ )
+
+ # Assert the fetched ID is actually greater than any ID we've already
+ # seen. If not, then the sequence and table have got out of sync
+ # somehow.
+ with self.id_gen._lock:
+ assert max(self.id_gen._current_positions.values(), default=0) < min(
+ self.stream_ids
+ )
+
+ self.id_gen._unfinished_ids.update(self.stream_ids)
+
+ if self.multiple_ids is None:
+ return self.stream_ids[0] * self.id_gen._return_factor
+ else:
+ return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+ async def __aexit__(self, exc_type, exc, tb):
+ for i in self.stream_ids:
+ self.id_gen._mark_id_as_finished(i)
+
+ if exc_type is not None:
+ return False
+
+ # Update the `stream_positions` table with newly updated stream
+ # ID (unless self._writers is not set in which case we don't
+ # bother, as nothing will read it).
+ #
+ # We only do this on the success path so that the persisted current
+ # position points to a persisted row with the correct instance name.
+ if self.id_gen._writers:
+ await self.id_gen._db.runInteraction(
+ "MultiWriterIdGenerator._update_table",
+ self.id_gen._update_stream_positions_table_txn,
+ )
+
+ return False
|