summary refs log tree commit diff
path: root/synapse/storage/util
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2020-09-24 16:53:51 +0100
committerGitHub <noreply@github.com>2020-09-24 16:53:51 +0100
commitf112cfe5bb2c918c9e942941686a05664d8bd7da (patch)
tree7e58da475bcfe5ddb8c1191d1ef20c32069f5f5f /synapse/storage/util
parentAdd type annotations to SimpleHttpClient (#8372) (diff)
downloadsynapse-f112cfe5bb2c918c9e942941686a05664d8bd7da.tar.xz
Fix MultiWriteIdGenerator's handling of restarts. (#8374)
On startup `MultiWriteIdGenerator` fetches the maximum stream ID for
each instance from the table and uses that as its initial "current
position" for each writer. This is problematic as a) it involves either
a scan of events table or an index (neither of which is ideal), and b)
if rows are being persisted out of order elsewhere while the process
restarts then using the maximum stream ID is not correct. This could
theoretically lead to race conditions where e.g. events that are
persisted out of order are not sent down sync streams.

We fix this by creating a new table that tracks the current positions of
each writer to the stream, and update it each time we finish persisting
a new entry. This is a relatively small overhead when persisting events.
However for the cache invalidation stream this is a much bigger relative
overhead, so instead we note that for invalidation we don't actually
care about reliability over restarts (as there's no caches to
invalidate) and simply don't bother reading and writing to the new table
in that particular case.
Diffstat (limited to 'synapse/storage/util')
-rw-r--r--synapse/storage/util/id_generators.py148
1 files changed, 127 insertions, 21 deletions
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index b0353ac2dc..727fcc521c 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -22,6 +22,7 @@ from typing import Dict, List, Optional, Set, Union
 import attr
 from typing_extensions import Deque
 
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
@@ -184,12 +185,16 @@ class MultiWriterIdGenerator:
     Args:
         db_conn
         db
+        stream_name: A name for the stream.
         instance_name: The name of this instance.
         table: Database table associated with stream.
         instance_column: Column that stores the row's writer's instance name
         id_column: Column that stores the stream ID.
         sequence_name: The name of the postgres sequence used to generate new
             IDs.
+        writers: A list of known writers to use to populate current positions
+            on startup. Can be empty if nothing uses `get_current_token` or
+            `get_positions` (e.g. caches stream).
         positive: Whether the IDs are positive (true) or negative (false).
             When using negative IDs we go backwards from -1 to -2, -3, etc.
     """
@@ -198,16 +203,20 @@ class MultiWriterIdGenerator:
         self,
         db_conn,
         db: DatabasePool,
+        stream_name: str,
         instance_name: str,
         table: str,
         instance_column: str,
         id_column: str,
         sequence_name: str,
+        writers: List[str],
         positive: bool = True,
     ):
         self._db = db
+        self._stream_name = stream_name
         self._instance_name = instance_name
         self._positive = positive
+        self._writers = writers
         self._return_factor = 1 if positive else -1
 
         # We lock as some functions may be called from DB threads.
@@ -216,9 +225,7 @@ class MultiWriterIdGenerator:
         # Note: If we are a negative stream then we still store all the IDs as
         # positive to make life easier for us, and simply negate the IDs when we
         # return them.
-        self._current_positions = self._load_current_ids(
-            db_conn, table, instance_column, id_column
-        )
+        self._current_positions = {}  # type: Dict[str, int]
 
         # Set of local IDs that we're still processing. The current position
         # should be less than the minimum of this set (if not empty).
@@ -251,30 +258,80 @@ class MultiWriterIdGenerator:
 
         self._sequence_gen = PostgresSequenceGenerator(sequence_name)
 
+        # This goes and fills out the above state from the database.
+        self._load_current_ids(db_conn, table, instance_column, id_column)
+
     def _load_current_ids(
         self, db_conn, table: str, instance_column: str, id_column: str
-    ) -> Dict[str, int]:
-        # If positive stream aggregate via MAX. For negative stream use MIN
-        # *and* negate the result to get a positive number.
-        sql = """
-            SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
-            GROUP BY %(instance)s
-        """ % {
-            "instance": instance_column,
-            "id": id_column,
-            "table": table,
-            "agg": "MAX" if self._positive else "-MIN",
-        }
-
+    ):
         cur = db_conn.cursor()
-        cur.execute(sql)
 
-        # `cur` is an iterable over returned rows, which are 2-tuples.
-        current_positions = dict(cur)
+        # Load the current positions of all writers for the stream.
+        if self._writers:
+            sql = """
+                SELECT instance_name, stream_id FROM stream_positions
+                WHERE stream_name = ?
+            """
+            sql = self._db.engine.convert_param_style(sql)
 
-        cur.close()
+            cur.execute(sql, (self._stream_name,))
+
+            self._current_positions = {
+                instance: stream_id * self._return_factor
+                for instance, stream_id in cur
+                if instance in self._writers
+            }
+
+        # We set the `_persisted_upto_position` to be the minimum of all current
+        # positions. If empty we use the max stream ID from the DB table.
+        min_stream_id = min(self._current_positions.values(), default=None)
+
+        if min_stream_id is None:
+            sql = """
+                SELECT COALESCE(%(agg)s(%(id)s), 1) FROM %(table)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "agg": "MAX" if self._positive else "-MIN",
+            }
+            cur.execute(sql)
+            (stream_id,) = cur.fetchone()
+            self._persisted_upto_position = stream_id
+        else:
+            # If we have a min_stream_id then we pull out everything greater
+            # than it from the DB so that we can prefill
+            # `_known_persisted_positions` and get a more accurate
+            # `_persisted_upto_position`.
+            #
+            # We also check if any of the later rows are from this instance, in
+            # which case we use that for this instance's current position. This
+            # is to handle the case where we didn't finish persisting to the
+            # stream positions table before restart (or the stream position
+            # table otherwise got out of date).
+
+            sql = """
+                SELECT %(instance)s, %(id)s FROM %(table)s
+                WHERE ? %(cmp)s %(id)s
+            """ % {
+                "id": id_column,
+                "table": table,
+                "instance": instance_column,
+                "cmp": "<=" if self._positive else ">=",
+            }
+            sql = self._db.engine.convert_param_style(sql)
+            cur.execute(sql, (min_stream_id,))
+
+            self._persisted_upto_position = min_stream_id
+
+            with self._lock:
+                for (instance, stream_id,) in cur:
+                    stream_id = self._return_factor * stream_id
+                    self._add_persisted_position(stream_id)
 
-        return current_positions
+                    if instance == self._instance_name:
+                        self._current_positions[instance] = stream_id
+
+        cur.close()
 
     def _load_next_id_txn(self, txn) -> int:
         return self._sequence_gen.get_next_id_txn(txn)
@@ -316,6 +373,21 @@ class MultiWriterIdGenerator:
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
 
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persited row with the correct instance name.
+        if self._writers:
+            txn.call_after(
+                run_as_background_process,
+                "MultiWriterIdGenerator._update_table",
+                self._db.runInteraction,
+                "MultiWriterIdGenerator._update_table",
+                self._update_stream_positions_table_txn,
+            )
+
         return self._return_factor * next_id
 
     def _mark_id_as_finished(self, next_id: int):
@@ -447,6 +519,28 @@ class MultiWriterIdGenerator:
                 # do.
                 break
 
+    def _update_stream_positions_table_txn(self, txn):
+        """Update the `stream_positions` table with newly persisted position.
+        """
+
+        if not self._writers:
+            return
+
+        # We upsert the value, ensuring on conflict that we always increase the
+        # value (or decrease if stream goes backwards).
+        sql = """
+            INSERT INTO stream_positions (stream_name, instance_name, stream_id)
+            VALUES (?, ?, ?)
+            ON CONFLICT (stream_name, instance_name)
+            DO UPDATE SET
+                stream_id = %(agg)s(stream_positions.stream_id, EXCLUDED.stream_id)
+        """ % {
+            "agg": "GREATEST" if self._positive else "LEAST",
+        }
+
+        pos = (self.get_current_token_for_writer(self._instance_name),)
+        txn.execute(sql, (self._stream_name, self._instance_name, pos))
+
 
 @attr.s(slots=True)
 class _AsyncCtxManagerWrapper:
@@ -503,4 +597,16 @@ class _MultiWriterCtxManager:
         if exc_type is not None:
             return False
 
+        # Update the `stream_positions` table with newly updated stream
+        # ID (unless self._writers is not set in which case we don't
+        # bother, as nothing will read it).
+        #
+        # We only do this on the success path so that the persisted current
+        # position points to a persisted row with the correct instance name.
+        if self.id_gen._writers:
+            await self.id_gen._db.runInteraction(
+                "MultiWriterIdGenerator._update_table",
+                self.id_gen._update_stream_positions_table_txn,
+            )
+
         return False