summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/database.py11
-rw-r--r--synapse/storage/databases/main/events.py36
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py9
-rw-r--r--synapse/storage/util/sequence.py10
4 files changed, 37 insertions, 29 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 6cfadc2b4e..a19d65ad23 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -49,6 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.types import Connection, Cursor
+from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import Collection
 
 # python 3 does not have a maximum int value
@@ -412,6 +413,16 @@ class DatabasePool:
                 self._check_safe_to_upsert,
             )
 
+        # We define this sequence here so that it can be referenced from both
+        # the DataStore and PersistEventStore.
+        def get_chain_id_txn(txn):
+            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
+            return txn.fetchone()[0]
+
+        self.event_chain_id_gen = build_sequence_generator(
+            engine, get_chain_id_txn, "event_auth_chain_id"
+        )
+
     def is_running(self) -> bool:
         """Is the database pool currently running
         """
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index e0fbcc58cf..3216b3f3c8 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -43,7 +43,6 @@ from synapse.storage._base import db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, LoggingTransaction
 from synapse.storage.databases.main.search import SearchEntry
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import StateMap, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.iterutils import batch_iter, sorted_topologically
@@ -100,14 +99,6 @@ class PersistEventsStore:
         self._clock = hs.get_clock()
         self._instance_name = hs.get_instance_name()
 
-        def get_chain_id_txn(txn):
-            txn.execute("SELECT COALESCE(max(chain_id), 0) FROM event_auth_chains")
-            return txn.fetchone()[0]
-
-        self._event_chain_id_gen = build_sequence_generator(
-            db.engine, get_chain_id_txn, "event_auth_chain_id"
-        )
-
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
         self.is_mine_id = hs.is_mine_id
 
@@ -479,12 +470,13 @@ class PersistEventsStore:
         event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
 
         self._add_chain_cover_index(
-            txn, event_to_room_id, event_to_types, event_to_auth_chain
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
         )
 
+    @staticmethod
     def _add_chain_cover_index(
-        self,
         txn,
+        db_pool: DatabasePool,
         event_to_room_id: Dict[str, str],
         event_to_types: Dict[str, Tuple[str, str]],
         event_to_auth_chain: Dict[str, List[str]],
@@ -507,7 +499,7 @@ class PersistEventsStore:
         # We check if there are any events that need to be handled in the rooms
         # we're looking at. These should just be out of band memberships, where
         # we didn't have the auth chain when we first persisted.
-        rows = self.db_pool.simple_select_many_txn(
+        rows = db_pool.simple_select_many_txn(
             txn,
             table="event_auth_chain_to_calculate",
             keyvalues={},
@@ -523,7 +515,7 @@ class PersistEventsStore:
             # (We could pull out the auth events for all rows at once using
             # simple_select_many, but this case happens rarely and almost always
             # with a single row.)
-            auth_events = self.db_pool.simple_select_onecol_txn(
+            auth_events = db_pool.simple_select_onecol_txn(
                 txn, "event_auth", keyvalues={"event_id": event_id}, retcol="auth_id",
             )
 
@@ -572,9 +564,7 @@ class PersistEventsStore:
 
                     events_to_calc_chain_id_for.add(auth_id)
 
-                    event_to_auth_chain[
-                        auth_id
-                    ] = self.db_pool.simple_select_onecol_txn(
+                    event_to_auth_chain[auth_id] = db_pool.simple_select_onecol_txn(
                         txn,
                         "event_auth",
                         keyvalues={"event_id": auth_id},
@@ -606,7 +596,7 @@ class PersistEventsStore:
                     room_id = event_to_room_id.get(event_id)
                     if room_id:
                         e_type, state_key = event_to_types[event_id]
-                        self.db_pool.simple_insert_txn(
+                        db_pool.simple_insert_txn(
                             txn,
                             table="event_auth_chain_to_calculate",
                             values={
@@ -651,7 +641,7 @@ class PersistEventsStore:
                 proposed_new_id = existing_chain_id[0]
                 proposed_new_seq = existing_chain_id[1] + 1
                 if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
-                    already_allocated = self.db_pool.simple_select_one_onecol_txn(
+                    already_allocated = db_pool.simple_select_one_onecol_txn(
                         txn,
                         table="event_auth_chains",
                         keyvalues={
@@ -672,14 +662,14 @@ class PersistEventsStore:
                         )
 
             if not new_chain_tuple:
-                new_chain_tuple = (self._event_chain_id_gen.get_next_id_txn(txn), 1)
+                new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
 
             chains_tuples_allocated.add(new_chain_tuple)
 
             chain_map[event_id] = new_chain_tuple
             new_chain_tuples[event_id] = new_chain_tuple
 
-        self.db_pool.simple_insert_many_txn(
+        db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chains",
             values=[
@@ -688,7 +678,7 @@ class PersistEventsStore:
             ],
         )
 
-        self.db_pool.simple_delete_many_txn(
+        db_pool.simple_delete_many_txn(
             txn,
             table="event_auth_chain_to_calculate",
             keyvalues={},
@@ -721,7 +711,7 @@ class PersistEventsStore:
         # Step 1, fetch all existing links from all the chains we've seen
         # referenced.
         chain_links = _LinkMap()
-        rows = self.db_pool.simple_select_many_txn(
+        rows = db_pool.simple_select_many_txn(
             txn,
             table="event_auth_chain_links",
             column="origin_chain_id",
@@ -785,7 +775,7 @@ class PersistEventsStore:
                         (chain_id, sequence_number), (target_id, target_seq)
                     )
 
-        self.db_pool.simple_insert_many_txn(
+        db_pool.simple_insert_many_txn(
             txn,
             table="event_auth_chain_links",
             values=[
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 90a40a92b4..7128dc1742 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -21,6 +21,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 
@@ -833,8 +834,12 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             #
             # Annoyingly we need to gut wrench into the persit event store so that
             # we can reuse the function to calculate the chain cover for rooms.
-            self.hs.get_datastores().persist_events._add_chain_cover_index(
-                txn, event_to_room_id, event_to_types, event_to_auth_chain,
+            PersistEventsStore._add_chain_cover_index(
+                txn,
+                self.db_pool,
+                event_to_room_id,
+                event_to_types,
+                event_to_auth_chain,
             )
 
             return new_last_depth, new_last_stream, count
diff --git a/synapse/storage/util/sequence.py b/synapse/storage/util/sequence.py
index 4386b6101e..412df6b8ef 100644
--- a/synapse/storage/util/sequence.py
+++ b/synapse/storage/util/sequence.py
@@ -15,9 +15,8 @@
 import abc
 import logging
 import threading
-from typing import Callable, List, Optional
+from typing import TYPE_CHECKING, Callable, List, Optional
 
-from synapse.storage.database import LoggingDatabaseConnection
 from synapse.storage.engines import (
     BaseDatabaseEngine,
     IncorrectDatabaseSetup,
@@ -25,6 +24,9 @@ from synapse.storage.engines import (
 )
 from synapse.storage.types import Connection, Cursor
 
+if TYPE_CHECKING:
+    from synapse.storage.database import LoggingDatabaseConnection
+
 logger = logging.getLogger(__name__)
 
 
@@ -55,7 +57,7 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
     @abc.abstractmethod
     def check_consistency(
         self,
-        db_conn: LoggingDatabaseConnection,
+        db_conn: "LoggingDatabaseConnection",
         table: str,
         id_column: str,
         positive: bool = True,
@@ -88,7 +90,7 @@ class PostgresSequenceGenerator(SequenceGenerator):
 
     def check_consistency(
         self,
-        db_conn: LoggingDatabaseConnection,
+        db_conn: "LoggingDatabaseConnection",
         table: str,
         id_column: str,
         positive: bool = True,