diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 13c0120bb4..7a638ea8e3 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -22,7 +22,7 @@ import logging
import sys
import time
import traceback
-from typing import Optional
+from typing import Dict, Optional, Set
import yaml
@@ -292,6 +292,34 @@ class Porter(object):
return table, already_ported, total_to_port, forward_chunk, backward_chunk
+ async def get_table_constraints(self) -> Dict[str, Set[str]]:
+ """Returns a map of tables that have foreign key constraints to tables they depend on.
+ """
+
+ def _get_constraints(txn):
+ # We can pull the information about foreign key constraints out from
+ # the postgres schema tables.
+ sql = """
+ SELECT DISTINCT
+ tc.table_name,
+ ccu.table_name AS foreign_table_name
+ FROM
+ information_schema.table_constraints AS tc
+ INNER JOIN information_schema.constraint_column_usage AS ccu
+ USING (table_schema, constraint_name)
+ WHERE tc.constraint_type = 'FOREIGN KEY';
+ """
+ txn.execute(sql)
+
+ results = {}
+ for table, foreign_table in txn:
+ results.setdefault(table, set()).add(foreign_table)
+ return results
+
+ return await self.postgres_store.db_pool.runInteraction(
+ "get_table_constraints", _get_constraints
+ )
+
async def handle_table(
self, table, postgres_size, table_size, forward_chunk, backward_chunk
):
@@ -619,15 +647,43 @@ class Porter(object):
consumeErrors=True,
)
)
+ # Map from table name to args passed to `handle_table`, i.e. a tuple
+ # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`.
+ tables_to_port_info_map = {r[0]: r[1:] for r in setup_res}
# Step 4. Do the copying.
+ #
+ # This is slightly convoluted as we need to ensure tables are ported
+ # in the correct order due to foreign key constraints.
self.progress.set_state("Copying to postgres")
- await make_deferred_yieldable(
- defer.gatherResults(
- [run_in_background(self.handle_table, *res) for res in setup_res],
- consumeErrors=True,
+
+ constraints = await self.get_table_constraints()
+ tables_ported = set() # type: Set[str]
+
+ while tables_to_port_info_map:
+ # Pulls out all tables that are still to be ported and which
+ # only depend on tables that are already ported (if any).
+ tables_to_port = [
+ table
+ for table in tables_to_port_info_map
+ if not constraints.get(table, set()) - tables_ported
+ ]
+
+ await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self.handle_table,
+ table,
+ *tables_to_port_info_map.pop(table),
+ )
+ for table in tables_to_port
+ ],
+ consumeErrors=True,
+ )
)
- )
+
+ tables_ported.update(tables_to_port)
# Step 5. Set up sequences
self.progress.set_state("Setting up sequence generators")
|