summary refs log tree commit diff
path: root/scripts/synapse_port_db
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/synapse_port_db')
-rwxr-xr-xscripts/synapse_port_db233
1 files changed, 130 insertions, 103 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index cb77314f1e..e8b698f3ff 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -27,13 +27,16 @@ from six import string_types
 
 import yaml
 
-from twisted.enterprise import adbapi
 from twisted.internet import defer, reactor
 
+import synapse
 from synapse.config.database import DatabaseConnectionConfig
 from synapse.config.homeserver import HomeServerConfig
-from synapse.logging.context import PreserveLoggingContext
-from synapse.storage._base import LoggingTransaction
+from synapse.logging.context import (
+    LoggingContext,
+    make_deferred_yieldable,
+    run_in_background,
+)
 from synapse.storage.data_stores.main.client_ips import ClientIpBackgroundUpdateStore
 from synapse.storage.data_stores.main.deviceinbox import (
     DeviceInboxBackgroundUpdateStore,
@@ -61,6 +64,7 @@ from synapse.storage.database import Database, make_conn
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util import Clock
+from synapse.util.versionstring import get_version_string
 
 logger = logging.getLogger("synapse_port_db")
 
@@ -125,6 +129,13 @@ APPEND_ONLY_TABLES = [
 ]
 
 
+# Error returned by the run function. Used at the top-level part of the script to
+# handle errors and return codes.
+end_error = None
+# The exec_info for the error, if any. If error is defined but not exec_info the script
+# will show only the error message without the stacktrace, if exec_info is defined but
+# not the error then the script will show nothing outside of what's printed in the run
+# function. If both are defined, the script will print both the error and the stacktrace.
 end_error_exec_info = None
 
 
@@ -177,6 +188,7 @@ class MockHomeserver:
         self.clock = Clock(reactor)
         self.config = config
         self.hostname = config.server_name
+        self.version_string = "Synapse/"+get_version_string(synapse)
 
     def get_clock(self):
         return self.clock
@@ -189,11 +201,10 @@ class Porter(object):
     def __init__(self, **kwargs):
         self.__dict__.update(kwargs)
 
-    @defer.inlineCallbacks
-    def setup_table(self, table):
+    async def setup_table(self, table):
         if table in APPEND_ONLY_TABLES:
             # It's safe to just carry on inserting.
-            row = yield self.postgres_store.db.simple_select_one(
+            row = await self.postgres_store.db.simple_select_one(
                 table="port_from_sqlite3",
                 keyvalues={"table_name": table},
                 retcols=("forward_rowid", "backward_rowid"),
@@ -207,10 +218,10 @@ class Porter(object):
                         forward_chunk,
                         already_ported,
                         total_to_port,
-                    ) = yield self._setup_sent_transactions()
+                    ) = await self._setup_sent_transactions()
                     backward_chunk = 0
                 else:
-                    yield self.postgres_store.db.simple_insert(
+                    await self.postgres_store.db.simple_insert(
                         table="port_from_sqlite3",
                         values={
                             "table_name": table,
@@ -227,7 +238,7 @@ class Porter(object):
                 backward_chunk = row["backward_rowid"]
 
             if total_to_port is None:
-                already_ported, total_to_port = yield self._get_total_count_to_port(
+                already_ported, total_to_port = await self._get_total_count_to_port(
                     table, forward_chunk, backward_chunk
                 )
         else:
@@ -238,9 +249,9 @@ class Porter(object):
                 )
                 txn.execute("TRUNCATE %s CASCADE" % (table,))
 
-            yield self.postgres_store.execute(delete_all)
+            await self.postgres_store.execute(delete_all)
 
-            yield self.postgres_store.db.simple_insert(
+            await self.postgres_store.db.simple_insert(
                 table="port_from_sqlite3",
                 values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
             )
@@ -248,16 +259,13 @@ class Porter(object):
             forward_chunk = 1
             backward_chunk = 0
 
-            already_ported, total_to_port = yield self._get_total_count_to_port(
+            already_ported, total_to_port = await self._get_total_count_to_port(
                 table, forward_chunk, backward_chunk
             )
 
-        defer.returnValue(
-            (table, already_ported, total_to_port, forward_chunk, backward_chunk)
-        )
+        return table, already_ported, total_to_port, forward_chunk, backward_chunk
 
-    @defer.inlineCallbacks
-    def handle_table(
+    async def handle_table(
         self, table, postgres_size, table_size, forward_chunk, backward_chunk
     ):
         logger.info(
@@ -275,7 +283,7 @@ class Porter(object):
         self.progress.add_table(table, postgres_size, table_size)
 
         if table == "event_search":
-            yield self.handle_search_table(
+            await self.handle_search_table(
                 postgres_size, table_size, forward_chunk, backward_chunk
             )
             return
@@ -294,7 +302,7 @@ class Porter(object):
         if table == "user_directory_stream_pos":
             # We need to make sure there is a single row, `(X, null), as that is
             # what synapse expects to be there.
-            yield self.postgres_store.db.simple_insert(
+            await self.postgres_store.db.simple_insert(
                 table=table, values={"stream_id": None}
             )
             self.progress.update(table, table_size)  # Mark table as done
@@ -335,7 +343,7 @@ class Porter(object):
 
                 return headers, forward_rows, backward_rows
 
-            headers, frows, brows = yield self.sqlite_store.db.runInteraction(
+            headers, frows, brows = await self.sqlite_store.db.runInteraction(
                 "select", r
             )
 
@@ -361,7 +369,7 @@ class Porter(object):
                         },
                     )
 
-                yield self.postgres_store.execute(insert)
+                await self.postgres_store.execute(insert)
 
                 postgres_size += len(rows)
 
@@ -369,8 +377,7 @@ class Porter(object):
             else:
                 return
 
-    @defer.inlineCallbacks
-    def handle_search_table(
+    async def handle_search_table(
         self, postgres_size, table_size, forward_chunk, backward_chunk
     ):
         select = (
@@ -390,7 +397,7 @@ class Porter(object):
 
                 return headers, rows
 
-            headers, rows = yield self.sqlite_store.db.runInteraction("select", r)
+            headers, rows = await self.sqlite_store.db.runInteraction("select", r)
 
             if rows:
                 forward_chunk = rows[-1][0] + 1
@@ -438,7 +445,7 @@ class Porter(object):
                         },
                     )
 
-                yield self.postgres_store.execute(insert)
+                await self.postgres_store.execute(insert)
 
                 postgres_size += len(rows)
 
@@ -447,20 +454,15 @@ class Porter(object):
             else:
                 return
 
-    def setup_db(self, db_config: DatabaseConnectionConfig, engine):
-        db_conn = make_conn(db_config, engine)
-        prepare_database(db_conn, engine, config=None)
-
-        db_conn.commit()
-
-        return db_conn
-
-    @defer.inlineCallbacks
-    def build_db_store(self, db_config: DatabaseConnectionConfig):
+    def build_db_store(
+        self, db_config: DatabaseConnectionConfig, allow_outdated_version: bool = False,
+    ):
         """Builds and returns a database store using the provided configuration.
 
         Args:
-            config: The database configuration
+            db_config: The database configuration
+            allow_outdated_version: True to suppress errors about the database server
+                version being too old to run a complete synapse
 
         Returns:
             The built Store object.
@@ -468,24 +470,23 @@ class Porter(object):
         self.progress.set_state("Preparing %s" % db_config.config["name"])
 
         engine = create_engine(db_config.config)
-        conn = self.setup_db(db_config, engine)
 
         hs = MockHomeserver(self.hs_config)
 
-        store = Store(Database(hs, db_config, engine), conn, hs)
-
-        yield store.db.runInteraction(
-            "%s_engine.check_database" % db_config.config["name"],
-            engine.check_database,
-        )
+        with make_conn(db_config, engine) as db_conn:
+            engine.check_database(
+                db_conn, allow_outdated_version=allow_outdated_version
+            )
+            prepare_database(db_conn, engine, config=self.hs_config)
+            store = Store(Database(hs, db_config, engine), db_conn, hs)
+            db_conn.commit()
 
         return store
 
-    @defer.inlineCallbacks
-    def run_background_updates_on_postgres(self):
+    async def run_background_updates_on_postgres(self):
         # Manually apply all background updates on the PostgreSQL database.
         postgres_ready = (
-            yield self.postgres_store.db.updates.has_completed_background_updates()
+            await self.postgres_store.db.updates.has_completed_background_updates()
         )
 
         if not postgres_ready:
@@ -494,35 +495,44 @@ class Porter(object):
             self.progress.set_state("Running background updates on PostgreSQL")
 
         while not postgres_ready:
-            yield self.postgres_store.db.updates.do_next_background_update(100)
-            postgres_ready = yield (
+            await self.postgres_store.db.updates.do_next_background_update(100)
+            postgres_ready = await (
                 self.postgres_store.db.updates.has_completed_background_updates()
             )
 
-    @defer.inlineCallbacks
-    def run(self):
+    async def run(self):
+        """Ports the SQLite database to a PostgreSQL database.
+
+        When a fatal error is met, its message is assigned to the global "end_error"
+        variable. When this error comes with a stacktrace, its exec_info is assigned to
+        the global "end_error_exec_info" variable.
+        """
+        global end_error
+
         try:
-            self.sqlite_store = yield self.build_db_store(
-                DatabaseConnectionConfig("master-sqlite", self.sqlite_config)
+            # we allow people to port away from outdated versions of sqlite.
+            self.sqlite_store = self.build_db_store(
+                DatabaseConnectionConfig("master-sqlite", self.sqlite_config),
+                allow_outdated_version=True,
             )
 
             # Check if all background updates are done, abort if not.
             updates_complete = (
-                yield self.sqlite_store.db.updates.has_completed_background_updates()
+                await self.sqlite_store.db.updates.has_completed_background_updates()
             )
             if not updates_complete:
-                sys.stderr.write(
+                end_error = (
                     "Pending background updates exist in the SQLite3 database."
                     " Please start Synapse again and wait until every update has finished"
                     " before running this script.\n"
                 )
-                defer.returnValue(None)
+                return
 
-            self.postgres_store = yield self.build_db_store(
+            self.postgres_store = self.build_db_store(
                 self.hs_config.get_single_database()
             )
 
-            yield self.run_background_updates_on_postgres()
+            await self.run_background_updates_on_postgres()
 
             self.progress.set_state("Creating port tables")
 
@@ -550,22 +560,22 @@ class Porter(object):
                 )
 
             try:
-                yield self.postgres_store.db.runInteraction("alter_table", alter_table)
+                await self.postgres_store.db.runInteraction("alter_table", alter_table)
             except Exception:
                 # On Error Resume Next
                 pass
 
-            yield self.postgres_store.db.runInteraction(
+            await self.postgres_store.db.runInteraction(
                 "create_port_table", create_port_table
             )
 
             # Step 2. Get tables.
             self.progress.set_state("Fetching tables")
-            sqlite_tables = yield self.sqlite_store.db.simple_select_onecol(
+            sqlite_tables = await self.sqlite_store.db.simple_select_onecol(
                 table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
             )
 
-            postgres_tables = yield self.postgres_store.db.simple_select_onecol(
+            postgres_tables = await self.postgres_store.db.simple_select_onecol(
                 table="information_schema.tables",
                 keyvalues={},
                 retcol="distinct table_name",
@@ -576,28 +586,34 @@ class Porter(object):
 
             # Step 3. Figure out what still needs copying
             self.progress.set_state("Checking on port progress")
-            setup_res = yield defer.gatherResults(
-                [
-                    self.setup_table(table)
-                    for table in tables
-                    if table not in ["schema_version", "applied_schema_deltas"]
-                    and not table.startswith("sqlite_")
-                ],
-                consumeErrors=True,
+            setup_res = await make_deferred_yieldable(
+                defer.gatherResults(
+                    [
+                        run_in_background(self.setup_table, table)
+                        for table in tables
+                        if table not in ["schema_version", "applied_schema_deltas"]
+                        and not table.startswith("sqlite_")
+                    ],
+                    consumeErrors=True,
+                )
             )
 
             # Step 4. Do the copying.
             self.progress.set_state("Copying to postgres")
-            yield defer.gatherResults(
-                [self.handle_table(*res) for res in setup_res], consumeErrors=True
+            await make_deferred_yieldable(
+                defer.gatherResults(
+                    [run_in_background(self.handle_table, *res) for res in setup_res],
+                    consumeErrors=True,
+                )
             )
 
             # Step 5. Do final post-processing
-            yield self._setup_state_group_id_seq()
+            await self._setup_state_group_id_seq()
 
             self.progress.done()
-        except Exception:
+        except Exception as e:
             global end_error_exec_info
+            end_error = e
             end_error_exec_info = sys.exc_info()
             logger.exception("")
         finally:
@@ -637,8 +653,7 @@ class Porter(object):
 
         return outrows
 
-    @defer.inlineCallbacks
-    def _setup_sent_transactions(self):
+    async def _setup_sent_transactions(self):
         # Only save things from the last day
         yesterday = int(time.time() * 1000) - 86400000
 
@@ -659,7 +674,7 @@ class Porter(object):
 
             return headers, [r for r in rows if r[ts_ind] < yesterday]
 
-        headers, rows = yield self.sqlite_store.db.runInteraction("select", r)
+        headers, rows = await self.sqlite_store.db.runInteraction("select", r)
 
         rows = self._convert_rows("sent_transactions", headers, rows)
 
@@ -672,7 +687,7 @@ class Porter(object):
                     txn, "sent_transactions", headers[1:], rows
                 )
 
-            yield self.postgres_store.execute(insert)
+            await self.postgres_store.execute(insert)
         else:
             max_inserted_rowid = 0
 
@@ -689,10 +704,10 @@ class Porter(object):
             else:
                 return 1
 
-        next_chunk = yield self.sqlite_store.execute(get_start_id)
+        next_chunk = await self.sqlite_store.execute(get_start_id)
         next_chunk = max(max_inserted_rowid + 1, next_chunk)
 
-        yield self.postgres_store.db.simple_insert(
+        await self.postgres_store.db.simple_insert(
             table="port_from_sqlite3",
             values={
                 "table_name": "sent_transactions",
@@ -708,46 +723,49 @@ class Porter(object):
             (size,) = txn.fetchone()
             return int(size)
 
-        remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
+        remaining_count = await self.sqlite_store.execute(get_sent_table_size)
 
         total_count = remaining_count + inserted_rows
 
-        defer.returnValue((next_chunk, inserted_rows, total_count))
+        return next_chunk, inserted_rows, total_count
 
-    @defer.inlineCallbacks
-    def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
-        frows = yield self.sqlite_store.execute_sql(
+    async def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
+        frows = await self.sqlite_store.execute_sql(
             "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), forward_chunk
         )
 
-        brows = yield self.sqlite_store.execute_sql(
+        brows = await self.sqlite_store.execute_sql(
             "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), backward_chunk
         )
 
-        defer.returnValue(frows[0][0] + brows[0][0])
+        return frows[0][0] + brows[0][0]
 
-    @defer.inlineCallbacks
-    def _get_already_ported_count(self, table):
-        rows = yield self.postgres_store.execute_sql(
+    async def _get_already_ported_count(self, table):
+        rows = await self.postgres_store.execute_sql(
             "SELECT count(*) FROM %s" % (table,)
         )
 
-        defer.returnValue(rows[0][0])
+        return rows[0][0]
 
-    @defer.inlineCallbacks
-    def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
-        remaining, done = yield defer.gatherResults(
-            [
-                self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
-                self._get_already_ported_count(table),
-            ],
-            consumeErrors=True,
+    async def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
+        remaining, done = await make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    run_in_background(
+                        self._get_remaining_count_to_port,
+                        table,
+                        forward_chunk,
+                        backward_chunk,
+                    ),
+                    run_in_background(self._get_already_ported_count, table),
+                ],
+            )
         )
 
         remaining = int(remaining) if remaining else 0
         done = int(done) if done else 0
 
-        defer.returnValue((done, remaining + done))
+        return done, remaining + done
 
     def _setup_state_group_id_seq(self):
         def r(txn):
@@ -1013,7 +1031,12 @@ if __name__ == "__main__":
             hs_config=config,
         )
 
-        reactor.callWhenRunning(porter.run)
+        @defer.inlineCallbacks
+        def run():
+            with LoggingContext("synapse_port_db_run"):
+                yield defer.ensureDeferred(porter.run())
+
+        reactor.callWhenRunning(run)
 
         reactor.run()
 
@@ -1022,7 +1045,11 @@ if __name__ == "__main__":
     else:
         start()
 
-    if end_error_exec_info:
-        exc_type, exc_value, exc_traceback = end_error_exec_info
-        traceback.print_exception(exc_type, exc_value, exc_traceback)
+    if end_error:
+        if end_error_exec_info:
+            exc_type, exc_value, exc_traceback = end_error_exec_info
+            traceback.print_exception(exc_type, exc_value, exc_traceback)
+
+        sys.stderr.write(end_error)
+
         sys.exit(5)