diff options
Diffstat (limited to 'synapse/storage/util')
-rw-r--r-- | synapse/storage/util/id_generators.py | 109 | ||||
-rw-r--r-- | synapse/storage/util/sequence.py | 82 |
2 files changed, 140 insertions, 51 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 133c0e7a28..71ef5a72dc 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -15,12 +15,11 @@ import heapq import logging import threading -from collections import deque +from collections import OrderedDict from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Tuple, Union import attr -from typing_extensions import Deque from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.database import DatabasePool, LoggingTransaction @@ -101,7 +100,13 @@ class StreamIdGenerator: self._current = (max if step > 0 else min)( self._current, _load_current_id(db_conn, table, column, step) ) - self._unfinished_ids = deque() # type: Deque[int] + + # We use this as an ordered set, as we want to efficiently append items, + # remove items and get the first item. Since we insert IDs in order, the + # insertion ordering will ensure its in the correct ordering. + # + # The key and values are the same, but we never look at the values. + self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int] def get_next(self): """ @@ -113,7 +118,7 @@ class StreamIdGenerator: self._current += self._step next_id = self._current - self._unfinished_ids.append(next_id) + self._unfinished_ids[next_id] = next_id @contextmanager def manager(): @@ -121,7 +126,7 @@ class StreamIdGenerator: yield next_id finally: with self._lock: - self._unfinished_ids.remove(next_id) + self._unfinished_ids.pop(next_id) return _AsyncCtxManagerWrapper(manager()) @@ -140,7 +145,7 @@ class StreamIdGenerator: self._current += n * self._step for next_id in next_ids: - self._unfinished_ids.append(next_id) + self._unfinished_ids[next_id] = next_id @contextmanager def manager(): @@ -149,7 +154,7 @@ class StreamIdGenerator: finally: with self._lock: for next_id in next_ids: - self._unfinished_ids.remove(next_id) + self._unfinished_ids.pop(next_id) return _AsyncCtxManagerWrapper(manager()) @@ -162,7 +167,7 @@ class StreamIdGenerator: """ with self._lock: if self._unfinished_ids: - return self._unfinished_ids[0] - self._step + return next(iter(self._unfinished_ids)) - self._step return self._current @@ -186,11 +191,12 @@ class MultiWriterIdGenerator: Args: db_conn db - stream_name: A name for the stream. + stream_name: A name for the stream, for use in the `stream_positions` + table. (Does not need to be the same as the replication stream name) instance_name: The name of this instance. - table: Database table associated with stream. - instance_column: Column that stores the row's writer's instance name - id_column: Column that stores the stream ID. + tables: List of tables associated with the stream. Tuple of table + name, column name that stores the writer's instance name, and + column name that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. writers: A list of known writers to use to populate current positions @@ -206,9 +212,7 @@ class MultiWriterIdGenerator: db: DatabasePool, stream_name: str, instance_name: str, - table: str, - instance_column: str, - id_column: str, + tables: List[Tuple[str, str, str]], sequence_name: str, writers: List[str], positive: bool = True, @@ -260,15 +264,20 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) # We check that the table and sequence haven't diverged. - self._sequence_gen.check_consistency( - db_conn, table=table, id_column=id_column, positive=positive - ) + for table, _, id_column in tables: + self._sequence_gen.check_consistency( + db_conn, + table=table, + id_column=id_column, + stream_name=stream_name, + positive=positive, + ) # This goes and fills out the above state from the database. - self._load_current_ids(db_conn, table, instance_column, id_column) + self._load_current_ids(db_conn, tables) def _load_current_ids( - self, db_conn, table: str, instance_column: str, id_column: str + self, db_conn, tables: List[Tuple[str, str, str]], ): cur = db_conn.cursor(txn_name="_load_current_ids") @@ -306,17 +315,22 @@ class MultiWriterIdGenerator: # We add a GREATEST here to ensure that the result is always # positive. (This can be a problem for e.g. backfill streams where # the server has never backfilled). - sql = """ - SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) - FROM %(table)s - """ % { - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - cur.execute(sql) - (stream_id,) = cur.fetchone() - self._persisted_upto_position = stream_id + max_stream_id = 1 + for table, _, id_column in tables: + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + + max_stream_id = max(max_stream_id, stream_id) + + self._persisted_upto_position = max_stream_id else: # If we have a min_stream_id then we pull out everything greater # than it from the DB so that we can prefill @@ -329,21 +343,28 @@ class MultiWriterIdGenerator: # stream positions table before restart (or the stream position # table otherwise got out of date). - sql = """ - SELECT %(instance)s, %(id)s FROM %(table)s - WHERE ? %(cmp)s %(id)s - """ % { - "id": id_column, - "table": table, - "instance": instance_column, - "cmp": "<=" if self._positive else ">=", - } - cur.execute(sql, (min_stream_id * self._return_factor,)) - self._persisted_upto_position = min_stream_id + rows = [] + for table, instance_column, id_column in tables: + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + cur.execute(sql, (min_stream_id * self._return_factor,)) + + rows.extend(cur) + + # Sort so that we handle rows in order for each instance. + rows.sort() + with self._lock: - for (instance, stream_id,) in cur: + for (instance, stream_id,) in rows: stream_id = self._return_factor * stream_id self._add_persisted_position(stream_id) diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py index 4386b6101e..0ec4dc2918 100644 --- a/synapse/storage/util/sequence.py +++ b/synapse/storage/util/sequence.py @@ -15,9 +15,8 @@ import abc import logging import threading -from typing import Callable, List, Optional +from typing import TYPE_CHECKING, Callable, List, Optional -from synapse.storage.database import LoggingDatabaseConnection from synapse.storage.engines import ( BaseDatabaseEngine, IncorrectDatabaseSetup, @@ -25,6 +24,9 @@ from synapse.storage.engines import ( ) from synapse.storage.types import Connection, Cursor +if TYPE_CHECKING: + from synapse.storage.database import LoggingDatabaseConnection + logger = logging.getLogger(__name__) @@ -43,6 +45,21 @@ and run the following SQL: See docs/postgres.md for more information. """ +_INCONSISTENT_STREAM_ERROR = """ +Postgres sequence '%(seq)s' is inconsistent with associated stream position +of '%(stream_name)s' in the 'stream_positions' table. + +This is likely a programming error and should be reported at +https://github.com/matrix-org/synapse. + +A temporary workaround to fix this error is to shut down Synapse (including +any and all workers) and run the following SQL: + + DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s'; + +This will need to be done every time the server is restarted. +""" + class SequenceGenerator(metaclass=abc.ABCMeta): """A class which generates a unique sequence of integers""" @@ -53,19 +70,30 @@ class SequenceGenerator(metaclass=abc.ABCMeta): ... @abc.abstractmethod + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + """Get the next `n` IDs in the sequence""" + ... + + @abc.abstractmethod def check_consistency( self, - db_conn: LoggingDatabaseConnection, + db_conn: "LoggingDatabaseConnection", table: str, id_column: str, + stream_name: Optional[str] = None, positive: bool = True, ): """Should be called during start up to test that the current value of the sequence is greater than or equal to the maximum ID in the table. - This is to handle various cases where the sequence value can get out - of sync with the table, e.g. if Synapse gets rolled back to a previous + This is to handle various cases where the sequence value can get out of + sync with the table, e.g. if Synapse gets rolled back to a previous version and the rolled forwards again. + + If a stream name is given then this will check that any value in the + `stream_positions` table is less than or equal to the current sequence + value. If it isn't then it's likely that streams have been crossed + somewhere (e.g. two ID generators have the same stream name). """ ... @@ -88,11 +116,15 @@ class PostgresSequenceGenerator(SequenceGenerator): def check_consistency( self, - db_conn: LoggingDatabaseConnection, + db_conn: "LoggingDatabaseConnection", table: str, id_column: str, + stream_name: Optional[str] = None, positive: bool = True, ): + """See SequenceGenerator.check_consistency for docstring. + """ + txn = db_conn.cursor(txn_name="sequence.check_consistency") # First we get the current max ID from the table. @@ -116,6 +148,18 @@ class PostgresSequenceGenerator(SequenceGenerator): "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name} ) last_value, is_called = txn.fetchone() + + # If we have an associated stream check the stream_positions table. + max_in_stream_positions = None + if stream_name: + txn.execute( + "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?", + (stream_name,), + ) + row = txn.fetchone() + if row: + max_in_stream_positions = row[0] + txn.close() # If `is_called` is False then `last_value` is actually the value that @@ -136,6 +180,14 @@ class PostgresSequenceGenerator(SequenceGenerator): % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql} ) + # If we have values in the stream positions table then they have to be + # less than or equal to `last_value` + if max_in_stream_positions and max_in_stream_positions > last_value: + raise IncorrectDatabaseSetup( + _INCONSISTENT_STREAM_ERROR + % {"seq": self._sequence_name, "stream_name": stream_name} + ) + GetFirstCallbackType = Callable[[Cursor], int] @@ -172,8 +224,24 @@ class LocalSequenceGenerator(SequenceGenerator): self._current_max_id += 1 return self._current_max_id + def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]: + with self._lock: + if self._current_max_id is None: + assert self._callback is not None + self._current_max_id = self._callback(txn) + self._callback = None + + first_id = self._current_max_id + 1 + self._current_max_id += n + return [first_id + i for i in range(n)] + def check_consistency( - self, db_conn: Connection, table: str, id_column: str, positive: bool = True + self, + db_conn: Connection, + table: str, + id_column: str, + stream_name: Optional[str] = None, + positive: bool = True, ): # There is nothing to do for in memory sequences pass |