diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 133c0e7a28..71ef5a72dc 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -15,12 +15,11 @@
import heapq
import logging
import threading
-from collections import deque
+from collections import OrderedDict
from contextlib import contextmanager
-from typing import Dict, List, Optional, Set, Union
+from typing import Dict, List, Optional, Set, Tuple, Union
import attr
-from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -101,7 +100,13 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step)
)
- self._unfinished_ids = deque() # type: Deque[int]
+
+ # We use this as an ordered set, as we want to efficiently append items,
+ # remove items and get the first item. Since we insert IDs in order, the
+ # insertion ordering will ensure its in the correct ordering.
+ #
+ # The key and values are the same, but we never look at the values.
+ self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self):
"""
@@ -113,7 +118,7 @@ class StreamIdGenerator:
self._current += self._step
next_id = self._current
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -121,7 +126,7 @@ class StreamIdGenerator:
yield next_id
finally:
with self._lock:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -140,7 +145,7 @@ class StreamIdGenerator:
self._current += n * self._step
for next_id in next_ids:
- self._unfinished_ids.append(next_id)
+ self._unfinished_ids[next_id] = next_id
@contextmanager
def manager():
@@ -149,7 +154,7 @@ class StreamIdGenerator:
finally:
with self._lock:
for next_id in next_ids:
- self._unfinished_ids.remove(next_id)
+ self._unfinished_ids.pop(next_id)
return _AsyncCtxManagerWrapper(manager())
@@ -162,7 +167,7 @@ class StreamIdGenerator:
"""
with self._lock:
if self._unfinished_ids:
- return self._unfinished_ids[0] - self._step
+ return next(iter(self._unfinished_ids)) - self._step
return self._current
@@ -186,11 +191,12 @@ class MultiWriterIdGenerator:
Args:
db_conn
db
- stream_name: A name for the stream.
+ stream_name: A name for the stream, for use in the `stream_positions`
+ table. (Does not need to be the same as the replication stream name)
instance_name: The name of this instance.
- table: Database table associated with stream.
- instance_column: Column that stores the row's writer's instance name
- id_column: Column that stores the stream ID.
+ tables: List of tables associated with the stream. Tuple of table
+ name, column name that stores the writer's instance name, and
+ column name that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
writers: A list of known writers to use to populate current positions
@@ -206,9 +212,7 @@ class MultiWriterIdGenerator:
db: DatabasePool,
stream_name: str,
instance_name: str,
- table: str,
- instance_column: str,
- id_column: str,
+ tables: List[Tuple[str, str, str]],
sequence_name: str,
writers: List[str],
positive: bool = True,
@@ -260,15 +264,20 @@ class MultiWriterIdGenerator:
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
# We check that the table and sequence haven't diverged.
- self._sequence_gen.check_consistency(
- db_conn, table=table, id_column=id_column, positive=positive
- )
+ for table, _, id_column in tables:
+ self._sequence_gen.check_consistency(
+ db_conn,
+ table=table,
+ id_column=id_column,
+ stream_name=stream_name,
+ positive=positive,
+ )
# This goes and fills out the above state from the database.
- self._load_current_ids(db_conn, table, instance_column, id_column)
+ self._load_current_ids(db_conn, tables)
def _load_current_ids(
- self, db_conn, table: str, instance_column: str, id_column: str
+ self, db_conn, tables: List[Tuple[str, str, str]],
):
cur = db_conn.cursor(txn_name="_load_current_ids")
@@ -306,17 +315,22 @@ class MultiWriterIdGenerator:
# We add a GREATEST here to ensure that the result is always
# positive. (This can be a problem for e.g. backfill streams where
# the server has never backfilled).
- sql = """
- SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
- FROM %(table)s
- """ % {
- "id": id_column,
- "table": table,
- "agg": "MAX" if self._positive else "-MIN",
- }
- cur.execute(sql)
- (stream_id,) = cur.fetchone()
- self._persisted_upto_position = stream_id
+ max_stream_id = 1
+ for table, _, id_column in tables:
+ sql = """
+ SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1)
+ FROM %(table)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "agg": "MAX" if self._positive else "-MIN",
+ }
+ cur.execute(sql)
+ (stream_id,) = cur.fetchone()
+
+ max_stream_id = max(max_stream_id, stream_id)
+
+ self._persisted_upto_position = max_stream_id
else:
# If we have a min_stream_id then we pull out everything greater
# than it from the DB so that we can prefill
@@ -329,21 +343,28 @@ class MultiWriterIdGenerator:
# stream positions table before restart (or the stream position
# table otherwise got out of date).
- sql = """
- SELECT %(instance)s, %(id)s FROM %(table)s
- WHERE ? %(cmp)s %(id)s
- """ % {
- "id": id_column,
- "table": table,
- "instance": instance_column,
- "cmp": "<=" if self._positive else ">=",
- }
- cur.execute(sql, (min_stream_id * self._return_factor,))
-
self._persisted_upto_position = min_stream_id
+ rows = []
+ for table, instance_column, id_column in tables:
+ sql = """
+ SELECT %(instance)s, %(id)s FROM %(table)s
+ WHERE ? %(cmp)s %(id)s
+ """ % {
+ "id": id_column,
+ "table": table,
+ "instance": instance_column,
+ "cmp": "<=" if self._positive else ">=",
+ }
+ cur.execute(sql, (min_stream_id * self._return_factor,))
+
+ rows.extend(cur)
+
+ # Sort so that we handle rows in order for each instance.
+ rows.sort()
+
with self._lock:
- for (instance, stream_id,) in cur:
+ for (instance, stream_id,) in rows:
stream_id = self._return_factor * stream_id
self._add_persisted_position(stream_id)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 4386b6101e..0ec4dc2918 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -15,9 +15,8 @@
import abc
import logging
import threading
-from typing import Callable, List, Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
-from synapse.storage.database import LoggingDatabaseConnection
from synapse.storage.engines import (
BaseDatabaseEngine,
IncorrectDatabaseSetup,
@@ -25,6 +24,9 @@ from synapse.storage.engines import (
)
from synapse.storage.types import Connection, Cursor
+if TYPE_CHECKING:
+ from synapse.storage.database import LoggingDatabaseConnection
+
logger = logging.getLogger(__name__)
@@ -43,6 +45,21 @@ and run the following SQL:
See docs/postgres.md for more information.
"""
+_INCONSISTENT_STREAM_ERROR = """
+Postgres sequence '%(seq)s' is inconsistent with associated stream position
+of '%(stream_name)s' in the 'stream_positions' table.
+
+This is likely a programming error and should be reported at
+https://github.com/matrix-org/synapse.
+
+A temporary workaround to fix this error is to shut down Synapse (including
+any and all workers) and run the following SQL:
+
+ DELETE FROM stream_positions WHERE stream_name = '%(stream_name)s';
+
+This will need to be done every time the server is restarted.
+"""
+
class SequenceGenerator(metaclass=abc.ABCMeta):
"""A class which generates a unique sequence of integers"""
@@ -53,19 +70,30 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
...
@abc.abstractmethod
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ """Get the next `n` IDs in the sequence"""
+ ...
+
+ @abc.abstractmethod
def check_consistency(
self,
- db_conn: LoggingDatabaseConnection,
+ db_conn: "LoggingDatabaseConnection",
table: str,
id_column: str,
+ stream_name: Optional[str] = None,
positive: bool = True,
):
"""Should be called during start up to test that the current value of
the sequence is greater than or equal to the maximum ID in the table.
- This is to handle various cases where the sequence value can get out
- of sync with the table, e.g. if Synapse gets rolled back to a previous
+ This is to handle various cases where the sequence value can get out of
+ sync with the table, e.g. if Synapse gets rolled back to a previous
version and the rolled forwards again.
+
+ If a stream name is given then this will check that any value in the
+ `stream_positions` table is less than or equal to the current sequence
+ value. If it isn't then it's likely that streams have been crossed
+ somewhere (e.g. two ID generators have the same stream name).
"""
...
@@ -88,11 +116,15 @@ class PostgresSequenceGenerator(SequenceGenerator):
def check_consistency(
self,
- db_conn: LoggingDatabaseConnection,
+ db_conn: "LoggingDatabaseConnection",
table: str,
id_column: str,
+ stream_name: Optional[str] = None,
positive: bool = True,
):
+ """See SequenceGenerator.check_consistency for docstring.
+ """
+
txn = db_conn.cursor(txn_name="sequence.check_consistency")
# First we get the current max ID from the table.
@@ -116,6 +148,18 @@ class PostgresSequenceGenerator(SequenceGenerator):
"SELECT last_value, is_called FROM %(seq)s" % {"seq": self._sequence_name}
)
last_value, is_called = txn.fetchone()
+
+ # If we have an associated stream check the stream_positions table.
+ max_in_stream_positions = None
+ if stream_name:
+ txn.execute(
+ "SELECT MAX(stream_id) FROM stream_positions WHERE stream_name = ?",
+ (stream_name,),
+ )
+ row = txn.fetchone()
+ if row:
+ max_in_stream_positions = row[0]
+
txn.close()
# If `is_called` is False then `last_value` is actually the value that
@@ -136,6 +180,14 @@ class PostgresSequenceGenerator(SequenceGenerator):
% {"seq": self._sequence_name, "table": table, "max_id_sql": table_sql}
)
+ # If we have values in the stream positions table then they have to be
+ # less than or equal to `last_value`
+ if max_in_stream_positions and max_in_stream_positions > last_value:
+ raise IncorrectDatabaseSetup(
+ _INCONSISTENT_STREAM_ERROR
+ % {"seq": self._sequence_name, "stream_name": stream_name}
+ )
+
GetFirstCallbackType = Callable[[Cursor], int]
@@ -172,8 +224,24 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1
return self._current_max_id
+ def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
+ with self._lock:
+ if self._current_max_id is None:
+ assert self._callback is not None
+ self._current_max_id = self._callback(txn)
+ self._callback = None
+
+ first_id = self._current_max_id + 1
+ self._current_max_id += n
+ return [first_id + i for i in range(n)]
+
def check_consistency(
- self, db_conn: Connection, table: str, id_column: str, positive: bool = True
+ self,
+ db_conn: Connection,
+ table: str,
+ id_column: str,
+ stream_name: Optional[str] = None,
+ positive: bool = True,
):
# There is nothing to do for in memory sequences
pass
|