diff --git a/changelog.d/17309.misc b/changelog.d/17309.misc
new file mode 100644
index 0000000000..cb6b9504b3
--- /dev/null
+++ b/changelog.d/17309.misc
@@ -0,0 +1 @@
+When rolling back to a previous Synapse version and then forwards again to this release, don't require server operators to manually run SQL.
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 59c8e05c39..48f88a6f8a 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -276,9 +276,6 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
# no active writes in progress.
self._max_position_of_local_instance = self._max_seen_allocated_stream_id
- # This goes and fills out the above state from the database.
- self._load_current_ids(db_conn, tables)
-
self._sequence_gen = build_sequence_generator(
db_conn=db_conn,
database_engine=db.engine,
@@ -303,6 +300,13 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
positive=positive,
)
+ # This goes and fills out the above state from the database.
+ # This may read on the PostgreSQL sequence, and
+ # SequenceGenerator.check_consistency might have fixed up the sequence, which
+ # means the SequenceGenerator needs to be setup before we read the value from
+ # the sequence.
+ self._load_current_ids(db_conn, tables, sequence_name)
+
self._max_seen_allocated_stream_id = max(
self._current_positions.values(), default=1
)
@@ -327,6 +331,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
self,
db_conn: LoggingDatabaseConnection,
tables: List[Tuple[str, str, str]],
+ sequence_name: str,
) -> None:
cur = db_conn.cursor(txn_name="_load_current_ids")
@@ -360,6 +365,18 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
if instance in self._writers
}
+ # If we're a writer, we can assume we're at the end of the stream
+ # Usually, we would get that from the stream_positions, but in some cases,
+ # like if we rolled back Synapse, the stream_positions table might not be up to
+ # date. If we're using Postgres for the sequences, we can just use the current
+ # sequence value as our own position.
+ if self._instance_name in self._writers:
+ if isinstance(self._db.engine, PostgresEngine):
+ cur.execute(f"SELECT last_value FROM {sequence_name}")
+ row = cur.fetchone()
+ assert row is not None
+ self._current_positions[self._instance_name] = row[0]
+
# We set the `_persisted_upto_position` to be the minimum of all current
# positions. If empty we use the max stream ID from the DB table.
min_stream_id = min(self._current_positions.values(), default=None)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index 9be2923e6f..12b89cecb6 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -18,7 +18,7 @@
# [This file includes modifications made by New Vector Limited]
#
#
-from typing import List, Optional
+from typing import Dict, List, Optional
from twisted.test.proto_helpers import MemoryReactor
@@ -42,9 +42,13 @@ from tests.utils import USE_POSTGRES_FOR_TESTS
class MultiWriterIdGeneratorBase(HomeserverTestCase):
+ positive: bool = True
+ tables: List[str] = ["foobar"]
+
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.db_pool: DatabasePool = self.store.db_pool
+ self.instances: Dict[str, MultiWriterIdGenerator] = {}
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
@@ -57,18 +61,22 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
if USE_POSTGRES_FOR_TESTS:
txn.execute("CREATE SEQUENCE foobar_seq")
- txn.execute(
- """
- CREATE TABLE foobar (
- stream_id BIGINT NOT NULL,
- instance_name TEXT NOT NULL,
- data TEXT
- );
- """
- )
+ for table in self.tables:
+ txn.execute(
+ """
+ CREATE TABLE %s (
+ stream_id BIGINT NOT NULL,
+ instance_name TEXT NOT NULL,
+ data TEXT
+ );
+ """
+ % (table,)
+ )
def _create_id_generator(
- self, instance_name: str = "master", writers: Optional[List[str]] = None
+ self,
+ instance_name: str = "master",
+ writers: Optional[List[str]] = None,
) -> MultiWriterIdGenerator:
def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
return MultiWriterIdGenerator(
@@ -77,36 +85,93 @@ class MultiWriterIdGeneratorBase(HomeserverTestCase):
notifier=self.hs.get_replication_notifier(),
stream_name="test_stream",
instance_name=instance_name,
- tables=[("foobar", "instance_name", "stream_id")],
+ tables=[(table, "instance_name", "stream_id") for table in self.tables],
sequence_name="foobar_seq",
writers=writers or ["master"],
+ positive=self.positive,
)
- return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
+ self.instances[instance_name] = self.get_success_or_raise(
+ self.db_pool.runWithConnection(_create)
+ )
+ return self.instances[instance_name]
+
+ def _replicate(self, instance_name: str) -> None:
+ """Similate a replication event for the given instance."""
+
+ writer = self.instances[instance_name]
+ token = writer.get_current_token_for_writer(instance_name)
+ for generator in self.instances.values():
+ if writer != generator:
+ generator.advance(instance_name, token)
+
+ def _replicate_all(self) -> None:
+ """Similate a replication event for all instances."""
- def _insert_rows(self, instance_name: str, number: int) -> None:
+ for instance_name in self.instances:
+ self._replicate(instance_name)
+
+ def _insert_row(
+ self, instance_name: str, stream_id: int, table: Optional[str] = None
+ ) -> None:
+ """Insert one row as the given instance with given stream_id."""
+
+ if table is None:
+ table = self.tables[0]
+
+ factor = 1 if self.positive else -1
+
+ def _insert(txn: LoggingTransaction) -> None:
+ txn.execute(
+ "INSERT INTO %s VALUES (?, ?)" % (table,),
+ (
+ stream_id,
+ instance_name,
+ ),
+ )
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, stream_id * factor, stream_id * factor),
+ )
+
+ self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+
+ def _insert_rows(
+ self,
+ instance_name: str,
+ number: int,
+ table: Optional[str] = None,
+ update_stream_table: bool = True,
+ ) -> None:
"""Insert N rows as the given instance, inserting with stream IDs pulled
from the postgres sequence.
"""
+ if table is None:
+ table = self.tables[0]
+
+ factor = 1 if self.positive else -1
+
def _insert(txn: LoggingTransaction) -> None:
for _ in range(number):
next_val = self.seq_gen.get_next_id_txn(txn)
txn.execute(
- "INSERT INTO foobar (stream_id, instance_name) VALUES (?, ?)",
- (
- next_val,
- instance_name,
- ),
+ "INSERT INTO %s (stream_id, instance_name) VALUES (?, ?)"
+ % (table,),
+ (next_val, instance_name),
)
- txn.execute(
- """
- INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
- ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
- """,
- (instance_name, next_val, next_val),
- )
+ if update_stream_table:
+ txn.execute(
+ """
+ INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
+ ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
+ """,
+ (instance_name, next_val * factor, next_val * factor),
+ )
self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
@@ -353,7 +418,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
id_gen = self._create_id_generator("first", writers=["first", "second"])
- self.assertEqual(id_gen.get_positions(), {"first": 3, "second": 5})
+ # When the writer is created, it assumes its own position is the current head of
+ # the sequence
+ self.assertEqual(id_gen.get_positions(), {"first": 5, "second": 5})
self.assertEqual(id_gen.get_persisted_upto_position(), 5)
@@ -375,11 +442,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
correctly.
"""
self._insert_rows("first", 3)
- self._insert_rows("second", 4)
-
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+ self._insert_rows("second", 4)
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+ self._replicate_all()
+
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
@@ -398,6 +467,9 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7}
)
+ self.assertEqual(
+ second_id_gen.get_positions(), {"first": 3, "second": 7}
+ )
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
self.get_success(_get_next_async())
@@ -432,11 +504,11 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
"""
# Insert some rows for two out of three of the ID gens.
self._insert_rows("first", 3)
- self._insert_rows("second", 4)
-
first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)
+
+ self._insert_rows("second", 4)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)
@@ -444,6 +516,8 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
"third", writers=["first", "second", "third"]
)
+ self._replicate_all()
+
self.assertEqual(
first_id_gen.get_positions(), {"first": 3, "second": 7, "third": 7}
)
@@ -546,11 +620,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
def test_minimal_local_token(self) -> None:
self._insert_rows("first", 3)
- self._insert_rows("second", 4)
-
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+ self._insert_rows("second", 4)
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
+ self._replicate_all()
+
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_minimal_local_current_token(), 3)
@@ -562,15 +638,17 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
token when there are no writes.
"""
self._insert_rows("first", 3)
- self._insert_rows("second", 4)
-
first_id_gen = self._create_id_generator(
"first", writers=["first", "second", "third"]
)
+
+ self._insert_rows("second", 4)
second_id_gen = self._create_id_generator(
"second", writers=["first", "second", "third"]
)
+ self._replicate_all()
+
self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(second_id_gen.get_current_token(), 7)
@@ -609,68 +687,13 @@ class WorkerMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
self.assertEqual(second_id_gen.get_current_token(), 7)
-class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+class BackwardsMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs."""
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.db_pool: DatabasePool = self.store.db_pool
-
- self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
- def _setup_db(self, txn: LoggingTransaction) -> None:
- txn.execute("CREATE SEQUENCE foobar_seq")
- txn.execute(
- """
- CREATE TABLE foobar (
- stream_id BIGINT NOT NULL,
- instance_name TEXT NOT NULL,
- data TEXT
- );
- """
- )
-
- def _create_id_generator(
- self, instance_name: str = "master", writers: Optional[List[str]] = None
- ) -> MultiWriterIdGenerator:
- def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
- return MultiWriterIdGenerator(
- conn,
- self.db_pool,
- notifier=self.hs.get_replication_notifier(),
- stream_name="test_stream",
- instance_name=instance_name,
- tables=[("foobar", "instance_name", "stream_id")],
- sequence_name="foobar_seq",
- writers=writers or ["master"],
- positive=False,
- )
-
- return self.get_success(self.db_pool.runWithConnection(_create))
-
- def _insert_row(self, instance_name: str, stream_id: int) -> None:
- """Insert one row as the given instance with given stream_id."""
-
- def _insert(txn: LoggingTransaction) -> None:
- txn.execute(
- "INSERT INTO foobar VALUES (?, ?)",
- (
- stream_id,
- instance_name,
- ),
- )
- txn.execute(
- """
- INSERT INTO stream_positions VALUES ('test_stream', ?, ?)
- ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = ?
- """,
- (instance_name, -stream_id, -stream_id),
- )
-
- self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
+ positive = False
def test_single_instance(self) -> None:
"""Test that reads and writes from a single process are handled
@@ -716,7 +739,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
async def _get_next_async() -> None:
async with id_gen_1.get_next() as stream_id:
self._insert_row("first", stream_id)
- id_gen_2.advance("first", stream_id)
+ self._replicate("first")
self.get_success(_get_next_async())
@@ -728,7 +751,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
async def _get_next_async2() -> None:
async with id_gen_2.get_next() as stream_id:
self._insert_row("second", stream_id)
- id_gen_1.advance("second", stream_id)
+ self._replicate("second")
self.get_success(_get_next_async2())
@@ -738,98 +761,26 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
-class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
+class MultiTableMultiWriterIdGeneratorTestCase(MultiWriterIdGeneratorBase):
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
- def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
- self.store = hs.get_datastores().main
- self.db_pool: DatabasePool = self.store.db_pool
-
- self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
-
- def _setup_db(self, txn: LoggingTransaction) -> None:
- txn.execute("CREATE SEQUENCE foobar_seq")
- txn.execute(
- """
- CREATE TABLE foobar1 (
- stream_id BIGINT NOT NULL,
- instance_name TEXT NOT NULL,
- data TEXT
- );
- """
- )
-
- txn.execute(
- """
- CREATE TABLE foobar2 (
- stream_id BIGINT NOT NULL,
- instance_name TEXT NOT NULL,
- data TEXT
- );
- """
- )
-
- def _create_id_generator(
- self, instance_name: str = "master", writers: Optional[List[str]] = None
- ) -> MultiWriterIdGenerator:
- def _create(conn: LoggingDatabaseConnection) -> MultiWriterIdGenerator:
- return MultiWriterIdGenerator(
- conn,
- self.db_pool,
- notifier=self.hs.get_replication_notifier(),
- stream_name="test_stream",
- instance_name=instance_name,
- tables=[
- ("foobar1", "instance_name", "stream_id"),
- ("foobar2", "instance_name", "stream_id"),
- ],
- sequence_name="foobar_seq",
- writers=writers or ["master"],
- )
-
- return self.get_success_or_raise(self.db_pool.runWithConnection(_create))
-
- def _insert_rows(
- self,
- table: str,
- instance_name: str,
- number: int,
- update_stream_table: bool = True,
- ) -> None:
- """Insert N rows as the given instance, inserting with stream IDs pulled
- from the postgres sequence.
- """
-
- def _insert(txn: LoggingTransaction) -> None:
- for _ in range(number):
- txn.execute(
- "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,),
- (instance_name,),
- )
- if update_stream_table:
- txn.execute(
- """
- INSERT INTO stream_positions VALUES ('test_stream', ?, lastval())
- ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval()
- """,
- (instance_name,),
- )
-
- self.get_success(self.db_pool.runInteraction("_insert_rows", _insert))
+ tables = ["foobar1", "foobar2"]
def test_load_existing_stream(self) -> None:
"""Test creating ID gens with multiple tables that have rows from after
the position in `stream_positions` table.
"""
- self._insert_rows("foobar1", "first", 3)
- self._insert_rows("foobar2", "second", 3)
- self._insert_rows("foobar2", "second", 1, update_stream_table=False)
-
+ self._insert_rows("first", 3, table="foobar1")
first_id_gen = self._create_id_generator("first", writers=["first", "second"])
+
+ self._insert_rows("second", 3, table="foobar2")
+ self._insert_rows("second", 1, table="foobar2", update_stream_table=False)
second_id_gen = self._create_id_generator("second", writers=["first", "second"])
- self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 6})
+ self._replicate_all()
+
+ self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_persisted_upto_position(), 7)
|