summary refs log tree commit diff
path: root/synapse/storage/util/sequence.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/util/sequence.py')
-rw-r--r--synapse/storage/util/sequence.py99
1 files changed, 97 insertions, 2 deletions
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index ffc1894748..ff2d038ad2 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -13,11 +13,35 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import abc
+import logging
 import threading
 from typing import Callable, List, Optional
 
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.database import LoggingDatabaseConnection
+from synapse.storage.engines import (
+    BaseDatabaseEngine,
+    IncorrectDatabaseSetup,
+    PostgresEngine,
+)
+from synapse.storage.types import Connection, Cursor
+
+logger = logging.getLogger(__name__)
+
+
+_INCONSISTENT_SEQUENCE_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated
+table '%(table)s'. This can happen if Synapse has been downgraded and
+then upgraded again, or due to a bad migration.
+
+To fix this error, shut down Synapse (including any and all workers)
+and run the following SQL:
+
+    SELECT setval('%(seq)s', (
+        %(max_id_sql)s
+    ));
+
+See docs/postgres.md for more information.
+"""
 
 
 class SequenceGenerator(metaclass=abc.ABCMeta):
@@ -28,6 +52,23 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
         """Gets the next ID in the sequence"""
         ...
 
+    @abc.abstractmethod
+    def check_consistency(
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        id_column: str,
+        positive: bool = True,
+    ):
+        """Should be called during start up to test that the current value of
+        the sequence is greater than or equal to the maximum ID in the table.
+
+        This is to handle various cases where the sequence value can get out
+        of sync with the table, e.g. if Synapse gets rolled back to a previous
+        version and the rolled forwards again.
+        """
+        ...
+
 
 class PostgresSequenceGenerator(SequenceGenerator):
     """An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -45,6 +86,54 @@ class PostgresSequenceGenerator(SequenceGenerator):
         )
         return [i for (i,) in txn]
 
+    def check_consistency(
+        self,
+        db_conn: LoggingDatabaseConnection,
+        table: str,
+        id_column: str,
+        positive: bool = True,
+    ):
+        txn = db_conn.cursor(txn_name="sequence.check_consistency")
+
+        # First we get the current max ID from the table.
+        table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
+            "id": id_column,
+            "table": table,
+            "agg": "MAX" if positive else "-MIN",
+        }
+
+        txn.execute(table_sql)
+        row = txn.fetchone()
+        if not row:
+            # Table is empty, so nothing to do.
+            txn.close()
+            return
+
+        # Now we fetch the current value from the sequence and compare with the
+        # above.
+        max_stream_id = row[0]
+        txn.execute(
+            "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
+        )
+        last_value, is_called = txn.fetchone()
+        txn.close()
+
+        # If `is_called` is False then `last_value` is actually the value that
+        # will be generated next, so we decrement to get the true "last value".
+        if not is_called:
+            last_value -= 1
+
+        if max_stream_id > last_value:
+            logger.warning(
+                "Postgres sequence %s is behind table %s: %d < %d",
+                last_value,
+                max_stream_id,
+            )
+            raise IncorrectDatabaseSetup(
+                _INCONSISTENT_SEQUENCE_ERROR
+                % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
+            )
+
 
 GetFirstCallbackType = Callable[[Cursor], int]
 
@@ -81,6 +170,12 @@ class LocalSequenceGenerator(SequenceGenerator):
             self._current_max_id += 1
             return self._current_max_id
 
+    def check_consistency(
+        self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+    ):
+        # There is nothing to do for in memory sequences
+        pass
+
 
 def build_sequence_generator(
     database_engine: BaseDatabaseEngine,