summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rwxr-xr-xscripts/synapse_port_db119
1 files changed, 46 insertions, 73 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index c4cf11d19a..e393a9b2f7 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -55,6 +55,7 @@ from synapse.storage.data_stores.main.stats import StatsStore
 from synapse.storage.data_stores.main.user_directory import (
     UserDirectoryBackgroundUpdateStore,
 )
+from synapse.storage.database import Database
 from synapse.storage.engines import create_engine
 from synapse.storage.prepare_database import prepare_database
 from synapse.util import Clock
@@ -139,48 +140,15 @@ class Store(
     UserDirectoryBackgroundUpdateStore,
     StatsStore,
 ):
-    def __init__(self, db_conn, hs):
-        super().__init__(db_conn, hs)
-        self.db_pool = hs.get_db_pool()
-
-    @defer.inlineCallbacks
-    def runInteraction(self, desc, func, *args, **kwargs):
-        def r(conn):
-            try:
-                i = 0
-                N = 5
-                while True:
-                    try:
-                        txn = conn.cursor()
-                        return func(
-                            LoggingTransaction(txn, desc, self.database_engine, [], []),
-                            *args,
-                            **kwargs
-                        )
-                    except self.database_engine.module.DatabaseError as e:
-                        if self.database_engine.is_deadlock(e):
-                            logger.warning("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
-                            if i < N:
-                                i += 1
-                                conn.rollback()
-                                continue
-                        raise
-            except Exception as e:
-                logger.debug("[TXN FAIL] {%s} %s", desc, e)
-                raise
-
-        with PreserveLoggingContext():
-            return (yield self.db_pool.runWithConnection(r))
-
     def execute(self, f, *args, **kwargs):
-        return self.runInteraction(f.__name__, f, *args, **kwargs)
+        return self.db.runInteraction(f.__name__, f, *args, **kwargs)
 
     def execute_sql(self, sql, *args):
         def r(txn):
             txn.execute(sql, args)
             return txn.fetchall()
 
-        return self.runInteraction("execute_sql", r)
+        return self.db.runInteraction("execute_sql", r)
 
     def insert_many_txn(self, txn, table, headers, rows):
         sql = "INSERT INTO %s (%s) VALUES (%s)" % (
@@ -223,7 +191,7 @@ class Porter(object):
     def setup_table(self, table):
         if table in APPEND_ONLY_TABLES:
             # It's safe to just carry on inserting.
-            row = yield self.postgres_store.simple_select_one(
+            row = yield self.postgres_store.db.simple_select_one(
                 table="port_from_sqlite3",
                 keyvalues={"table_name": table},
                 retcols=("forward_rowid", "backward_rowid"),
@@ -233,12 +201,14 @@ class Porter(object):
             total_to_port = None
             if row is None:
                 if table == "sent_transactions":
-                    forward_chunk, already_ported, total_to_port = (
-                        yield self._setup_sent_transactions()
-                    )
+                    (
+                        forward_chunk,
+                        already_ported,
+                        total_to_port,
+                    ) = yield self._setup_sent_transactions()
                     backward_chunk = 0
                 else:
-                    yield self.postgres_store.simple_insert(
+                    yield self.postgres_store.db.simple_insert(
                         table="port_from_sqlite3",
                         values={
                             "table_name": table,
@@ -268,7 +238,7 @@ class Porter(object):
 
             yield self.postgres_store.execute(delete_all)
 
-            yield self.postgres_store.simple_insert(
+            yield self.postgres_store.db.simple_insert(
                 table="port_from_sqlite3",
                 values={"table_name": table, "forward_rowid": 1, "backward_rowid": 0},
             )
@@ -322,7 +292,7 @@ class Porter(object):
         if table == "user_directory_stream_pos":
             # We need to make sure there is a single row, `(X, null), as that is
             # what synapse expects to be there.
-            yield self.postgres_store.simple_insert(
+            yield self.postgres_store.db.simple_insert(
                 table=table, values={"stream_id": None}
             )
             self.progress.update(table, table_size)  # Mark table as done
@@ -363,7 +333,9 @@ class Porter(object):
 
                 return headers, forward_rows, backward_rows
 
-            headers, frows, brows = yield self.sqlite_store.runInteraction("select", r)
+            headers, frows, brows = yield self.sqlite_store.db.runInteraction(
+                "select", r
+            )
 
             if frows or brows:
                 if frows:
@@ -377,7 +349,7 @@ class Porter(object):
                 def insert(txn):
                     self.postgres_store.insert_many_txn(txn, table, headers[1:], rows)
 
-                    self.postgres_store.simple_update_one_txn(
+                    self.postgres_store.db.simple_update_one_txn(
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": table},
@@ -416,7 +388,7 @@ class Porter(object):
 
                 return headers, rows
 
-            headers, rows = yield self.sqlite_store.runInteraction("select", r)
+            headers, rows = yield self.sqlite_store.db.runInteraction("select", r)
 
             if rows:
                 forward_chunk = rows[-1][0] + 1
@@ -433,8 +405,8 @@ class Porter(object):
                     rows_dict = []
                     for row in rows:
                         d = dict(zip(headers, row))
-                        if "\0" in d['value']:
-                            logger.warning('dropping search row %s', d)
+                        if "\0" in d["value"]:
+                            logger.warning("dropping search row %s", d)
                         else:
                             rows_dict.append(d)
 
@@ -454,7 +426,7 @@ class Porter(object):
                         ],
                     )
 
-                    self.postgres_store.simple_update_one_txn(
+                    self.postgres_store.db.simple_update_one_txn(
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": "event_search"},
@@ -504,17 +476,14 @@ class Porter(object):
         self.progress.set_state("Preparing %s" % config["name"])
         conn = self.setup_db(config, engine)
 
-        db_pool = adbapi.ConnectionPool(
-            config["name"], **config["args"]
-        )
+        db_pool = adbapi.ConnectionPool(config["name"], **config["args"])
 
         hs = MockHomeserver(self.hs_config, engine, conn, db_pool)
 
-        store = Store(conn, hs)
+        store = Store(Database(hs), conn, hs)
 
-        yield store.runInteraction(
-            "%s_engine.check_database" % config["name"],
-            engine.check_database,
+        yield store.db.runInteraction(
+            "%s_engine.check_database" % config["name"], engine.check_database,
         )
 
         return store
@@ -522,7 +491,9 @@ class Porter(object):
     @defer.inlineCallbacks
     def run_background_updates_on_postgres(self):
         # Manually apply all background updates on the PostgreSQL database.
-        postgres_ready = yield self.postgres_store.has_completed_background_updates()
+        postgres_ready = (
+            yield self.postgres_store.db.updates.has_completed_background_updates()
+        )
 
         if not postgres_ready:
             # Only say that we're running background updates when there are background
@@ -530,9 +501,9 @@ class Porter(object):
             self.progress.set_state("Running background updates on PostgreSQL")
 
         while not postgres_ready:
-            yield self.postgres_store.do_next_background_update(100)
+            yield self.postgres_store.db.updates.do_next_background_update(100)
             postgres_ready = yield (
-                self.postgres_store.has_completed_background_updates()
+                self.postgres_store.db.updates.has_completed_background_updates()
             )
 
     @defer.inlineCallbacks
@@ -541,7 +512,9 @@ class Porter(object):
             self.sqlite_store = yield self.build_db_store(self.sqlite_config)
 
             # Check if all background updates are done, abort if not.
-            updates_complete = yield self.sqlite_store.has_completed_background_updates()
+            updates_complete = (
+                yield self.sqlite_store.db.updates.has_completed_background_updates()
+            )
             if not updates_complete:
                 sys.stderr.write(
                     "Pending background updates exist in the SQLite3 database."
@@ -582,22 +555,22 @@ class Porter(object):
                 )
 
             try:
-                yield self.postgres_store.runInteraction("alter_table", alter_table)
+                yield self.postgres_store.db.runInteraction("alter_table", alter_table)
             except Exception:
                 # On Error Resume Next
                 pass
 
-            yield self.postgres_store.runInteraction(
+            yield self.postgres_store.db.runInteraction(
                 "create_port_table", create_port_table
             )
 
             # Step 2. Get tables.
             self.progress.set_state("Fetching tables")
-            sqlite_tables = yield self.sqlite_store.simple_select_onecol(
+            sqlite_tables = yield self.sqlite_store.db.simple_select_onecol(
                 table="sqlite_master", keyvalues={"type": "table"}, retcol="name"
             )
 
-            postgres_tables = yield self.postgres_store.simple_select_onecol(
+            postgres_tables = yield self.postgres_store.db.simple_select_onecol(
                 table="information_schema.tables",
                 keyvalues={},
                 retcol="distinct table_name",
@@ -687,11 +660,11 @@ class Porter(object):
             rows = txn.fetchall()
             headers = [column[0] for column in txn.description]
 
-            ts_ind = headers.index('ts')
+            ts_ind = headers.index("ts")
 
             return headers, [r for r in rows if r[ts_ind] < yesterday]
 
-        headers, rows = yield self.sqlite_store.runInteraction("select", r)
+        headers, rows = yield self.sqlite_store.db.runInteraction("select", r)
 
         rows = self._convert_rows("sent_transactions", headers, rows)
 
@@ -724,7 +697,7 @@ class Porter(object):
         next_chunk = yield self.sqlite_store.execute(get_start_id)
         next_chunk = max(max_inserted_rowid + 1, next_chunk)
 
-        yield self.postgres_store.simple_insert(
+        yield self.postgres_store.db.simple_insert(
             table="port_from_sqlite3",
             values={
                 "table_name": "sent_transactions",
@@ -737,7 +710,7 @@ class Porter(object):
             txn.execute(
                 "SELECT count(*) FROM sent_transactions" " WHERE ts >= ?", (yesterday,)
             )
-            size, = txn.fetchone()
+            (size,) = txn.fetchone()
             return int(size)
 
         remaining_count = yield self.sqlite_store.execute(get_sent_table_size)
@@ -790,7 +763,7 @@ class Porter(object):
             next_id = curr_id + 1
             txn.execute("ALTER SEQUENCE state_group_id_seq RESTART WITH %s", (next_id,))
 
-        return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
+        return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
 
 
 ##############################################
@@ -871,7 +844,7 @@ class CursesProgress(Progress):
         duration = int(now) - int(self.start_time)
 
         minutes, seconds = divmod(duration, 60)
-        duration_str = '%02dm %02ds' % (minutes, seconds)
+        duration_str = "%02dm %02ds" % (minutes, seconds)
 
         if self.finished:
             status = "Time spent: %s (Done!)" % (duration_str,)
@@ -881,7 +854,7 @@ class CursesProgress(Progress):
                 left = float(self.total_remaining) / self.total_processed
 
                 est_remaining = (int(now) - self.start_time) * left
-                est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
+                est_remaining_str = "%02dm %02ds remaining" % divmod(est_remaining, 60)
             else:
                 est_remaining_str = "Unknown"
             status = "Time spent: %s (est. remaining: %s)" % (
@@ -967,7 +940,7 @@ if __name__ == "__main__":
         description="A script to port an existing synapse SQLite database to"
         " a new PostgreSQL database."
     )
-    parser.add_argument("-v", action='store_true')
+    parser.add_argument("-v", action="store_true")
     parser.add_argument(
         "--sqlite-database",
         required=True,
@@ -976,12 +949,12 @@ if __name__ == "__main__":
     )
     parser.add_argument(
         "--postgres-config",
-        type=argparse.FileType('r'),
+        type=argparse.FileType("r"),
         required=True,
         help="The database config file for the PostgreSQL database",
     )
     parser.add_argument(
-        "--curses", action='store_true', help="display a curses based progress UI"
+        "--curses", action="store_true", help="display a curses based progress UI"
     )
 
     parser.add_argument(