summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/synapse_port_db84
1 files changed, 54 insertions, 30 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 604b961bd2..5ad17aa90f 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -619,7 +619,18 @@ class Porter(object):
                 "create_port_table", create_port_table
             )
 
-            # Step 2. Get tables.
+            # Step 2. Set up sequences
+            #
+            # We do this before porting the tables so that event if we fail half
+            # way through the postgres DB always have sequences that are greater
+            # than their respective tables. If we don't then creating the
+            # `DataStore` object will fail due to the inconsistency.
+            self.progress.set_state("Setting up sequence generators")
+            await self._setup_state_group_id_seq()
+            await self._setup_user_id_seq()
+            await self._setup_events_stream_seqs()
+
+            # Step 3. Get tables.
             self.progress.set_state("Fetching tables")
             sqlite_tables = await self.sqlite_store.db_pool.simple_select_onecol(
                 table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
@@ -634,7 +645,7 @@ class Porter(object):
             tables = set(sqlite_tables) & set(postgres_tables)
             logger.info("Found %d tables", len(tables))
 
-            # Step 3. Figure out what still needs copying
+            # Step 4. Figure out what still needs copying
             self.progress.set_state("Checking on port progress")
             setup_res = await make_deferred_yieldable(
                 defer.gatherResults(
@@ -651,7 +662,7 @@ class Porter(object):
             # of: `postgres_size`, `table_size`, `forward_chunk`, `backward_chunk`.
             tables_to_port_info_map = {r[0]: r[1:] for r in setup_res}
 
-            # Step 4. Do the copying.
+            # Step 5. Do the copying.
             #
             # This is slightly convoluted as we need to ensure tables are ported
             # in the correct order due to foreign key constraints.
@@ -685,12 +696,6 @@ class Porter(object):
 
                 tables_ported.update(tables_to_port)
 
-            # Step 5. Set up sequences
-            self.progress.set_state("Setting up sequence generators")
-            await self._setup_state_group_id_seq()
-            await self._setup_user_id_seq()
-            await self._setup_events_stream_seqs()
-
             self.progress.done()
         except Exception as e:
             global end_error_exec_info
@@ -848,43 +853,62 @@ class Porter(object):
 
         return done, remaining + done
 
-    def _setup_state_group_id_seq(self):
+    async def _setup_state_group_id_seq(self):
+        curr_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="state_groups", keyvalues={}, retcol="MAX(id)", allow_none=True
+        )
+
+        if not curr_id:
+            return
+
         def r(txn):
-            txn.execute("SELECT MAX(id) FROM state_groups")
-            curr_id = txn.fetchone()[0]
-            if not curr_id:
-                return
             next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
 
-        return self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
+        await self.postgres_store.db_pool.runInteraction("setup_state_group_id_seq", r)
+
+    async def _setup_user_id_seq(self):
+        curr_id = await self.sqlite_store.db_pool.runInteraction(
+            "setup_user_id_seq", find_max_generated_user_id_localpart
+        )
 
-    def _setup_user_id_seq(self):
         def r(txn):
-            next_id = find_max_generated_user_id_localpart(txn) + 1
+            next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
 
         return self.postgres_store.db_pool.runInteraction("setup_user_id_seq", r)
 
-    def _setup_events_stream_seqs(self):
-        def r(txn):
-            txn.execute("SELECT MAX(stream_ordering) FROM events")
-            curr_id = txn.fetchone()[0]
-            if curr_id:
-                next_id = curr_id + 1
+    async def _setup_events_stream_seqs(self):
+        """Set the event stream sequences to the correct values.
+        """
+
+        # We get called before we've ported the events table, so we need to
+        # fetch the current positions from the SQLite store.
+        curr_forward_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="events", keyvalues={}, retcol="MAX(stream_ordering)", allow_none=True
+        )
+
+        curr_backward_id = await self.sqlite_store.db_pool.simple_select_one_onecol(
+            table="events",
+            keyvalues={},
+            retcol="MAX(-MIN(stream_ordering), 1)",
+            allow_none=True,
+        )
+
+        def _setup_events_stream_seqs_set_pos(txn):
+            if curr_forward_id:
                 txn.execute(
-                    "ALTER SEQUENCE events_stream_seq RESTART WITH %s", (next_id,)
+                    "ALTER SEQUENCE events_stream_seq RESTART WITH %s",
+                    (curr_forward_id + 1,),
                 )
 
-            txn.execute("SELECT GREATEST(-MIN(stream_ordering), 1) FROM events")
-            curr_id = txn.fetchone()[0]
-            next_id = curr_id + 1
             txn.execute(
-                "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s", (next_id,),
+                "ALTER SEQUENCE events_backfill_stream_seq RESTART WITH %s",
+                (curr_backward_id + 1,),
             )
 
-        return self.postgres_store.db_pool.runInteraction(
-            "_setup_events_stream_seqs", r
+        return await self.postgres_store.db_pool.runInteraction(
+            "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos,
         )