diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index ffc1894748..2dd95e2709 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -13,11 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
+import logging
import threading
from typing import Callable, List, Optional
-from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.engines import (
+ BaseDatabaseEngine,
+ IncorrectDatabaseSetup,
+ PostgresEngine,
+)
+from synapse.storage.types import Connection, Cursor
+
+logger = logging.getLogger(__name__)
+
+
+_INCONSISTENT_SEQUENCE_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated
+table '%(table)s'. This can happen if Synapse has been downgraded and
+then upgraded again, or due to a bad migration.
+
+To fix this error, shut down Synapse (including any and all workers)
+and run the following SQL:
+
+ SELECT setval('%(seq)s', (
+ %(max_id_sql)s
+ ));
+
+See docs/postgres.md for more information.
+"""
class SequenceGenerator(metaclass=abc.ABCMeta):
@@ -28,6 +51,19 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""Gets the next ID in the sequence"""
...
+ @abc.abstractmethod
+ def check_consistency(
+ self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ ):
+ """Should be called during start up to test that the current value of
+ the sequence is greater than or equal to the maximum ID in the table.
+
+ This is to handle various cases where the sequence value can get out
+ of sync with the table, e.g. if Synapse gets rolled back to a previous
+ version and the rolled forwards again.
+ """
+ ...
+
class PostgresSequenceGenerator(SequenceGenerator):
"""An implementation of SequenceGenerator which uses a postgres sequence"""
@@ -45,6 +81,50 @@ class PostgresSequenceGenerator(SequenceGenerator):
)
return [i for (i,) in txn]
+ def check_consistency(
+ self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ ):
+ txn = db_conn.cursor()
+
+ # First we get the current max ID from the table.
+ table_sql = "SELECT GREATEST(%(agg)s(%(id)s), 0) FROM %(table)s" % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if positive else "-MIN",
+ }
+
+ txn.execute(table_sql)
+ row = txn.fetchone()
+ if not row:
+ # Table is empty, so nothing to do.
+ txn.close()
+ return
+
+ # Now we fetch the current value from the sequence and compare with the
+ # above.
+ max_stream_id = row[0]
+ txn.execute(
+ "SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
+ )
+ last_value, is_called = txn.fetchone()
+ txn.close()
+
+ # If `is_called` is False then `last_value` is actually the value that
+ # will be generated next, so we decrement to get the true "last value".
+ if not is_called:
+ last_value -= 1
+
+ if max_stream_id > last_value:
+ logger.warning(
+ "Postgres sequence %s is behind table %s: %d < %d",
+ last_value,
+ max_stream_id,
+ )
+ raise IncorrectDatabaseSetup(
+ _INCONSISTENT_SEQUENCE_ERROR
+ % {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
+ )
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -81,6 +161,12 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def check_consistency(
+ self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ ):
+ # There is nothing to do for in memory sequences
+ pass
+
def build_sequence_generator(
database_engine: BaseDatabaseEngine,
|