summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/synapse_port_db68
1 files changed, 62 insertions, 6 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 13c0120bb4..7a638ea8e3 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -22,7 +22,7 @@ import logging
 import sys
 import time
 import traceback
-from typing import Optional
+from typing import Dict, Optional, Set
 
 import yaml
 
@@ -292,6 +292,34 @@ class Porter(object):
 
         return table, already_ported, total_to_port, forward_chunk, backward_chunk
 
+    async def get_table_constraints(self) -> Dict[str, Set[str]]:
+        """Returns a map of tables that have foreign key constraints to tables they depend on.
+        """
+
+        def _get_constraints(txn):
+            # We can pull the information about foreign key constraints out from
+            # the postgres schema tables.
+            sql = """
+                SELECT DISTINCT
+                    tc.table_name,
+                    ccu.table_name AS foreign_table_name
+                FROM
+                    information_schema.table_constraints AS tc
+                    INNER JOIN information_schema.constraint_column_usage AS ccu
+                    USING (table_schema, constraint_name)
+                WHERE tc.constraint_type = 'FOREIGN KEY';
+            """
+            txn.execute(sql)
+
+            results = {}
+            for table, foreign_table in txn:
+                results.setdefault(table, set()).add(foreign_table)
+            return results
+
+        return await self.postgres_store.db_pool.runInteraction(
+            "get_table_constraints", _get_constraints
+        )
+
     async def handle_table(
         self, table, postgres_size, table_size, forward_chunk, backward_chunk
     ):
@@ -619,15 +647,43 @@ class Porter(object):
                     consumeErrors=True,
                 )
             )
+            # Map from table name to args passed to `handle_table`, i.e. a tuple
+            # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`.
+            tables_to_port_info_map = {r[0]: r[1:] for r in setup_res}
 
             # Step 4. Do the copying.
+            #
+            # This is slightly convoluted as we need to ensure tables are ported
+            # in the correct order due to foreign key constraints.
             self.progress.set_state("Copying to postgres")
-            await make_deferred_yieldable(
-                defer.gatherResults(
-                    [run_in_background(self.handle_table, *res) for res in setup_res],
-                    consumeErrors=True,
+
+            constraints = await self.get_table_constraints()
+            tables_ported = set()  # type: Set[str]
+
+            while tables_to_port_info_map:
+                # Pulls out all tables that are still to be ported and which
+                # only depend on tables that are already ported (if any).
+                tables_to_port = [
+                    table
+                    for table in tables_to_port_info_map
+                    if not constraints.get(table, set()) - tables_ported
+                ]
+
+                await make_deferred_yieldable(
+                    defer.gatherResults(
+                        [
+                            run_in_background(
+                                self.handle_table,
+                                table,
+                                *tables_to_port_info_map.pop(table),
+                            )
+                            for table in tables_to_port
+                        ],
+                        consumeErrors=True,
+                    )
                 )
-            )
+
+                tables_ported.update(tables_to_port)
 
             # Step 5. Set up sequences
             self.progress.set_state("Setting up sequence generators")