diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 4386b6101e..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -15,9 +15,8 @@
import abc
import logging
import threading
-from typing import Callable, List, Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
-from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import (
BaseDatabaseEngine,
IncorrectDatabaseSetup,
@@ -25,6 +24,9 @@ from synapse.storage.engines import (
)
from synapse.storage.types import Connection, Cursor
+if TYPE_CHECKING:
+ from synapse.storage.database import LoggingDatabaseConnection
+
logger = logging.getLogger(__name__)
@@ -43,6 +45,21 @@ and run the following SQL:
See docs/postgres.md for more information.
"""
+_INCONSISTENT_STREAM_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated stream position
+of '%(stream_name)s' in the 'stream_positions' table.
+
+This is likely a programming error and should be reported at
+https://github.com/matrix-org/synapse.
+
+A temporary workaround to fix this error is to shut down Synapse (including
+any and all workers) and run the following SQL:
+
+ DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s';
+
+This will need to be done every time the server is restarted.
+"""
+
class SequenceGenerator(metaclass=abc.ABCMeta):
"""A class which generates a unique sequence of integers"""
@@ -53,19 +70,30 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
...
@abc.abstractmethod
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ """Get the next `n` IDs in the sequence"""
+ ...
+
+ @abc.abstractmethod
def check_consistency(
self,
- db_conn: LoggingDatabaseConnection,
+ db_conn: "LoggingDatabaseConnection",
table: str,
id_column: str,
+ stream_name: Optional[str] = None,
positive: bool = True,
):
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
- This is to handle various cases where the sequence value can get out
- of sync with the table, e.g. if Synapse gets rolled back to a previous
+ This is to handle various cases where the sequence value can get out of
+ sync with the table, e.g. if Synapse gets rolled back to a previous
version and the rolled forwards again.
+
+ If a stream name is given then this will check that any value in the
+ `stream_positions` table is less than or equal to the current sequence
+ value. If it isn't then it's likely that streams have been crossed
+ somewhere (e.g. two ID generators have the same stream name).
"""
...
@@ -88,11 +116,15 @@ class PostgresSequenceGenerator(SequenceGenerator):
def check_consistency(
self,
- db_conn: LoggingDatabaseConnection,
+ db_conn: "LoggingDatabaseConnection",
table: str,
id_column: str,
+ stream_name: Optional[str] = None,
positive: bool = True,
):
+ """See SequenceGenerator.check_consistency for docstring.
+ """
+
txn = db_conn.cursor(txn_name="sequence.check_consistency")
# First we get the current max ID from the table.
@@ -116,6 +148,18 @@ class PostgresSequenceGenerator(SequenceGenerator):
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
last_value, is_called = txn.fetchone()
+
+ # If we have an associated stream check the stream_positions table.
+ max_in_stream_positions = None
+ if stream_name:
+ txn.execute(
+ "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?",
+ (stream_name,),
+ )
+ row = txn.fetchone()
+ if row:
+ max_in_stream_positions = row[0]
+
txn.close()
# If `is_called` is False then `last_value` is actually the value that
@@ -136,6 +180,14 @@ class PostgresSequenceGenerator(SequenceGenerator):
% {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
)
+ # If we have values in the stream positions table then they have to be
+ # less than or equal to `last_value`
+ if max_in_stream_positions and max_in_stream_positions > last_value:
+ raise IncorrectDatabaseSetup(
+ _INCONSISTENT_STREAM_ERROR
+ % {"seq": self._sequence_name, "stream_name": stream_name}
+ )
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -172,8 +224,24 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ first_id = self._current_max_id + 1
+ self._current_max_id += n
+ return [first_id + i for i in range(n)]
+
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: Connection,
+ table: str,
+ id_column: str,
+ stream_name: Optional[str] = None,
+ positive: bool = True,
):
# There is nothing to do for in memory sequences
pass
|