summary refs log tree commit diff
path: root/synapse/storage/util/id_generators.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util/id_generators.py')
-rw-r--r--synapse/storage/util/id_generators.py366
1 files changed, 280 insertions, 86 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py

index b7eb4f8ac9..02fbb656e8 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py
@@ -12,16 +12,17 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import contextlib import heapq import logging import threading from collections import deque -from typing import Dict, List, Set +from contextlib import contextmanager +from typing import Dict, List, Optional, Set, Union +import attr from typing_extensions import Deque +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.util.sequence import PostgresSequenceGenerator @@ -86,7 +87,7 @@ class StreamIdGenerator: upwards, -1 to grow downwards. Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ @@ -101,10 +102,10 @@ class StreamIdGenerator: ) self._unfinished_ids = deque() # type: Deque[int] - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ with self._lock: @@ -113,7 +114,7 @@ class StreamIdGenerator: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_id @@ -121,12 +122,12 @@ class StreamIdGenerator: with self._lock: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) - async def get_next_mult(self, n): + def get_next_mult(self, n): """ Usage: - with await stream_id_gen.get_next(n) as stream_ids: + async with stream_id_gen.get_next(n) as stream_ids: # ... persist events ... """ with self._lock: @@ -140,7 +141,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.append(next_id) - @contextlib.contextmanager + @contextmanager def manager(): try: yield next_ids @@ -149,7 +150,7 @@ class StreamIdGenerator: for next_id in next_ids: self._unfinished_ids.remove(next_id) - return manager() + return _AsyncCtxManagerWrapper(manager()) def get_current_token(self): """Returns the maximum stream id such that all stream ids less than or @@ -184,12 +185,16 @@ class MultiWriterIdGenerator: Args: db_conn db + stream_name: A name for the stream. instance_name: The name of this instance. table: Database table associated with stream. instance_column: Column that stores the row's writer's instance name id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + writers: A list of known writers to use to populate current positions + on startup. Can be empty if nothing uses `get_current_token` or + `get_positions` (e.g. caches stream). positive: Whether the IDs are positive (true) or negative (false). When using negative IDs we go backwards from -1 to -2, -3, etc. """ @@ -198,16 +203,20 @@ class MultiWriterIdGenerator: self, db_conn, db: DatabasePool, + stream_name: str, instance_name: str, table: str, instance_column: str, id_column: str, sequence_name: str, + writers: List[str], positive: bool = True, ): self._db = db + self._stream_name = stream_name self._instance_name = instance_name self._positive = positive + self._writers = writers self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. @@ -216,14 +225,16 @@ class MultiWriterIdGenerator: # Note: If we are a negative stream then we still store all the IDs as # positive to make life easier for us, and simply negate the IDs when we # return them. - self._current_positions = self._load_current_ids( - db_conn, table, instance_column, id_column - ) + self._current_positions = {} # type: Dict[str, int] # Set of local IDs that we're still processing. The current position # should be less than the minimum of this set (if not empty). self._unfinished_ids = set() # type: Set[int] + # Set of local IDs that we've processed that are larger than the current + # position, due to there being smaller unpersisted IDs. + self._finished_ids = set() # type: Set[int] + # We track the max position where we know everything before has been # persisted. This is done by a) looking at the min across all instances # and b) noting that if we have seen a run of persisted positions @@ -236,37 +247,113 @@ class MultiWriterIdGenerator: # gaps should be relatively rare it's still worth doing the book keeping # that allows us to skip forwards when there are gapless runs of # positions. + # + # We start at 1 here as a) the first generated stream ID will be 2, and + # b) other parts of the code assume that stream IDs are strictly greater + # than 0. self._persisted_upto_position = ( - min(self._current_positions.values()) if self._current_positions else 0 + min(self._current_positions.values()) if self._current_positions else 1 ) self._known_persisted_positions = [] # type: List[int] self._sequence_gen = PostgresSequenceGenerator(sequence_name) + # We check that the table and sequence haven't diverged. + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) + + # This goes and fills out the above state from the database. + self._load_current_ids(db_conn, table, instance_column, id_column) + def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str - ) -> Dict[str, int]: - # If positive stream aggregate via MAX. For negative stream use MIN - # *and* negate the result to get a positive number. - sql = """ - SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s - GROUP BY %(instance)s - """ % { - "instance": instance_column, - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - + ): cur = db_conn.cursor() - cur.execute(sql) - # `cur` is an iterable over returned rows, which are 2-tuples. - current_positions = dict(cur) + # Load the current positions of all writers for the stream. + if self._writers: + # We delete any stale entries in the positions table. This is + # important if we add back a writer after a long time; we want to + # consider that a "new" writer, rather than using the old stale + # entry here. + sql = """ + DELETE FROM stream_positions + WHERE + stream_name = ? + AND instance_name != ALL(?) + """ + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (self._stream_name, self._writers)) + + sql = """ + SELECT instance_name, stream_id FROM stream_positions + WHERE stream_name = ? + """ + sql = self._db.engine.convert_param_style(sql) + + cur.execute(sql, (self._stream_name,)) + + self._current_positions = { + instance: stream_id * self._return_factor + for instance, stream_id in cur + if instance in self._writers + } - cur.close() + # We set the `_persisted_upto_position` to be the minimum of all current + # positions. If empty we use the max stream ID from the DB table. + min_stream_id = min(self._current_positions.values(), default=None) + + if min_stream_id is None: + # We add a GREATEST here to ensure that the result is always + # positive. (This can be a problem for e.g. backfill streams where + # the server has never backfilled). + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + self._persisted_upto_position = stream_id + else: + # If we have a min_stream_id then we pull out everything greater + # than it from the DB so that we can prefill + # `_known_persisted_positions` and get a more accurate + # `_persisted_upto_position`. + # + # We also check if any of the later rows are from this instance, in + # which case we use that for this instance's current position. This + # is to handle the case where we didn't finish persisting to the + # stream positions table before restart (or the stream position + # table otherwise got out of date). + + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + sql = self._db.engine.convert_param_style(sql) + cur.execute(sql, (min_stream_id,)) - return current_positions + self._persisted_upto_position = min_stream_id + + with self._lock: + for (instance, stream_id,) in cur: + stream_id = self._return_factor * stream_id + self._add_persisted_position(stream_id) + + if instance == self._instance_name: + self._current_positions[instance] = stream_id + + cur.close() def _load_next_id_txn(self, txn) -> int: return self._sequence_gen.get_next_id_txn(txn) @@ -274,59 +361,23 @@ class MultiWriterIdGenerator: def _load_next_mult_id_txn(self, txn, n: int) -> List[int]: return self._sequence_gen.get_next_mult_txn(txn, n) - async def get_next(self): + def get_next(self): """ Usage: - with await stream_id_gen.get_next() as stream_id: + async with stream_id_gen.get_next() as stream_id: # ... persist event ... """ - next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn) - # Assert the fetched ID is actually greater than what we currently - # believe the ID to be. If not, then the sequence and table have got - # out of sync somehow. - with self._lock: - assert self._current_positions.get(self._instance_name, 0) < next_id - - self._unfinished_ids.add(next_id) + return _MultiWriterCtxManager(self) - @contextlib.contextmanager - def manager(): - try: - # Multiply by the return factor so that the ID has correct sign. - yield self._return_factor * next_id - finally: - self._mark_id_as_finished(next_id) - - return manager() - - async def get_next_mult(self, n: int): + def get_next_mult(self, n: int): """ Usage: - with await stream_id_gen.get_next_mult(5) as stream_ids: + async with stream_id_gen.get_next_mult(5) as stream_ids: # ... persist events ... """ - next_ids = await self._db.runInteraction( - "_load_next_mult_id", self._load_next_mult_id_txn, n - ) - # Assert the fetched ID is actually greater than any ID we've already - # seen. If not, then the sequence and table have got out of sync - # somehow. - with self._lock: - assert max(self._current_positions.values(), default=0) < min(next_ids) - - self._unfinished_ids.update(next_ids) - - @contextlib.contextmanager - def manager(): - try: - yield [self._return_factor * i for i in next_ids] - finally: - for i in next_ids: - self._mark_id_as_finished(i) - - return manager() + return _MultiWriterCtxManager(self, n) def get_next_txn(self, txn: LoggingTransaction): """ @@ -344,21 +395,63 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persited row with the correct instance name. + if self._writers: + txn.call_after( + run_as_background_process, + "MultiWriterIdGenerator._update_table", + self._db.runInteraction, + "MultiWriterIdGenerator._update_table", + self._update_stream_positions_table_txn, + ) + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the - current poistion if possible. + current position if possible. """ with self._lock: self._unfinished_ids.discard(next_id) + self._finished_ids.add(next_id) + + new_cur = None + + if self._unfinished_ids: + # If there are unfinished IDs then the new position will be the + # largest finished ID less than the minimum unfinished ID. + + finished = set() + + min_unfinshed = min(self._unfinished_ids) + for s in self._finished_ids: + if s < min_unfinshed: + if new_cur is None or new_cur < s: + new_cur = s + else: + finished.add(s) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids = finished + else: + # There are no unfinished IDs so the new position is simply the + # largest finished one. + new_cur = max(self._finished_ids) + + # We clear these out since they're now all less than the new + # position. + self._finished_ids.clear() - # Figure out if its safe to advance the position by checking there - # aren't any lower allocated IDs that are yet to finish. - if all(c > next_id for c in self._unfinished_ids): + if new_cur: curr = self._current_positions.get(self._instance_name, 0) - self._current_positions[self._instance_name] = max(curr, next_id) + self._current_positions[self._instance_name] = max(curr, new_cur) self._add_persisted_position(next_id) @@ -367,19 +460,28 @@ class MultiWriterIdGenerator: equal to it have been successfully persisted. """ - # Currently we don't support this operation, as it's not obvious how to - # condense the stream positions of multiple writers into a single int. - raise NotImplementedError() + return self.get_persisted_upto_position() def get_current_token_for_writer(self, instance_name: str) -> int: """Returns the position of the given writer. """ + # If we don't have an entry for the given instance name, we assume it's a + # new writer. + # + # For new writers we assume their initial position to be the current + # persisted up to position. This stops Synapse from doing a full table + # scan when a new writer announces itself over replication. with self._lock: - return self._return_factor * self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get( + instance_name, self._persisted_upto_position + ) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. + + Note that this won't necessarily include all configured writers if some + writers haven't written anything yet. """ with self._lock: @@ -428,7 +530,7 @@ class MultiWriterIdGenerator: # We move the current min position up if the minimum current positions # of all instances is higher (since by definition all positions less # that that have been persisted). - min_curr = min(self._current_positions.values()) + min_curr = min(self._current_positions.values(), default=0) self._persisted_upto_position = max(min_curr, self._persisted_upto_position) # We now iterate through the seen positions, discarding those that are @@ -449,3 +551,95 @@ class MultiWriterIdGenerator: # There was a gap in seen positions, so there is nothing more to # do. break + + def _update_stream_positions_table_txn(self, txn): + """Update the `stream_positions` table with newly persisted position. + """ + + if not self._writers: + return + + # We upsert the value, ensuring on conflict that we always increase the + # value (or decrease if stream goes backwards). + sql = """ + INSERT INTO stream_positions (stream_name, instance_name, stream_id) + VALUES (?, ?, ?) + ON CONFLICT (stream_name, instance_name) + DO UPDATE SET + stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id) + """ % { + "agg": "GREATEST" if self._positive else "LEAST", + } + + pos = (self.get_current_token_for_writer(self._instance_name),) + txn.execute(sql, (self._stream_name, self._instance_name, pos)) + + +@attr.s(slots=True) +class _AsyncCtxManagerWrapper: + """Helper class to convert a plain context manager to an async one. + + This is mainly useful if you have a plain context manager but the interface + requires an async one. + """ + + inner = attr.ib() + + async def __aenter__(self): + return self.inner.__enter__() + + async def __aexit__(self, exc_type, exc, tb): + return self.inner.__exit__(exc_type, exc, tb) + + +@attr.s(slots=True) +class _MultiWriterCtxManager: + """Async context manager returned by MultiWriterIdGenerator + """ + + id_gen = attr.ib(type=MultiWriterIdGenerator) + multiple_ids = attr.ib(type=Optional[int], default=None) + stream_ids = attr.ib(type=List[int], factory=list) + + async def __aenter__(self) -> Union[int, List[int]]: + self.stream_ids = await self.id_gen._db.runInteraction( + "_load_next_mult_id", + self.id_gen._load_next_mult_id_txn, + self.multiple_ids or 1, + ) + + # Assert the fetched ID is actually greater than any ID we've already + # seen. If not, then the sequence and table have got out of sync + # somehow. + with self.id_gen._lock: + assert max(self.id_gen._current_positions.values(), default=0) < min( + self.stream_ids + ) + + self.id_gen._unfinished_ids.update(self.stream_ids) + + if self.multiple_ids is None: + return self.stream_ids[0] * self.id_gen._return_factor + else: + return [i * self.id_gen._return_factor for i in self.stream_ids] + + async def __aexit__(self, exc_type, exc, tb): + for i in self.stream_ids: + self.id_gen._mark_id_as_finished(i) + + if exc_type is not None: + return False + + # Update the `stream_positions` table with newly updated stream + # ID (unless self._writers is not set in which case we don't + # bother, as nothing will read it). + # + # We only do this on the success path so that the persisted current + # position points to a persisted row with the correct instance name. + if self.id_gen._writers: + await self.id_gen._db.runInteraction( + "MultiWriterIdGenerator._update_table", + self.id_gen._update_stream_positions_table_txn, + ) + + return False