summary refs log tree commit diff
path: root/scripts
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-04-28 11:16:44 +0100
committerErik Johnston <erik@matrix.org>2015-04-28 11:16:44 +0100
commit4a13ae72019655ad0531f93af18382c196fb362d (patch)
treecfac0fa20ef0b7c08b0871ce3bcb55c6bcff4068 /scripts
parentRemove accidentally committed debug hardcode hack (diff)
downloadsynapse-4a13ae72019655ad0531f93af18382c196fb362d.tar.xz
Correctly handle total/remaining counts in the presence of sent_transasctions table
Diffstat (limited to 'scripts')
-rw-r--r--scripts/port_from_sqlite_to_postgres.py290
1 files changed, 167 insertions, 123 deletions
diff --git a/scripts/port_from_sqlite_to_postgres.py b/scripts/port_from_sqlite_to_postgres.py
index 596be75c49..1e52d82fe0 100644
--- a/scripts/port_from_sqlite_to_postgres.py
+++ b/scripts/port_from_sqlite_to_postgres.py
@@ -125,6 +125,12 @@ class Store(object):
     def execute(self, f, *args, **kwargs):
         return self.runInteraction(f.__name__, f, *args, **kwargs)
 
+    def execute_sql(self, sql, *args):
+        def r(txn):
+            txn.execute(sql, args)
+            return txn.fetchall()
+        return self.runInteraction("execute_sql", r)
+
     def insert_many_txn(self, txn, table, headers, rows):
         sql = "INSERT INTO %s (%s) VALUES (%s)" % (
             table,
@@ -146,119 +152,9 @@ class Porter(object):
     def __init__(self, **kwargs):
         self.__dict__.update(kwargs)
 
-    def convert_rows(self, table, headers, rows):
-        bool_col_names = BOOLEAN_COLUMNS.get(table, [])
-
-        bool_cols = [
-            i for i, h in enumerate(headers) if h in bool_col_names
-        ]
-
-        def conv(j, col):
-            if j in bool_cols:
-                return bool(col)
-            return col
-
-        for i, row in enumerate(rows):
-            rows[i] = tuple(
-                self.postgres_store.database_engine.encode_parameter(
-                    conv(j, col)
-                )
-                for j, col in enumerate(row)
-                if j > 0
-            )
-
     @defer.inlineCallbacks
     def setup_table(self, table):
-        def delete_all(txn):
-            txn.execute(
-                "DELETE FROM port_from_sqlite3 WHERE table_name = %s",
-                (table,)
-            )
-            txn.execute("TRUNCATE %s CASCADE" % (table,))
-
-        def get_table_size(txn):
-            txn.execute("SELECT count(*) FROM %s" % (table,))
-            size, = txn.fetchone()
-            return int(size)
-
-        if table == "sent_transactions":
-            # This is a big table, and we really only need some of the recent
-            # data
-
-            yield self.postgres_store.execute(delete_all)
-
-            # Only save things from the last day
-            yesterday = int(time.time()*1000) - 86400000
-
-            # And save the max transaction id from each destination
-            select = (
-                "SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
-                "SELECT max(rowid) FROM sent_transactions"
-                " GROUP BY destination"
-                ")"
-            )
-
-            def r(txn):
-                txn.execute(select)
-                rows = txn.fetchall()
-                headers = [column[0] for column in txn.description]
-
-                ts_ind = headers.index('ts')
-
-                return headers, [r for r in rows if r[ts_ind] < yesterday]
-
-            headers, rows = yield self.sqlite_store.runInteraction(
-                "select", r,
-            )
-
-            self.convert_rows(table, headers, rows)
-
-            inserted_rows = len(rows)
-            max_inserted_rowid = max(r[0] for r in rows)
-
-            def insert(txn):
-                self.postgres_store.insert_many_txn(
-                    txn, table, headers[1:], rows
-                )
-
-            yield self.postgres_store.execute(insert)
-
-            def get_start_id(txn):
-                txn.execute(
-                    "SELECT rowid FROM sent_transactions WHERE ts >= ?"
-                    " ORDER BY rowid ASC LIMIT 1",
-                    (yesterday,)
-                )
-
-                rows = txn.fetchall()
-                if rows:
-                    return rows[0][0]
-                else:
-                    return 1
-
-            next_chunk = yield self.sqlite_store.execute(get_start_id)
-            next_chunk = max(max_inserted_rowid + 1, next_chunk)
-
-            yield self.postgres_store._simple_insert(
-                table="port_from_sqlite3",
-                values={"table_name": table, "rowid": next_chunk}
-            )
-
-            def get_sent_table_size(txn):
-                txn.execute(
-                    "SELECT count(*) FROM sent_transactions"
-                    " WHERE ts >= ?",
-                    (yesterday,)
-                )
-                size, = txn.fetchone()
-                return int(size)
-
-            table_size = yield self.sqlite_store.execute(
-                get_sent_table_size
-            )
-
-            table_size += inserted_rows
-        elif table in APPEND_ONLY_TABLES:
+        if table in APPEND_ONLY_TABLES:
             # It's safe to just carry on inserting.
             next_chunk = yield self.postgres_store._simple_select_one_onecol(
                 table="port_from_sqlite3",
@@ -267,28 +163,47 @@ class Porter(object):
                 allow_none=True,
             )
 
+            total_to_port = None
             if next_chunk is None:
-                yield self.postgres_store._simple_insert(
-                    table="port_from_sqlite3",
-                    values={"table_name": table, "rowid": 1}
-                )
+                if table == "sent_transactions":
+                    next_chunk, already_ported, total_to_port = (
+                        yield self._setup_sent_transactions()
+                    )
+                else:
+                    yield self.postgres_store._simple_insert(
+                        table="port_from_sqlite3",
+                        values={"table_name": table, "rowid": 1}
+                    )
 
-                next_chunk = 1
+                    next_chunk = 1
+                    already_ported = 0
 
-            table_size = yield self.sqlite_store.execute(get_table_size)
+            if total_to_port is None:
+                already_ported, total_to_port = yield self._get_total_count_to_port(
+                    table, next_chunk
+                )
         else:
+            def delete_all(txn):
+                txn.execute(
+                    "DELETE FROM port_from_sqlite3 WHERE table_name = %s",
+                    (table,)
+                )
+                txn.execute("TRUNCATE %s CASCADE" % (table,))
+
             yield self.postgres_store.execute(delete_all)
-            self.postgres_store._simple_insert(
+
+            yield self.postgres_store._simple_insert(
                 table="port_from_sqlite3",
                 values={"table_name": table, "rowid": 0}
             )
 
-            table_size = yield self.sqlite_store.execute(get_table_size)
             next_chunk = 1
 
-        postgres_size = yield self.postgres_store.execute(get_table_size)
+            already_ported, total_to_port = yield self._get_total_count_to_port(
+                table, next_chunk
+            )
 
-        defer.returnValue((table, postgres_size, table_size, next_chunk))
+        defer.returnValue((table, already_ported, total_to_port, next_chunk))
 
     @defer.inlineCallbacks
     def handle_table(self, table, postgres_size, table_size, next_chunk):
@@ -315,7 +230,7 @@ class Porter(object):
             if rows:
                 next_chunk = rows[-1][0] + 1
 
-                self.convert_rows(table, headers, rows)
+                self._convert_rows(table, headers, rows)
 
                 def insert(txn):
                     self.postgres_store.insert_many_txn(
@@ -414,7 +329,7 @@ class Porter(object):
             except Exception as e:
                 logger.info("Failed to create port table: %s", e)
 
-            self.progress.set_state("Preparing tables")
+            self.progress.set_state("Setting up")
 
             # Set up tables.
             setup_res = yield defer.gatherResults(
@@ -444,6 +359,135 @@ class Porter(object):
         finally:
             reactor.stop()
 
+    def _convert_rows(self, table, headers, rows):
+        bool_col_names = BOOLEAN_COLUMNS.get(table, [])
+
+        bool_cols = [
+            i for i, h in enumerate(headers) if h in bool_col_names
+        ]
+
+        def conv(j, col):
+            if j in bool_cols:
+                return bool(col)
+            return col
+
+        for i, row in enumerate(rows):
+            rows[i] = tuple(
+                self.postgres_store.database_engine.encode_parameter(
+                    conv(j, col)
+                )
+                for j, col in enumerate(row)
+                if j > 0
+            )
+
+    @defer.inlineCallbacks
+    def _setup_sent_transactions(self):
+        # Only save things from the last day
+        yesterday = int(time.time()*1000) - 86400000
+
+        # And save the max transaction id from each destination
+        select = (
+            "SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
+            "SELECT max(rowid) FROM sent_transactions"
+            " GROUP BY destination"
+            ")"
+        )
+
+        def r(txn):
+            txn.execute(select)
+            rows = txn.fetchall()
+            headers = [column[0] for column in txn.description]
+
+            ts_ind = headers.index('ts')
+
+            return headers, [r for r in rows if r[ts_ind] < yesterday]
+
+        headers, rows = yield self.sqlite_store.runInteraction(
+            "select", r,
+        )
+
+        self._convert_rows("sent_transactions", headers, rows)
+
+        inserted_rows = len(rows)
+        max_inserted_rowid = max(r[0] for r in rows)
+
+        def insert(txn):
+            self.postgres_store.insert_many_txn(
+                txn, "sent_transactions", headers[1:], rows
+            )
+
+        yield self.postgres_store.execute(insert)
+
+        def get_start_id(txn):
+            txn.execute(
+                "SELECT rowid FROM sent_transactions WHERE ts >= ?"
+                " ORDER BY rowid ASC LIMIT 1",
+                (yesterday,)
+            )
+
+            rows = txn.fetchall()
+            if rows:
+                return rows[0][0]
+            else:
+                return 1
+
+        next_chunk = yield self.sqlite_store.execute(get_start_id)
+        next_chunk = max(max_inserted_rowid + 1, next_chunk)
+
+        yield self.postgres_store._simple_insert(
+            table="port_from_sqlite3",
+            values={"table_name": "sent_transactions", "rowid": next_chunk}
+        )
+
+        def get_sent_table_size(txn):
+            txn.execute(
+                "SELECT count(*) FROM sent_transactions"
+                " WHERE ts >= ?",
+                (yesterday,)
+            )
+            size, = txn.fetchone()
+            return int(size)
+
+        remaining_count = yield self.sqlite_store.execute(
+            get_sent_table_size
+        )
+
+        total_count = remaining_count + inserted_rows
+
+        defer.returnValue((next_chunk, remaining_count, total_count))
+
+    @defer.inlineCallbacks
+    def _get_remaining_count_to_port(self, table, next_chunk):
+        rows = yield self.sqlite_store.execute_sql(
+            "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
+            next_chunk,
+        )
+
+        defer.returnValue(rows[0][0])
+
+    @defer.inlineCallbacks
+    def _get_already_ported_count(self, table):
+        rows = yield self.postgres_store.execute_sql(
+            "SELECT count(*) FROM %s" % (table,),
+        )
+
+        defer.returnValue(rows[0][0])
+
+    @defer.inlineCallbacks
+    def _get_total_count_to_port(self, table, next_chunk):
+        remaining, done = yield defer.gatherResults(
+            [
+                self._get_remaining_count_to_port(table, next_chunk),
+                self._get_already_ported_count(table),
+            ],
+            consumeErrors=True,
+        )
+
+        remaining = int(remaining) if remaining else 0
+        done = int(done) if done else 0
+
+        defer.returnValue((done, remaining + done))
+
 
 ##############################################
 ###### The following is simply UI stuff ######