summary refs log tree commit diff
path: root/synapse/storage/util
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util')
-rw-r--r--synapse/storage/util/id_generators.py369
-rw-r--r--synapse/storage/util/sequence.py90
2 files changed, 371 insertions, 88 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b7eb4f8ac9..eccd2d5b7b 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -12,17 +12,19 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import contextlib
 import heapq
 import logging
 import threading
 from collections import deque
-from typing import Dict, List, Set
+from contextlib import contextmanager
+from typing import Dict, List, Optional, Set, Union
 
+import attr
 from typing_extensions import Deque
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
+from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
 logger = logging.getLogger(__name__)
@@ -86,7 +88,7 @@ class StreamIdGenerator:
             upwards, -1 to grow downwards.
 
     Usage:
-        with await stream_id_gen.get_next() as stream_id:
+        async with stream_id_gen.get_next() as stream_id:
             # ... persist event ...
     """
 
@@ -101,10 +103,10 @@ class StreamIdGenerator:
             )
         self._unfinished_ids = deque()  # type: Deque[int]
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
         with self._lock:
@@ -113,7 +115,7 @@ class StreamIdGenerator:
 
             self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_id
@@ -121,12 +123,12 @@ class StreamIdGenerator:
                 with self._lock:
                     self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
-    async def get_next_mult(self, n):
+    def get_next_mult(self, n):
         """
         Usage:
-            with await stream_id_gen.get_next(n) as stream_ids:
+            async with stream_id_gen.get_next(n) as stream_ids:
                 # ... persist events ...
         """
         with self._lock:
@@ -140,7 +142,7 @@ class StreamIdGenerator:
             for next_id in next_ids:
                 self._unfinished_ids.append(next_id)
 
-        @contextlib.contextmanager
+        @contextmanager
         def manager():
             try:
                 yield next_ids
@@ -149,7 +151,7 @@ class StreamIdGenerator:
                     for next_id in next_ids:
                         self._unfinished_ids.remove(next_id)
 
-        return manager()
+        return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self):
         """Returns the maximum stream id such that all stream ids less than or
@@ -184,12 +186,16 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
+        stream_name: A name for the stream.
         instance_name: The name of this instance.
         table: Database table associated with stream.
         instance_column: Column that stores the row's writer's instance name
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        writers: A list of known writers to use to populate current positions
+            on startup. Can be empty if nothing uses `get_current_token` or
+            `get_positions` (e.g. caches stream).
         positive: Whether the IDs are positive (true) or negative (false).
             When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
@@ -198,16 +204,20 @@ class MultiWriterIdGenerator:
         self,
         db_conn,
         db: DatabasePool,
+        stream_name: str,
         instance_name: str,
         table: str,
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        writers: List[str],
         positive: bool = True,
     ):
         self._db = db
+        self._stream_name = stream_name
         self._instance_name = instance_name
         self._positive = positive
+        self._writers = writers
         self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
@@ -216,14 +226,16 @@ class MultiWriterIdGenerator:
         # Note: If we are a negative stream then we still store all the IDs as
         # positive to make life easier for us, and simply negate the IDs when we
         # return them.
-        self._current_positions = self._load_current_ids(
-            db_conn, table, instance_column, id_column
-        )
+        self._current_positions = {}  # type: Dict[str, int]
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
         self._unfinished_ids = set()  # type: Set[int]
 
+        # Set of local IDs that we've processed that are larger than the current
+        # position, due to there being smaller unpersisted IDs.
+        self._finished_ids = set()  # type: Set[int]
+
         # We track the max position where we know everything before has been
         # persisted. This is done by a) looking at the min across all instances
         # and b) noting that if we have seen a run of persisted positions
@@ -236,37 +248,113 @@ class MultiWriterIdGenerator:
         # gaps should be relatively rare it's still worth doing the book keeping
         # that allows us to skip forwards when there are gapless runs of
         # positions.
+        #
+        # We start at 1 here as a) the first generated stream ID will be 2, and
+        # b) other parts of the code assume that stream IDs are strictly greater
+        # than 0.
         self._persisted_upto_position = (
-            min(self._current_positions.values()) if self._current_positions else 0
+            min(self._current_positions.values()) if self._current_positions else 1
         )
         self._known_persisted_positions = []  # type: List[int]
 
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
+        # We check that the table and sequence haven't diverged.
+        self._sequence_gen.check_consistency(
+            db_conn, table=table, id_column=id_column, positive=positive
+        )
+
+        # This goes and fills out the above state from the database.
+        self._load_current_ids(db_conn, table, instance_column, id_column)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
-    ) -> Dict[str, int]:
-        # If positive stream aggregate via MAX. For negative stream use MIN
-        # *and* negate the result to get a positive number.
-        sql = """
-            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
-            GROUP BY %(instance)s
-        """ % {
-            "instance": instance_column,
-            "id": id_column,
-            "table": table,
-            "agg": "MAX" if self._positive else "-MIN",
-        }
-
+    ):
         cur = db_conn.cursor()
-        cur.execute(sql)
 
-        # `cur` is an iterable over returned rows, which are 2-tuples.
-        current_positions = dict(cur)
+        # Load the current positions of all writers for the stream.
+        if self._writers:
+            # We delete any stale entries in the positions table. This is
+            # important if we add back a writer after a long time; we want to
+            # consider that a "new" writer, rather than using the old stale
+            # entry here.
+            sql = """
+                DELETE FROM stream_positions
+                WHERE
+                    stream_name = ?
+                    AND instance_name != ALL(?)
+            """
+            sql = self._db.engine.convert_param_style(sql)
+            cur.execute(sql, (self._stream_name, self._writers))
+
+            sql = """
+                SELECT instance_name, stream_id FROM stream_positions
+                WHERE stream_name = ?
+            """
+            sql = self._db.engine.convert_param_style(sql)
+
+            cur.execute(sql, (self._stream_name,))
+
+            self._current_positions = {
+                instance: stream_id * self._return_factor
+                for instance, stream_id in cur
+                if instance in self._writers
+            }
 
-        cur.close()
+        # We set the `_persisted_upto_position` to be the minimum of all current
+        # positions. If empty we use the max stream ID from the DB table.
+        min_stream_id = min(self._current_positions.values(), default=None)
+
+        if min_stream_id is None:
+            # We add a GREATEST here to ensure that the result is always
+            # positive. (This can be a problem for e.g. backfill streams where
+            # the server has never backfilled).
+            sql = """
+                SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+                FROM %(table)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "agg": "MAX" if self._positive else "-MIN",
+            }
+            cur.execute(sql)
+            (stream_id,) = cur.fetchone()
+            self._persisted_upto_position = stream_id
+        else:
+            # If we have a min_stream_id then we pull out everything greater
+            # than it from the DB so that we can prefill
+            # `_known_persisted_positions` and get a more accurate
+            # `_persisted_upto_position`.
+            #
+            # We also check if any of the later rows are from this instance, in
+            # which case we use that for this instance's current position. This
+            # is to handle the case where we didn't finish persisting to the
+            # stream positions table before restart (or the stream position
+            # table otherwise got out of date).
+
+            sql = """
+                SELECT %(instance)s, %(id)s FROM %(table)s
+                WHERE ? %(cmp)s %(id)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "instance": instance_column,
+                "cmp": "<=" if self._positive else ">=",
+            }
+            sql = self._db.engine.convert_param_style(sql)
+            cur.execute(sql, (min_stream_id * self._return_factor,))
+
+            self._persisted_upto_position = min_stream_id
 
-        return current_positions
+            with self._lock:
+                for (instance, stream_id,) in cur:
+                    stream_id = self._return_factor * stream_id
+                    self._add_persisted_position(stream_id)
+
+                    if instance == self._instance_name:
+                        self._current_positions[instance] = stream_id
+
+        cur.close()
 
     def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
@@ -274,59 +362,23 @@ class MultiWriterIdGenerator:
     def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
         return self._sequence_gen.get_next_mult_txn(txn, n)
 
-    async def get_next(self):
+    def get_next(self):
         """
         Usage:
-            with await stream_id_gen.get_next() as stream_id:
+            async with stream_id_gen.get_next() as stream_id:
                 # ... persist event ...
         """
-        next_id = await self._db.runInteraction("_load_next_id", self._load_next_id_txn)
-
-        # Assert the fetched ID is actually greater than what we currently
-        # believe the ID to be. If not, then the sequence and table have got
-        # out of sync somehow.
-        with self._lock:
-            assert self._current_positions.get(self._instance_name, 0) < next_id
-
-            self._unfinished_ids.add(next_id)
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                # Multiply by the return factor so that the ID has correct sign.
-                yield self._return_factor * next_id
-            finally:
-                self._mark_id_as_finished(next_id)
 
-        return manager()
+        return _MultiWriterCtxManager(self)
 
-    async def get_next_mult(self, n: int):
+    def get_next_mult(self, n: int):
         """
         Usage:
-            with await stream_id_gen.get_next_mult(5) as stream_ids:
+            async with stream_id_gen.get_next_mult(5) as stream_ids:
                 # ... persist events ...
         """
-        next_ids = await self._db.runInteraction(
-            "_load_next_mult_id", self._load_next_mult_id_txn, n
-        )
-
-        # Assert the fetched ID is actually greater than any ID we've already
-        # seen. If not, then the sequence and table have got out of sync
-        # somehow.
-        with self._lock:
-            assert max(self._current_positions.values(), default=0) < min(next_ids)
 
-            self._unfinished_ids.update(next_ids)
-
-        @contextlib.contextmanager
-        def manager():
-            try:
-                yield [self._return_factor * i for i in next_ids]
-            finally:
-                for i in next_ids:
-                    self._mark_id_as_finished(i)
-
-        return manager()
+        return _MultiWriterCtxManager(self, n)
 
     def get_next_txn(self, txn: LoggingTransaction):
         """
@@ -344,21 +396,63 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persited row with the correct instance name.
+        if self._writers:
+            txn.call_after(
+                run_as_background_process,
+                "MultiWriterIdGenerator._update_table",
+                self._db.runInteraction,
+                "MultiWriterIdGenerator._update_table",
+                self._update_stream_positions_table_txn,
+            )
+
         return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
         """The ID has finished being processed so we should advance the
-        current poistion if possible.
+        current position if possible.
         """
 
         with self._lock:
             self._unfinished_ids.discard(next_id)
+            self._finished_ids.add(next_id)
+
+            new_cur = None
+
+            if self._unfinished_ids:
+                # If there are unfinished IDs then the new position will be the
+                # largest finished ID less than the minimum unfinished ID.
+
+                finished = set()
+
+                min_unfinshed = min(self._unfinished_ids)
+                for s in self._finished_ids:
+                    if s < min_unfinshed:
+                        if new_cur is None or new_cur < s:
+                            new_cur = s
+                    else:
+                        finished.add(s)
+
+                # We clear these out since they're now all less than the new
+                # position.
+                self._finished_ids = finished
+            else:
+                # There are no unfinished IDs so the new position is simply the
+                # largest finished one.
+                new_cur = max(self._finished_ids)
+
+                # We clear these out since they're now all less than the new
+                # position.
+                self._finished_ids.clear()
 
-            # Figure out if its safe to advance the position by checking there
-            # aren't any lower allocated IDs that are yet to finish.
-            if all(c > next_id for c in self._unfinished_ids):
+            if new_cur:
                 curr = self._current_positions.get(self._instance_name, 0)
-                self._current_positions[self._instance_name] = max(curr, next_id)
+                self._current_positions[self._instance_name] = max(curr, new_cur)
 
             self._add_persisted_position(next_id)
 
@@ -367,19 +461,28 @@ class MultiWriterIdGenerator:
         equal to it have been successfully persisted.
         """
 
-        # Currently we don't support this operation, as it's not obvious how to
-        # condense the stream positions of multiple writers into a single int.
-        raise NotImplementedError()
+        return self.get_persisted_upto_position()
 
     def get_current_token_for_writer(self, instance_name: str) -> int:
         """Returns the position of the given writer.
         """
 
+        # If we don't have an entry for the given instance name, we assume it's a
+        # new writer.
+        #
+        # For new writers we assume their initial position to be the current
+        # persisted up to position. This stops Synapse from doing a full table
+        # scan when a new writer announces itself over replication.
         with self._lock:
-            return self._return_factor * self._current_positions.get(instance_name, 0)
+            return self._return_factor * self._current_positions.get(
+                instance_name, self._persisted_upto_position
+            )
 
     def get_positions(self) -> Dict[str, int]:
         """Get a copy of the current positon map.
+
+        Note that this won't necessarily include all configured writers if some
+        writers haven't written anything yet.
         """
 
         with self._lock:
@@ -428,7 +531,7 @@ class MultiWriterIdGenerator:
         # We move the current min position up if the minimum current positions
         # of all instances is higher (since by definition all positions less
         # that that have been persisted).
-        min_curr = min(self._current_positions.values())
+        min_curr = min(self._current_positions.values(), default=0)
         self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
 
         # We now iterate through the seen positions, discarding those that are
@@ -449,3 +552,97 @@ class MultiWriterIdGenerator:
                 # There was a gap in seen positions, so there is nothing more to
                 # do.
                 break
+
+    def _update_stream_positions_table_txn(self, txn: Cursor):
+        """Update the `stream_positions` table with newly persisted position.
+        """
+
+        if not self._writers:
+            return
+
+        # We upsert the value, ensuring on conflict that we always increase the
+        # value (or decrease if stream goes backwards).
+        sql = """
+            INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+            VALUES (?, ?, ?)
+            ON CONFLICT (stream_name, instance_name)
+            DO UPDATE SET
+                stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+        """ % {
+            "agg": "GREATEST" if self._positive else "LEAST",
+        }
+
+        pos = (self.get_current_token_for_writer(self._instance_name),)
+        txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
+
+@attr.s(slots=True)
+class _AsyncCtxManagerWrapper:
+    """Helper class to convert a plain context manager to an async one.
+
+    This is mainly useful if you have a plain context manager but the interface
+    requires an async one.
+    """
+
+    inner = attr.ib()
+
+    async def __aenter__(self):
+        return self.inner.__enter__()
+
+    async def __aexit__(self, exc_type, exc, tb):
+        return self.inner.__exit__(exc_type, exc, tb)
+
+
+@attr.s(slots=True)
+class _MultiWriterCtxManager:
+    """Async context manager returned by MultiWriterIdGenerator
+    """
+
+    id_gen = attr.ib(type=MultiWriterIdGenerator)
+    multiple_ids = attr.ib(type=Optional[int], default=None)
+    stream_ids = attr.ib(type=List[int], factory=list)
+
+    async def __aenter__(self) -> Union[int, List[int]]:
+        # It's safe to run this in autocommit mode as fetching values from a
+        # sequence ignores transaction semantics anyway.
+        self.stream_ids = await self.id_gen._db.runInteraction(
+            "_load_next_mult_id",
+            self.id_gen._load_next_mult_id_txn,
+            self.multiple_ids or 1,
+            db_autocommit=True,
+        )
+
+        with self.id_gen._lock:
+            self.id_gen._unfinished_ids.update(self.stream_ids)
+
+        if self.multiple_ids is None:
+            return self.stream_ids[0] * self.id_gen._return_factor
+        else:
+            return [i * self.id_gen._return_factor for i in self.stream_ids]
+
+    async def __aexit__(self, exc_type, exc, tb):
+        for i in self.stream_ids:
+            self.id_gen._mark_id_as_finished(i)
+
+        if exc_type is not None:
+            return False
+
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persisted row with the correct instance name.
+        #
+        # We do this in autocommit mode as a) the upsert works correctly outside
+        # transactions and b) reduces the amount of time the rows are locked
+        # for. If we don't do this then we'll often hit serialization errors due
+        # to the fact we default to REPEATABLE READ isolation levels.
+        if self.id_gen._writers:
+            await self.id_gen._db.runInteraction(
+                "MultiWriterIdGenerator._update_table",
+                self.id_gen._update_stream_positions_table_txn,
+                db_autocommit=True,
+            )
+
+        return False
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index ffc1894748..2dd95e2709 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -13,11 +13,34 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import abc
+import logging
 import threading
 from typing import Callable, List, Optional
 
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.engines import (
+    BaseDatabaseEngine,
+    IncorrectDatabaseSetup,
+    PostgresEngine,
+)
+from synapse.storage.types import Connection, Cursor
+
+logger = logging.getLogger(__name__)
+
+
+_INCONSISTENT_SEQUENCE_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated
+table '%(table)s'. This can happen if Synapse has been downgraded and
+then upgraded again, or due to a bad migration.
+
+To fix this error, shut down Synapse (including any and all workers)
+and run the following SQL:
+
+    SELECT setval('%(seq)s', (
+        %(max_id_sql)s
+    ));
+
+See docs/postgres.md for more information.
+"""
 
 
 class SequenceGenerator(metaclass=abc.ABCMeta):
@@ -28,6 +51,19 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         """Gets the next ID in the sequence"""
         ...
 
+    @abc.abstractmethod
+    def check_consistency(
+        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+    ):
+        """Should be called during start up to test that the current value of
+        the sequence is greater than or equal to the maximum ID in the table.
+
+        This is to handle various cases where the sequence value can get out
+        of sync with the table, e.g. if Synapse gets rolled back to a previous
+        version and the rolled forwards again.
+        """
+        ...
+
 
 class PostgresSequenceGenerator(SequenceGenerator):
     """An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -45,6 +81,50 @@ class PostgresSequenceGenerator(SequenceGenerator):
         )
         return [i for (i,) in txn]
 
+    def check_consistency(
+        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+    ):
+        txn = db_conn.cursor()
+
+        # First we get the current max ID from the table.
+        table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
+            "id": id_column,
+            "table": table,
+            "agg": "MAX" if positive else "-MIN",
+        }
+
+        txn.execute(table_sql)
+        row = txn.fetchone()
+        if not row:
+            # Table is empty, so nothing to do.
+            txn.close()
+            return
+
+        # Now we fetch the current value from the sequence and compare with the
+        # above.
+        max_stream_id = row[0]
+        txn.execute(
+            "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
+        )
+        last_value, is_called = txn.fetchone()
+        txn.close()
+
+        # If `is_called` is False then `last_value` is actually the value that
+        # will be generated next, so we decrement to get the true "last value".
+        if not is_called:
+            last_value -= 1
+
+        if max_stream_id > last_value:
+            logger.warning(
+                "Postgres sequence %s is behind table %s: %d < %d",
+                last_value,
+                max_stream_id,
+            )
+            raise IncorrectDatabaseSetup(
+                _INCONSISTENT_SEQUENCE_ERROR
+                % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
+            )
+
 
 GetFirstCallbackType = Callable[[Cursor], int]
 
@@ -81,6 +161,12 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def check_consistency(
+        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+    ):
+        # There is nothing to do for in memory sequences
+        pass
+
 
 def build_sequence_generator(
     database_engine: BaseDatabaseEngine,