summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-02-24 10:13:53 +0000
committerGitHub <noreply@github.com>2021-02-24 10:13:53 +0000
commit0b5c967813a8e5f3338b6ecd5f3742e91b8a100b (patch)
tree8001d34fb268c378624382d13a94c6e2c0b047e4 /synapse/storage/databases/main
parentAdd a comment about systemd-python. (#9464) (diff)
downloadsynapse-0b5c967813a8e5f3338b6ecd5f3742e91b8a100b.tar.xz
Refactor to ensure we call check_consistency (#9470)
The idea here is to stop people forgetting to call `check_consistency`. Folks can still just pass in `None` to the new args in `build_sequence_generator`, but hopefully they won't.
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/events.py13
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py1
-rw-r--r--synapse/storage/databases/main/events_worker.py16
-rw-r--r--synapse/storage/databases/main/registration.py19
4 files changed, 39 insertions, 10 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index a7a11a5bc0..cd1ceac50e 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -44,6 +44,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.types import Connection
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
+from synapse.storage.util.sequence import SequenceGenerator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -114,12 +115,6 @@ class PersistEventsStore:
         )  # type: MultiWriterIdGenerator
         self._stream_id_gen = self.store._stream_id_gen  # type: MultiWriterIdGenerator
 
-        # The consistency of this cannot be checked when the ID generator is
-        # created since the database might not yet be up-to-date.
-        self.db_pool.event_chain_id_gen.check_consistency(
-            db_conn, "event_auth_chains", "chain_id"  # type: ignore
-        )
-
         # This should only exist on instances that are configured to write
         assert (
             hs.get_instance_name() in hs.config.worker.writers.events
@@ -485,6 +480,7 @@ class PersistEventsStore:
         self._add_chain_cover_index(
             txn,
             self.db_pool,
+            self.store.event_chain_id_gen,
             event_to_room_id,
             event_to_types,
             event_to_auth_chain,
@@ -495,6 +491,7 @@ class PersistEventsStore:
         cls,
         txn,
         db_pool: DatabasePool,
+        event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
         event_to_auth_chain: Dict[str, List[str]],
@@ -641,6 +638,7 @@ class PersistEventsStore:
         new_chain_tuples = cls._allocate_chain_ids(
             txn,
             db_pool,
+            event_chain_id_gen,
             event_to_room_id,
             event_to_types,
             event_to_auth_chain,
@@ -779,6 +777,7 @@ class PersistEventsStore:
     def _allocate_chain_ids(
         txn,
         db_pool: DatabasePool,
+        event_chain_id_gen: SequenceGenerator,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
         event_to_auth_chain: Dict[str, List[str]],
@@ -891,7 +890,7 @@ class PersistEventsStore:
             chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
 
         # Generate new chain IDs for all unallocated chain IDs.
-        newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
+        newly_allocated_chain_ids = event_chain_id_gen.get_next_mult_txn(
             txn, len(unallocated_chain_ids)
         )
 
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 89274e75f7..c1626ccf28 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -917,6 +917,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
         PersistEventsStore._add_chain_cover_index(
             txn,
             self.db_pool,
+            self.event_chain_id_gen,
             event_to_room_id,
             event_to_types,
             event_to_auth_chain,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index c8850a4707..edbe42f2bf 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -45,6 +45,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
 from synapse.storage.database import DatabasePool
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import Collection, JsonDict, get_domain_from_id
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
@@ -156,6 +157,21 @@ class EventsWorkerStore(SQLBaseStore):
         self._event_fetch_list = []
         self._event_fetch_ongoing = 0
 
+        # We define this sequence here so that it can be referenced from both
+        # the DataStore and PersistEventStore.
+        def get_chain_id_txn(txn):
+            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+            return txn.fetchone()[0]
+
+        self.event_chain_id_gen = build_sequence_generator(
+            db_conn,
+            database.engine,
+            get_chain_id_txn,
+            "event_auth_chain_id",
+            table="event_auth_chains",
+            id_column="chain_id",
+        )
+
     def process_replication_rows(self, stream_name, instance_name, token, rows):
         if stream_name == EventsStream.NAME:
             self._stream_id_gen.advance(instance_name, token)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index d5b5507815..61a7556e56 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -23,7 +23,7 @@ import attr
 from synapse.api.constants import UserTypes
 from synapse.api.errors import Codes, StoreError, SynapseError, ThreepidValidationError
 from synapse.metrics.background_process_metrics import wrap_as_background_process
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.types import Connection, Cursor
@@ -70,7 +70,12 @@ class TokenLookupResult:
 
 
 class RegistrationWorkerStore(CacheInvalidationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self.config = hs.config
@@ -79,9 +84,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         # call `find_max_generated_user_id_localpart` each time, which is
         # expensive if there are many entries.
         self._user_id_seq = build_sequence_generator(
+            db_conn,
             database.engine,
             find_max_generated_user_id_localpart,
             "user_id_seq",
+            table=None,
+            id_column=None,
         )
 
         self._account_validity = hs.config.account_validity
@@ -1036,7 +1044,12 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
 
 class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
-    def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
         super().__init__(database, db_conn, hs)
 
         self._clock = hs.get_clock()