summary refs log tree commit diff
path: root/scripts/synapse_port_db
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/synapse_port_db')
-rwxr-xr-xscripts/synapse_port_db447
1 files changed, 271 insertions, 176 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index b6ba19c776..9a0fbc61d8 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -2,6 +2,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
 # Copyright 2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -26,12 +27,44 @@ from six import string_types
 
 import yaml
 
-from twisted.enterprise import adbapi
 from twisted.internet import defer, reactor
 
-from synapse.storage._base import LoggingTransaction, SQLBaseStore
+import synapse
+from synapse.config.database import DatabaseConnectionConfig
+from synapse.config.homeserver import HomeServerConfig
+from synapse.logging.context import (
+    LoggingContext,
+    make_deferred_yieldable,
+    run_in_background,
+)
+from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore
+from synapse.storage.data_stores.main.deviceinbox import (
+    DeviceInboxBackgroundUpdateStore,
+)
+from synapse.storage.data_stores.main.devices import DeviceBackgroundUpdateStore
+from synapse.storage.data_stores.main.events_bg_updates import (
+    EventsBackgroundUpdatesStore,
+)
+from synapse.storage.data_stores.main.media_repository import (
+    MediaRepositoryBackgroundUpdateStore,
+)
+from synapse.storage.data_stores.main.registration import (
+    RegistrationBackgroundUpdateStore,
+)
+from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
+from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
+from synapse.storage.data_stores.main.search import SearchBackgroundUpdateStore
+from synapse.storage.data_stores.main.state import MainStateBackgroundUpdateStore
+from synapse.storage.data_stores.main.stats import StatsStore
+from synapse.storage.data_stores.main.user_directory import (
+    UserDirectoryBackgroundUpdateStore,
+)
+from synapse.storage.data_stores.state.bg_updates import StateBackgroundUpdateStore
+from synapse.storage.database import Database, make_conn
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
+from synapse.util import Clock
+from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse_port_db")
 
@@ -43,6 +76,7 @@ BOOLEAN_COLUMNS = {
     "presence_list": ["accepted"],
     "presence_stream": ["currently_active"],
     "public_room_list_stream": ["visibility"],
+    "devices": ["hidden"],
     "device_lists_outbound_pokes": ["sent"],
     "users_who_share_rooms": ["share_private"],
     "groups": ["is_public"],
@@ -55,6 +89,8 @@ BOOLEAN_COLUMNS = {
     "local_group_membership": ["is_publicised", "is_admin"],
     "e2e_room_keys": ["is_verified"],
     "account_validity": ["email_sent"],
+    "redactions": ["have_censored"],
+    "room_stats_state": ["is_federatable"],
 }
 
 
@@ -86,79 +122,47 @@ APPEND_ONLY_TABLES = [
     "presence_stream",
     "push_rules_stream",
     "ex_outlier_stream",
-    "cache_invalidation_stream",
+    "cache_invalidation_stream_by_instance",
     "public_room_list_stream",
     "state_group_edges",
     "stream_ordering_to_exterm",
 ]
 
 
+# Error returned by the run function. Used at the top-level part of the script to
+# handle errors and return codes.
+end_error = None
+# The exec_info for the error, if any. If error is defined but not exec_info the script
+# will show only the error message without the stacktrace, if exec_info is defined but
+# not the error then the script will show nothing outside of what's printed in the run
+# function. If both are defined, the script will print both the error and the stacktrace.
 end_error_exec_info = None
 
 
-class Store(object):
-    """This object is used to pull out some of the convenience API from the
-    Storage layer.
-
-    *All* database interactions should go through this object.
-    """
-
-    def __init__(self, db_pool, engine):
-        self.db_pool = db_pool
-        self.database_engine = engine
-
-    _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
-    _simple_insert = SQLBaseStore.__dict__["_simple_insert"]
-
-    _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
-    _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
-    _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
-    _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
-    _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
-    _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
-        "_simple_select_one_onecol_txn"
-    ]
-
-    _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
-    _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
-    _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
-
-    def runInteraction(self, desc, func, *args, **kwargs):
-        def r(conn):
-            try:
-                i = 0
-                N = 5
-                while True:
-                    try:
-                        txn = conn.cursor()
-                        return func(
-                            LoggingTransaction(txn, desc, self.database_engine, [], []),
-                            *args,
-                            **kwargs
-                        )
-                    except self.database_engine.module.DatabaseError as e:
-                        if self.database_engine.is_deadlock(e):
-                            logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
-                            if i < N:
-                                i += 1
-                                conn.rollback()
-                                continue
-                        raise
-            except Exception as e:
-                logger.debug("[TXN FAIL] {%s} %s", desc, e)
-                raise
-
-        return self.db_pool.runWithConnection(r)
-
+class Store(
+    ClientIpBackgroundUpdateStore,
+    DeviceInboxBackgroundUpdateStore,
+    DeviceBackgroundUpdateStore,
+    EventsBackgroundUpdatesStore,
+    MediaRepositoryBackgroundUpdateStore,
+    RegistrationBackgroundUpdateStore,
+    RoomBackgroundUpdateStore,
+    RoomMemberBackgroundUpdateStore,
+    SearchBackgroundUpdateStore,
+    StateBackgroundUpdateStore,
+    MainStateBackgroundUpdateStore,
+    UserDirectoryBackgroundUpdateStore,
+    StatsStore,
+):
     def execute(self, f, *args, **kwargs):
-        return self.runInteraction(f.__name__, f, *args, **kwargs)
+        return self.db.runInteraction(f.__name__, f, *args, **kwargs)
 
     def execute_sql(self, sql, *args):
         def r(txn):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        return self.runInteraction("execute_sql", r)
+        return self.db.runInteraction("execute_sql", r)
 
     def insert_many_txn(self, txn, table, headers, rows):
         sql = "INSERT INTO %s (%s) VALUES (%s)" % (
@@ -173,16 +177,37 @@ class Store(object):
             logger.exception("Failed to insert: %s", table)
             raise
 
+    def set_room_is_public(self, room_id, is_public):
+        raise Exception(
+            "Attempt to set room_is_public during port_db: database not empty?"
+        )
+
+
+class MockHomeserver:
+    def __init__(self, config):
+        self.clock = Clock(reactor)
+        self.config = config
+        self.hostname = config.server_name
+        self.version_string = "Synapse/" + get_version_string(synapse)
+
+    def get_clock(self):
+        return self.clock
+
+    def get_reactor(self):
+        return reactor
+
+    def get_instance_name(self):
+        return "master"
+
 
 class Porter(object):
     def __init__(self, **kwargs):
         self.__dict__.update(kwargs)
 
-    @defer.inlineCallbacks
-    def setup_table(self, table):
+    async def setup_table(self, table):
         if table in APPEND_ONLY_TABLES:
             # It's safe to just carry on inserting.
-            row = yield self.postgres_store._simple_select_one(
+            row = await self.postgres_store.db.simple_select_one(
                 table="port_from_sqlite3",
                 keyvalues={"table_name": table},
                 retcols=("forward_rowid", "backward_rowid"),
@@ -192,12 +217,14 @@ class Porter(object):
             total_to_port = None
             if row is None:
                 if table == "sent_transactions":
-                    forward_chunk, already_ported, total_to_port = (
-                        yield self._setup_sent_transactions()
-                    )
+                    (
+                        forward_chunk,
+                        already_ported,
+                        total_to_port,
+                    ) = await self._setup_sent_transactions()
                     backward_chunk = 0
                 else:
-                    yield self.postgres_store._simple_insert(
+                    await self.postgres_store.db.simple_insert(
                         table="port_from_sqlite3",
                         values={
                             "table_name": table,
@@ -214,7 +241,7 @@ class Porter(object):
                 backward_chunk = row["backward_rowid"]
 
             if total_to_port is None:
-                already_ported, total_to_port = yield self._get_total_count_to_port(
+                already_ported, total_to_port = await self._get_total_count_to_port(
                     table, forward_chunk, backward_chunk
                 )
         else:
@@ -225,9 +252,9 @@ class Porter(object):
                 )
                 txn.execute("TRUNCATE %s CASCADE" % (table,))
 
-            yield self.postgres_store.execute(delete_all)
+            await self.postgres_store.execute(delete_all)
 
-            yield self.postgres_store._simple_insert(
+            await self.postgres_store.db.simple_insert(
                 table="port_from_sqlite3",
                 values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
             )
@@ -235,16 +262,13 @@ class Porter(object):
             forward_chunk = 1
             backward_chunk = 0
 
-            already_ported, total_to_port = yield self._get_total_count_to_port(
+            already_ported, total_to_port = await self._get_total_count_to_port(
                 table, forward_chunk, backward_chunk
             )
 
-        defer.returnValue(
-            (table, already_ported, total_to_port, forward_chunk, backward_chunk)
-        )
+        return table, already_ported, total_to_port, forward_chunk, backward_chunk
 
-    @defer.inlineCallbacks
-    def handle_table(
+    async def handle_table(
         self, table, postgres_size, table_size, forward_chunk, backward_chunk
     ):
         logger.info(
@@ -262,7 +286,7 @@ class Porter(object):
         self.progress.add_table(table, postgres_size, table_size)
 
         if table == "event_search":
-            yield self.handle_search_table(
+            await self.handle_search_table(
                 postgres_size, table_size, forward_chunk, backward_chunk
             )
             return
@@ -281,7 +305,7 @@ class Porter(object):
         if table == "user_directory_stream_pos":
             # We need to make sure there is a single row, `(X, null), as that is
             # what synapse expects to be there.
-            yield self.postgres_store._simple_insert(
+            await self.postgres_store.db.simple_insert(
                 table=table, values={"stream_id": None}
             )
             self.progress.update(table, table_size)  # Mark table as done
@@ -322,7 +346,9 @@ class Porter(object):
 
                 return headers, forward_rows, backward_rows
 
-            headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
+            headers, frows, brows = await self.sqlite_store.db.runInteraction(
+                "select", r
+            )
 
             if frows or brows:
                 if frows:
@@ -336,7 +362,7 @@ class Porter(object):
                 def insert(txn):
                     self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
 
-                    self.postgres_store._simple_update_one_txn(
+                    self.postgres_store.db.simple_update_one_txn(
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": table},
@@ -346,7 +372,7 @@ class Porter(object):
                         },
                     )
 
-                yield self.postgres_store.execute(insert)
+                await self.postgres_store.execute(insert)
 
                 postgres_size += len(rows)
 
@@ -354,8 +380,7 @@ class Porter(object):
             else:
                 return
 
-    @defer.inlineCallbacks
-    def handle_search_table(
+    async def handle_search_table(
         self, postgres_size, table_size, forward_chunk, backward_chunk
     ):
         select = (
@@ -375,7 +400,7 @@ class Porter(object):
 
                 return headers, rows
 
-            headers, rows = yield self.sqlite_store.runInteraction("select", r)
+            headers, rows = await self.sqlite_store.db.runInteraction("select", r)
 
             if rows:
                 forward_chunk = rows[-1][0] + 1
@@ -392,8 +417,8 @@ class Porter(object):
                     rows_dict = []
                     for row in rows:
                         d = dict(zip(headers, row))
-                        if "\0" in d['value']:
-                            logger.warn('dropping search row %s', d)
+                        if "\0" in d["value"]:
+                            logger.warning("dropping search row %s", d)
                         else:
                             rows_dict.append(d)
 
@@ -413,7 +438,7 @@ class Porter(object):
                         ],
                     )
 
-                    self.postgres_store._simple_update_one_txn(
+                    self.postgres_store.db.simple_update_one_txn(
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": "event_search"},
@@ -423,7 +448,7 @@ class Porter(object):
                         },
                     )
 
-                yield self.postgres_store.execute(insert)
+                await self.postgres_store.execute(insert)
 
                 postgres_size += len(rows)
 
@@ -432,44 +457,85 @@ class Porter(object):
             else:
                 return
 
-    def setup_db(self, db_config, database_engine):
-        db_conn = database_engine.module.connect(
-            **{
-                k: v
-                for k, v in db_config.get("args", {}).items()
-                if not k.startswith("cp_")
-            }
-        )
+    def build_db_store(
+        self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False,
+    ):
+        """Builds and returns a database store using the provided configuration.
 
-        prepare_database(db_conn, database_engine, config=None)
+        Args:
+            db_config: The database configuration
+            allow_outdated_version: True to suppress errors about the database server
+                version being too old to run a complete synapse
 
-        db_conn.commit()
+        Returns:
+            The built Store object.
+        """
+        self.progress.set_state("Preparing %s" % db_config.config["name"])
 
-    @defer.inlineCallbacks
-    def run(self):
-        try:
-            sqlite_db_pool = adbapi.ConnectionPool(
-                self.sqlite_config["name"], **self.sqlite_config["args"]
+        engine = create_engine(db_config.config)
+
+        hs = MockHomeserver(self.hs_config)
+
+        with make_conn(db_config, engine) as db_conn:
+            engine.check_database(
+                db_conn, allow_outdated_version=allow_outdated_version
             )
+            prepare_database(db_conn, engine, config=self.hs_config)
+            store = Store(Database(hs, db_config, engine), db_conn, hs)
+            db_conn.commit()
+
+        return store
 
-            postgres_db_pool = adbapi.ConnectionPool(
-                self.postgres_config["name"], **self.postgres_config["args"]
+    async def run_background_updates_on_postgres(self):
+        # Manually apply all background updates on the PostgreSQL database.
+        postgres_ready = (
+            await self.postgres_store.db.updates.has_completed_background_updates()
+        )
+
+        if not postgres_ready:
+            # Only say that we're running background updates when there are background
+            # updates to run.
+            self.progress.set_state("Running background updates on PostgreSQL")
+
+        while not postgres_ready:
+            await self.postgres_store.db.updates.do_next_background_update(100)
+            postgres_ready = await (
+                self.postgres_store.db.updates.has_completed_background_updates()
             )
 
-            sqlite_engine = create_engine(sqlite_config)
-            postgres_engine = create_engine(postgres_config)
+    async def run(self):
+        """Ports the SQLite database to a PostgreSQL database.
 
-            self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
-            self.postgres_store = Store(postgres_db_pool, postgres_engine)
+        When a fatal error is met, its message is assigned to the global "end_error"
+        variable. When this error comes with a stacktrace, its exec_info is assigned to
+        the global "end_error_exec_info" variable.
+        """
+        global end_error
+
+        try:
+            # we allow people to port away from outdated versions of sqlite.
+            self.sqlite_store = self.build_db_store(
+                DatabaseConnectionConfig("master-sqlite", self.sqlite_config),
+                allow_outdated_version=True,
+            )
 
-            yield self.postgres_store.execute(postgres_engine.check_database)
+            # Check if all background updates are done, abort if not.
+            updates_complete = (
+                await self.sqlite_store.db.updates.has_completed_background_updates()
+            )
+            if not updates_complete:
+                end_error = (
+                    "Pending background updates exist in the SQLite3 database."
+                    " Please start Synapse again and wait until every update has finished"
+                    " before running this script.\n"
+                )
+                return
 
-            # Step 1. Set up databases.
-            self.progress.set_state("Preparing SQLite3")
-            self.setup_db(sqlite_config, sqlite_engine)
+            self.postgres_store = self.build_db_store(
+                self.hs_config.get_single_database()
+            )
 
-            self.progress.set_state("Preparing PostgreSQL")
-            self.setup_db(postgres_config, postgres_engine)
+            await self.run_background_updates_on_postgres()
 
             self.progress.set_state("Creating port tables")
 
@@ -497,22 +563,22 @@ class Porter(object):
                 )
 
             try:
-                yield self.postgres_store.runInteraction("alter_table", alter_table)
+                await self.postgres_store.db.runInteraction("alter_table", alter_table)
             except Exception:
                 # On Error Resume Next
                 pass
 
-            yield self.postgres_store.runInteraction(
+            await self.postgres_store.db.runInteraction(
                 "create_port_table", create_port_table
             )
 
             # Step 2. Get tables.
             self.progress.set_state("Fetching tables")
-            sqlite_tables = yield self.sqlite_store._simple_select_onecol(
+            sqlite_tables = await self.sqlite_store.db.simple_select_onecol(
                 table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
             )
 
-            postgres_tables = yield self.postgres_store._simple_select_onecol(
+            postgres_tables = await self.postgres_store.db.simple_select_onecol(
                 table="information_schema.tables",
                 keyvalues={},
                 retcol="distinct table_name",
@@ -523,28 +589,34 @@ class Porter(object):
 
             # Step 3. Figure out what still needs copying
             self.progress.set_state("Checking on port progress")
-            setup_res = yield defer.gatherResults(
-                [
-                    self.setup_table(table)
-                    for table in tables
-                    if table not in ["schema_version", "applied_schema_deltas"]
-                    and not table.startswith("sqlite_")
-                ],
-                consumeErrors=True,
+            setup_res = await make_deferred_yieldable(
+                defer.gatherResults(
+                    [
+                        run_in_background(self.setup_table, table)
+                        for table in tables
+                        if table not in ["schema_version", "applied_schema_deltas"]
+                        and not table.startswith("sqlite_")
+                    ],
+                    consumeErrors=True,
+                )
             )
 
             # Step 4. Do the copying.
             self.progress.set_state("Copying to postgres")
-            yield defer.gatherResults(
-                [self.handle_table(*res) for res in setup_res], consumeErrors=True
+            await make_deferred_yieldable(
+                defer.gatherResults(
+                    [run_in_background(self.handle_table, *res) for res in setup_res],
+                    consumeErrors=True,
+                )
             )
 
             # Step 5. Do final post-processing
-            yield self._setup_state_group_id_seq()
+            await self._setup_state_group_id_seq()
 
             self.progress.done()
-        except Exception:
+        except Exception as e:
             global end_error_exec_info
+            end_error = e
             end_error_exec_info = sys.exc_info()
             logger.exception("")
         finally:
@@ -561,8 +633,10 @@ class Porter(object):
         def conv(j, col):
             if j in bool_cols:
                 return bool(col)
+            if isinstance(col, bytes):
+                return bytearray(col)
             elif isinstance(col, string_types) and "\0" in col:
-                logger.warn(
+                logger.warning(
                     "DROPPING ROW: NUL value in table %s col %s: %r",
                     table,
                     headers[j],
@@ -582,8 +656,7 @@ class Porter(object):
 
         return outrows
 
-    @defer.inlineCallbacks
-    def _setup_sent_transactions(self):
+    async def _setup_sent_transactions(self):
         # Only save things from the last day
         yesterday = int(time.time() * 1000) - 86400000
 
@@ -600,11 +673,11 @@ class Porter(object):
             rows = txn.fetchall()
             headers = [column[0] for column in txn.description]
 
-            ts_ind = headers.index('ts')
+            ts_ind = headers.index("ts")
 
             return headers, [r for r in rows if r[ts_ind] < yesterday]
 
-        headers, rows = yield self.sqlite_store.runInteraction("select", r)
+        headers, rows = await self.sqlite_store.db.runInteraction("select", r)
 
         rows = self._convert_rows("sent_transactions", headers, rows)
 
@@ -617,7 +690,7 @@ class Porter(object):
                     txn, "sent_transactions", headers[1:], rows
                 )
 
-            yield self.postgres_store.execute(insert)
+            await self.postgres_store.execute(insert)
         else:
             max_inserted_rowid = 0
 
@@ -634,10 +707,10 @@ class Porter(object):
             else:
                 return 1
 
-        next_chunk = yield self.sqlite_store.execute(get_start_id)
+        next_chunk = await self.sqlite_store.execute(get_start_id)
         next_chunk = max(max_inserted_rowid + 1, next_chunk)
 
-        yield self.postgres_store._simple_insert(
+        await self.postgres_store.db.simple_insert(
             table="port_from_sqlite3",
             values={
                 "table_name": "sent_transactions",
@@ -650,57 +723,63 @@ class Porter(object):
             txn.execute(
                 "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
             )
-            size, = txn.fetchone()
+            (size,) = txn.fetchone()
             return int(size)
 
-        remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
+        remaining_count = await self.sqlite_store.execute(get_sent_table_size)
 
         total_count = remaining_count + inserted_rows
 
-        defer.returnValue((next_chunk, inserted_rows, total_count))
+        return next_chunk, inserted_rows, total_count
 
-    @defer.inlineCallbacks
-    def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
-        frows = yield self.sqlite_store.execute_sql(
+    async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
+        frows = await self.sqlite_store.execute_sql(
             "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
         )
 
-        brows = yield self.sqlite_store.execute_sql(
+        brows = await self.sqlite_store.execute_sql(
             "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
         )
 
-        defer.returnValue(frows[0][0] + brows[0][0])
+        return frows[0][0] + brows[0][0]
 
-    @defer.inlineCallbacks
-    def _get_already_ported_count(self, table):
-        rows = yield self.postgres_store.execute_sql(
+    async def _get_already_ported_count(self, table):
+        rows = await self.postgres_store.execute_sql(
             "SELECT count(*) FROM %s" % (table,)
         )
 
-        defer.returnValue(rows[0][0])
+        return rows[0][0]
 
-    @defer.inlineCallbacks
-    def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
-        remaining, done = yield defer.gatherResults(
-            [
-                self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
-                self._get_already_ported_count(table),
-            ],
-            consumeErrors=True,
+    async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
+        remaining, done = await make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self._get_remaining_count_to_port,
+                        table,
+                        forward_chunk,
+                        backward_chunk,
+                    ),
+                    run_in_background(self._get_already_ported_count, table),
+                ],
+            )
         )
 
         remaining = int(remaining) if remaining else 0
         done = int(done) if done else 0
 
-        defer.returnValue((done, remaining + done))
+        return done, remaining + done
 
     def _setup_state_group_id_seq(self):
         def r(txn):
             txn.execute("SELECT MAX(id) FROM state_groups")
-            next_id = txn.fetchone()[0] + 1
+            curr_id = txn.fetchone()[0]
+            if not curr_id:
+                return
+            next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
 
-        return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
+        return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
 
 
 ##############################################
@@ -781,7 +860,7 @@ class CursesProgress(Progress):
         duration = int(now) - int(self.start_time)
 
         minutes, seconds = divmod(duration, 60)
-        duration_str = '%02dm %02ds' % (minutes, seconds)
+        duration_str = "%02dm %02ds" % (minutes, seconds)
 
         if self.finished:
             status = "Time spent: %s (Done!)" % (duration_str,)
@@ -791,7 +870,7 @@ class CursesProgress(Progress):
                 left = float(self.total_remaining) / self.total_processed
 
                 est_remaining = (int(now) - self.start_time) * left
-                est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
+                est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60)
             else:
                 est_remaining_str = "Unknown"
             status = "Time spent: %s (est. remaining: %s)" % (
@@ -877,7 +956,7 @@ if __name__ == "__main__":
         description="A script to port an existing synapse SQLite database to"
         " a new PostgreSQL database."
     )
-    parser.add_argument("-v", action='store_true')
+    parser.add_argument("-v", action="store_true")
     parser.add_argument(
         "--sqlite-database",
         required=True,
@@ -886,12 +965,12 @@ if __name__ == "__main__":
     )
     parser.add_argument(
         "--postgres-config",
-        type=argparse.FileType('r'),
+        type=argparse.FileType("r"),
         required=True,
         help="The database config file for the PostgreSQL database",
     )
     parser.add_argument(
-        "--curses", action='store_true', help="display a curses based progress UI"
+        "--curses", action="store_true", help="display a curses based progress UI"
     )
 
     parser.add_argument(
@@ -924,18 +1003,24 @@ if __name__ == "__main__":
         },
     }
 
-    postgres_config = yaml.safe_load(args.postgres_config)
+    hs_config = yaml.safe_load(args.postgres_config)
+
+    if "database" not in hs_config:
+        sys.stderr.write("The configuration file must have a 'database' section.\n")
+        sys.exit(4)
 
-    if "database" in postgres_config:
-        postgres_config = postgres_config["database"]
+    postgres_config = hs_config["database"]
 
     if "name" not in postgres_config:
-        sys.stderr.write("Malformed database config: no 'name'")
+        sys.stderr.write("Malformed database config: no 'name'\n")
         sys.exit(2)
     if postgres_config["name"] != "psycopg2":
-        sys.stderr.write("Database must use 'psycopg2' connector.")
+        sys.stderr.write("Database must use the 'psycopg2' connector.\n")
         sys.exit(3)
 
+    config = HomeServerConfig()
+    config.parse_config_dict(hs_config, "", "")
+
     def start(stdscr=None):
         if stdscr:
             progress = CursesProgress(stdscr)
@@ -944,12 +1029,17 @@ if __name__ == "__main__":
 
         porter = Porter(
             sqlite_config=sqlite_config,
-            postgres_config=postgres_config,
             progress=progress,
             batch_size=args.batch_size,
+            hs_config=config,
         )
 
-        reactor.callWhenRunning(porter.run)
+        @defer.inlineCallbacks
+        def run():
+            with LoggingContext("synapse_port_db_run"):
+                yield defer.ensureDeferred(porter.run())
+
+        reactor.callWhenRunning(run)
 
         reactor.run()
 
@@ -958,6 +1048,11 @@ if __name__ == "__main__":
     else:
         start()
 
-    if end_error_exec_info:
-        exc_type, exc_value, exc_traceback = end_error_exec_info
-        traceback.print_exception(exc_type, exc_value, exc_traceback)
+    if end_error:
+        if end_error_exec_info:
+            exc_type, exc_value, exc_traceback = end_error_exec_info
+            traceback.print_exception(exc_type, exc_value, exc_traceback)
+
+        sys.stderr.write(end_error)
+
+        sys.exit(5)