diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4646926449..f1ba529a2d 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,7 +49,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.background_updates import BackgroundUpdater
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.types import Connection, Cursor
-from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection
# python 3 does not have a maximum int value
@@ -381,7 +380,10 @@ class DatabasePool:
_TXN_ID = 0
def __init__(
- self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+ self,
+ hs,
+ database_config: DatabaseConnectionConfig,
+ engine: BaseDatabaseEngine,
):
self.hs = hs
self._clock = hs.get_clock()
@@ -420,16 +422,6 @@ class DatabasePool:
self._check_safe_to_upsert,
)
- # We define this sequence here so that it can be referenced from both
- # the DataStore and PersistEventStore.
- def get_chain_id_txn(txn):
- txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
- return txn.fetchone()[0]
-
- self.event_chain_id_gen = build_sequence_generator(
- engine, get_chain_id_txn, "event_auth_chain_id"
- )
-
def is_running(self) -> bool:
"""Is the database pool currently running"""
return self._db_pool.running
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a7a11a5bc0..cd1ceac50e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -44,6 +44,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.sequence import SequenceGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util import json_encoder
from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -114,12 +115,6 @@ class PersistEventsStore:
) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
- # The consistency of this cannot be checked when the ID generator is
- # created since the database might not yet be up-to-date.
- self.db_pool.event_chain_id_gen.check_consistency(
- db_conn, "event_auth_chains", "chain_id" # type: ignore
- )
-
# This should only exist on instances that are configured to write
assert (
hs.get_instance_name() in hs.config.worker.writers.events
@@ -485,6 +480,7 @@ class PersistEventsStore:
self._add_chain_cover_index(
txn,
self.db_pool,
+ self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
@@ -495,6 +491,7 @@ class PersistEventsStore:
cls,
txn,
db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
@@ -641,6 +638,7 @@ class PersistEventsStore:
new_chain_tuples = cls._allocate_chain_ids(
txn,
db_pool,
+ event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
@@ -779,6 +777,7 @@ class PersistEventsStore:
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
+ event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
@@ -891,7 +890,7 @@ class PersistEventsStore:
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs.
- newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+ newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 89274e75f7..c1626ccf28 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
PersistEventsStore._add_chain_cover_index(
txn,
self.db_pool,
+ self.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c8850a4707..edbe42f2bf 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import Collection, JsonDict, get_domain_from_id
from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
@@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
self._event_fetch_list = []
self._event_fetch_ongoing = 0
+ # We define this sequence here so that it can be referenced from both
+ # the DataStore and PersistEventStore.
+ def get_chain_id_txn(txn):
+ txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+ return txn.fetchone()[0]
+
+ self.event_chain_id_gen = build_sequence_generator(
+ db_conn,
+ database.engine,
+ get_chain_id_txn,
+ "event_auth_chain_id",
+ table="event_auth_chains",
+ id_column="chain_id",
+ )
+
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == EventsStream.NAME:
self._stream_id_gen.advance(instance_name, token)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d5b5507815..61a7556e56 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -23,7 +23,7 @@ import attr
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.databases.main.stats import StatsStore
from synapse.storage.types import Connection, Cursor
@@ -70,7 +70,12 @@ class TokenLookupResult:
class RegistrationWorkerStore(CacheInvalidationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self.config = hs.config
@@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
# call `find_max_generated_user_id_localpart` each time, which is
# expensive if there are many entries.
self._user_id_seq = build_sequence_generator(
+ db_conn,
database.engine,
find_max_generated_user_id_localpart,
"user_id_seq",
+ table=None,
+ id_column=None,
)
self._account_validity = hs.config.account_validity
@@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
- def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+ def __init__(
+ self,
+ database: DatabasePool,
+ db_conn: LoggingDatabaseConnection,
+ hs: "HomeServer",
+ ):
super().__init__(database, db_conn, hs)
self._clock = hs.get_clock()
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index b16b9905d8..e2240703a7 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -97,10 +97,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
return txn.fetchone()[0]
self._state_group_seq_gen = build_sequence_generator(
- self.database_engine, get_max_state_group_txn, "state_group_id_seq"
- )
- self._state_group_seq_gen.check_consistency(
- db_conn, table="state_groups", id_column="id"
+ db_conn,
+ self.database_engine,
+ get_max_state_group_txn,
+ "state_group_id_seq",
+ table="state_groups",
+ id_column="id",
)
@cached(max_entries=10000, iterable=True)
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 3ea637b281..36a67e7019 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -251,9 +251,14 @@ class LocalSequenceGenerator(SequenceGenerator):
def build_sequence_generator(
+ db_conn: "LoggingDatabaseConnection",
database_engine: BaseDatabaseEngine,
get_first_callback: GetFirstCallbackType,
sequence_name: str,
+ table: Optional[str],
+ id_column: Optional[str],
+ stream_name: Optional[str] = None,
+ positive: bool = True,
) -> SequenceGenerator:
"""Get the best impl of SequenceGenerator available
@@ -265,8 +270,23 @@ def build_sequence_generator(
get_first_callback: a callback which gets the next sequence ID. Used if
we're on sqlite.
sequence_name: the name of a postgres sequence to use.
+ table, id_column, stream_name, positive: If set then `check_consistency`
+ is called on the created sequence. See docstring for
+ `check_consistency` details.
"""
if isinstance(database_engine, PostgresEngine):
- return PostgresSequenceGenerator(sequence_name)
+ seq = PostgresSequenceGenerator(sequence_name) # type: SequenceGenerator
else:
- return LocalSequenceGenerator(get_first_callback)
+ seq = LocalSequenceGenerator(get_first_callback)
+
+ if table:
+ assert id_column
+ seq.check_consistency(
+ db_conn=db_conn,
+ table=table,
+ id_column=id_column,
+ stream_name=stream_name,
+ positive=positive,
+ )
+
+ return seq
|