summary refs log tree commit diff
path: root/scripts/synapse_port_db
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/synapse_port_db')
-rwxr-xr-xscripts/synapse_port_db166
1 files changed, 124 insertions, 42 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index ae2887b7d2..5ad17aa90f 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -22,6 +22,7 @@ import logging
 import sys
 import time
 import traceback
+from typing import Dict, Optional, Set
 
 import yaml
 
@@ -39,6 +40,7 @@ from synapse.storage.database import DatabasePool, make_conn
 from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
 from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
 from synapse.storage.databases.main.devices import DeviceBackgroundUpdateStore
+from synapse.storage.databases.main.end_to_end_keys import EndToEndKeyBackgroundStore
 from synapse.storage.databases.main.events_bg_updates import (
     EventsBackgroundUpdatesStore,
 )
@@ -90,6 +92,7 @@ BOOLEAN_COLUMNS = {
     "room_stats_state": ["is_federatable"],
     "local_media_repository": ["safe_from_quarantine"],
     "users": ["shadow_banned"],
+    "e2e_fallback_keys_json": ["used"],
 }
 
 
@@ -151,7 +154,7 @@ IGNORED_TABLES = {
 
 # Error returned by the run function. Used at the top-level part of the script to
 # handle errors and return codes.
-end_error = None
+end_error = None  # type: Optional[str]
 # The exec_info for the error, if any. If error is defined but not exec_info the script
 # will show only the error message without the stacktrace, if exec_info is defined but
 # not the error then the script will show nothing outside of what's printed in the run
@@ -172,6 +175,7 @@ class Store(
     StateBackgroundUpdateStore,
     MainStateBackgroundUpdateStore,
     UserDirectoryBackgroundUpdateStore,
+    EndToEndKeyBackgroundStore,
     StatsStore,
 ):
     def execute(self, f, *args, **kwargs):
@@ -288,6 +292,34 @@ class Porter(object):
 
         return table, already_ported, total_to_port, forward_chunk, backward_chunk
 
+    async def get_table_constraints(self) -> Dict[str, Set[str]]:
+        """Returns a map of tables that have foreign key constraints to tables they depend on.
+        """
+
+        def _get_constraints(txn):
+            # We can pull the information about foreign key constraints out from
+            # the postgres schema tables.
+            sql = """
+                SELECT DISTINCT
+                    tc.table_name,
+                    ccu.table_name AS foreign_table_name
+                FROM
+                    information_schema.table_constraints AS tc
+                    INNER JOIN information_schema.constraint_column_usage AS ccu
+                    USING (table_schema, constraint_name)
+                WHERE tc.constraint_type = 'FOREIGN KEY';
+            """
+            txn.execute(sql)
+
+            results = {}
+            for table, foreign_table in txn:
+                results.setdefault(table, set()).add(foreign_table)
+            return results
+
+        return await self.postgres_store.db_pool.runInteraction(
+            "get_table_constraints", _get_constraints
+        )
+
     async def handle_table(
         self, table, postgres_size, table_size, forward_chunk, backward_chunk
     ):
@@ -489,7 +521,7 @@ class Porter(object):
 
         hs = MockHomeserver(self.hs_config)
 
-        with make_conn(db_config, engine) as db_conn:
+        with make_conn(db_config, engine, "portdb") as db_conn:
             engine.check_database(
                 db_conn, allow_outdated_version=allow_outdated_version
             )
@@ -587,7 +619,18 @@ class Porter(object):
                 "create_port_table", create_port_table
             )
 
-            # Step 2. Get tables.
+            # Step 2. Set up sequences
+            #
+            # We do this before porting the tables so that event if we fail half
+            # way through the postgres DB always have sequences that are greater
+            # than their respective tables. If we don't then creating the
+            # `DataStore` object will fail due to the inconsistency.
+            self.progress.set_state("Setting up sequence generators")
+            await self._setup_state_group_id_seq()
+            await self._setup_user_id_seq()
+            await self._setup_events_stream_seqs()
+
+            # Step 3. Get tables.
             self.progress.set_state("Fetching tables")
             sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol(
                 table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
@@ -602,7 +645,7 @@ class Porter(object):
             tables = set(sqlite_tables) & set(postgres_tables)
             logger.info("Found %d tables", len(tables))
 
-            # Step 3. Figure out what still needs copying
+            # Step 4. Figure out what still needs copying
             self.progress.set_state("Checking on port progress")
             setup_res = await make_deferred_yieldable(
                 defer.gatherResults(
@@ -615,26 +658,48 @@ class Porter(object):
                     consumeErrors=True,
                 )
             )
-
-            # Step 4. Do the copying.
+            # Map from table name to args passed to `handle_table`, i.e. a tuple
+            # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`.
+            tables_to_port_info_map = {r[0]: r[1:] for r in setup_res}
+
+            # Step 5. Do the copying.
+            #
+            # This is slightly convoluted as we need to ensure tables are ported
+            # in the correct order due to foreign key constraints.
             self.progress.set_state("Copying to postgres")
-            await make_deferred_yieldable(
-                defer.gatherResults(
-                    [run_in_background(self.handle_table, *res) for res in setup_res],
-                    consumeErrors=True,
+
+            constraints = await self.get_table_constraints()
+            tables_ported = set()  # type: Set[str]
+
+            while tables_to_port_info_map:
+                # Pulls out all tables that are still to be ported and which
+                # only depend on tables that are already ported (if any).
+                tables_to_port = [
+                    table
+                    for table in tables_to_port_info_map
+                    if not constraints.get(table, set()) - tables_ported
+                ]
+
+                await make_deferred_yieldable(
+                    defer.gatherResults(
+                        [
+                            run_in_background(
+                                self.handle_table,
+                                table,
+                                *tables_to_port_info_map.pop(table),
+                            )
+                            for table in tables_to_port
+                        ],
+                        consumeErrors=True,
+                    )
                 )
-            )
 
-            # Step 5. Set up sequences
-            self.progress.set_state("Setting up sequence generators")
-            await self._setup_state_group_id_seq()
-            await self._setup_user_id_seq()
-            await self._setup_events_stream_seqs()
+                tables_ported.update(tables_to_port)
 
             self.progress.done()
         except Exception as e:
             global end_error_exec_info
-            end_error = e
+            end_error = str(e)
             end_error_exec_info = sys.exc_info()
             logger.exception("")
         finally:
@@ -788,45 +853,62 @@ class Porter(object):
 
         return done, remaining + done
 
-    def _setup_state_group_id_seq(self):
+    async def _setup_state_group_id_seq(self):
+        curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
+        )
+
+        if not curr_id:
+            return
+
         def r(txn):
-            txn.execute("SELECT MAX(id) FROM state_groups")
-            curr_id = txn.fetchone()[0]
-            if not curr_id:
-                return
             next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
 
-        return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
+        await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
+
+    async def _setup_user_id_seq(self):
+        curr_id = await self.sqlite_store.db_pool.runInteraction(
+            "setup_user_id_seq", find_max_generated_user_id_localpart
+        )
 
-    def _setup_user_id_seq(self):
         def r(txn):
-            next_id = find_max_generated_user_id_localpart(txn) + 1
+            next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
 
         return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
 
-    def _setup_events_stream_seqs(self):
-        def r(txn):
-            txn.execute("SELECT MAX(stream_ordering) FROM events")
-            curr_id = txn.fetchone()[0]
-            if curr_id:
-                next_id = curr_id + 1
-                txn.execute(
-                    "ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,)
-                )
+    async def _setup_events_stream_seqs(self):
+        """Set the event stream sequences to the correct values.
+        """
+
+        # We get called before we've ported the events table, so we need to
+        # fetch the current positions from the SQLite store.
+        curr_forward_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="events", keyvalues={}, retcol="MAX(stream_ordering)", allow_none=True
+        )
 
-            txn.execute("SELECT -MIN(stream_ordering) FROM events")
-            curr_id = txn.fetchone()[0]
-            if curr_id:
-                next_id = curr_id + 1
+        curr_backward_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="events",
+            keyvalues={},
+            retcol="MAX(-MIN(stream_ordering), 1)",
+            allow_none=True,
+        )
+
+        def _setup_events_stream_seqs_set_pos(txn):
+            if curr_forward_id:
                 txn.execute(
-                    "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s",
-                    (next_id,),
+                    "ALTER SEQUENCE events_stream_seq RESTART WITH %s",
+                    (curr_forward_id + 1,),
                 )
 
-        return self.postgres_store.db_pool.runInteraction(
-            "_setup_events_stream_seqs", r
+            txn.execute(
+                "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s",
+                (curr_backward_id + 1,),
+            )
+
+        return await self.postgres_store.db_pool.runInteraction(
+            "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
         )