summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-xscripts/synapse_port_db175
1 files changed, 132 insertions, 43 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index efd04da2d6..66c61b0198 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -34,7 +34,7 @@ logger = logging.getLogger("synapse_port_db")
 
 
 BOOLEAN_COLUMNS = {
-    "events": ["processed", "outlier"],
+    "events": ["processed", "outlier", "contains_url"],
     "rooms": ["is_public"],
     "event_edges": ["is_state"],
     "presence_list": ["accepted"],
@@ -92,8 +92,12 @@ class Store(object):
 
     _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
     _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
+    _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"]
+    _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"]
     _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
-    _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
+    _simple_select_one_onecol_txn = SQLBaseStore.__dict__[
+        "_simple_select_one_onecol_txn"
+    ]
 
     _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
     _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
@@ -158,31 +162,40 @@ class Porter(object):
     def setup_table(self, table):
         if table in APPEND_ONLY_TABLES:
             # It's safe to just carry on inserting.
-            next_chunk = yield self.postgres_store._simple_select_one_onecol(
+            row = yield self.postgres_store._simple_select_one(
                 table="port_from_sqlite3",
                 keyvalues={"table_name": table},
-                retcol="rowid",
+                retcols=("forward_rowid", "backward_rowid"),
                 allow_none=True,
             )
 
             total_to_port = None
-            if next_chunk is None:
+            if row is None:
                 if table == "sent_transactions":
-                    next_chunk, already_ported, total_to_port = (
+                    forward_chunk, already_ported, total_to_port = (
                         yield self._setup_sent_transactions()
                     )
+                    backward_chunk = 0
                 else:
                     yield self.postgres_store._simple_insert(
                         table="port_from_sqlite3",
-                        values={"table_name": table, "rowid": 1}
+                        values={
+                            "table_name": table,
+                            "forward_rowid": 1,
+                            "backward_rowid": 0,
+                        }
                     )
 
-                    next_chunk = 1
+                    forward_chunk = 1
+                    backward_chunk = 0
                     already_ported = 0
+            else:
+                forward_chunk = row["forward_rowid"]
+                backward_chunk = row["backward_rowid"]
 
             if total_to_port is None:
                 already_ported, total_to_port = yield self._get_total_count_to_port(
-                    table, next_chunk
+                    table, forward_chunk, backward_chunk
                 )
         else:
             def delete_all(txn):
@@ -196,46 +209,85 @@ class Porter(object):
 
             yield self.postgres_store._simple_insert(
                 table="port_from_sqlite3",
-                values={"table_name": table, "rowid": 0}
+                values={
+                    "table_name": table,
+                    "forward_rowid": 1,
+                    "backward_rowid": 0,
+                }
             )
 
-            next_chunk = 1
+            forward_chunk = 1
+            backward_chunk = 0
 
             already_ported, total_to_port = yield self._get_total_count_to_port(
-                table, next_chunk
+                table, forward_chunk, backward_chunk
             )
 
-        defer.returnValue((table, already_ported, total_to_port, next_chunk))
+        defer.returnValue(
+            (table, already_ported, total_to_port, forward_chunk, backward_chunk)
+        )
 
     @defer.inlineCallbacks
-    def handle_table(self, table, postgres_size, table_size, next_chunk):
+    def handle_table(self, table, postgres_size, table_size, forward_chunk,
+                     backward_chunk):
         if not table_size:
             return
 
         self.progress.add_table(table, postgres_size, table_size)
 
         if table == "event_search":
-            yield self.handle_search_table(postgres_size, table_size, next_chunk)
+            yield self.handle_search_table(
+                postgres_size, table_size, forward_chunk, backward_chunk
+            )
             return
 
-        select = (
+        forward_select = (
             "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
             % (table,)
         )
 
+        backward_select = (
+            "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?"
+            % (table,)
+        )
+
+        do_forward = [True]
+        do_backward = [True]
+
         while True:
             def r(txn):
-                txn.execute(select, (next_chunk, self.batch_size,))
-                rows = txn.fetchall()
-                headers = [column[0] for column in txn.description]
+                forward_rows = []
+                backward_rows = []
+                if do_forward[0]:
+                    txn.execute(forward_select, (forward_chunk, self.batch_size,))
+                    forward_rows = txn.fetchall()
+                    if not forward_rows:
+                        do_forward[0] = False
+
+                if do_backward[0]:
+                    txn.execute(backward_select, (backward_chunk, self.batch_size,))
+                    backward_rows = txn.fetchall()
+                    if not backward_rows:
+                        do_backward[0] = False
+
+                if forward_rows or backward_rows:
+                    headers = [column[0] for column in txn.description]
+                else:
+                    headers = None
 
-                return headers, rows
+                return headers, forward_rows, backward_rows
 
-            headers, rows = yield self.sqlite_store.runInteraction("select", r)
+            headers, frows, brows = yield self.sqlite_store.runInteraction(
+                "select", r
+            )
 
-            if rows:
-                next_chunk = rows[-1][0] + 1
+            if frows or brows:
+                if frows:
+                    forward_chunk = max(row[0] for row in frows) + 1
+                if brows:
+                    backward_chunk = min(row[0] for row in brows) - 1
 
+                rows = frows + brows
                 self._convert_rows(table, headers, rows)
 
                 def insert(txn):
@@ -247,7 +299,10 @@ class Porter(object):
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": table},
-                        updatevalues={"rowid": next_chunk},
+                        updatevalues={
+                            "forward_rowid": forward_chunk,
+                            "backward_rowid": backward_chunk,
+                        },
                     )
 
                 yield self.postgres_store.execute(insert)
@@ -259,7 +314,8 @@ class Porter(object):
                 return
 
     @defer.inlineCallbacks
-    def handle_search_table(self, postgres_size, table_size, next_chunk):
+    def handle_search_table(self, postgres_size, table_size, forward_chunk,
+                            backward_chunk):
         select = (
             "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering"
             " FROM event_search as es"
@@ -270,7 +326,7 @@ class Porter(object):
 
         while True:
             def r(txn):
-                txn.execute(select, (next_chunk, self.batch_size,))
+                txn.execute(select, (forward_chunk, self.batch_size,))
                 rows = txn.fetchall()
                 headers = [column[0] for column in txn.description]
 
@@ -279,7 +335,7 @@ class Porter(object):
             headers, rows = yield self.sqlite_store.runInteraction("select", r)
 
             if rows:
-                next_chunk = rows[-1][0] + 1
+                forward_chunk = rows[-1][0] + 1
 
                 # We have to treat event_search differently since it has a
                 # different structure in the two different databases.
@@ -312,7 +368,10 @@ class Porter(object):
                         txn,
                         table="port_from_sqlite3",
                         keyvalues={"table_name": "event_search"},
-                        updatevalues={"rowid": next_chunk},
+                        updatevalues={
+                            "forward_rowid": forward_chunk,
+                            "backward_rowid": backward_chunk,
+                        },
                     )
 
                 yield self.postgres_store.execute(insert)
@@ -324,7 +383,6 @@ class Porter(object):
             else:
                 return
 
-
     def setup_db(self, db_config, database_engine):
         db_conn = database_engine.module.connect(
             **{
@@ -395,10 +453,32 @@ class Porter(object):
                 txn.execute(
                     "CREATE TABLE port_from_sqlite3 ("
                     " table_name varchar(100) NOT NULL UNIQUE,"
-                    " rowid bigint NOT NULL"
+                    " forward_rowid bigint NOT NULL,"
+                    " backward_rowid bigint NOT NULL"
                     ")"
                 )
 
+            # The old port script created a table with just a "rowid" column.
+            # We want people to be able to rerun this script from an old port
+            # so that they can pick up any missing events that were not
+            # ported across.
+            def alter_table(txn):
+                txn.execute(
+                    "ALTER TABLE IF EXISTS port_from_sqlite3"
+                    " RENAME rowid TO forward_rowid"
+                )
+                txn.execute(
+                    "ALTER TABLE IF EXISTS port_from_sqlite3"
+                    " ADD backward_rowid bigint NOT NULL DEFAULT 0"
+                )
+
+            try:
+                yield self.postgres_store.runInteraction(
+                    "alter_table", alter_table
+                )
+            except Exception as e:
+                logger.info("Failed to create port table: %s", e)
+
             try:
                 yield self.postgres_store.runInteraction(
                     "create_port_table", create_port_table
@@ -458,7 +538,7 @@ class Porter(object):
     @defer.inlineCallbacks
     def _setup_sent_transactions(self):
         # Only save things from the last day
-        yesterday = int(time.time()*1000) - 86400000
+        yesterday = int(time.time() * 1000) - 86400000
 
         # And save the max transaction id from each destination
         select = (
@@ -514,7 +594,11 @@ class Porter(object):
 
         yield self.postgres_store._simple_insert(
             table="port_from_sqlite3",
-            values={"table_name": "sent_transactions", "rowid": next_chunk}
+            values={
+                "table_name": "sent_transactions",
+                "forward_rowid": next_chunk,
+                "backward_rowid": 0,
+            }
         )
 
         def get_sent_table_size(txn):
@@ -535,13 +619,18 @@ class Porter(object):
         defer.returnValue((next_chunk, inserted_rows, total_count))
 
     @defer.inlineCallbacks
-    def _get_remaining_count_to_port(self, table, next_chunk):
-        rows = yield self.sqlite_store.execute_sql(
+    def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk):
+        frows = yield self.sqlite_store.execute_sql(
             "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
-            next_chunk,
+            forward_chunk,
         )
 
-        defer.returnValue(rows[0][0])
+        brows = yield self.sqlite_store.execute_sql(
+            "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,),
+            backward_chunk,
+        )
+
+        defer.returnValue(frows[0][0] + brows[0][0])
 
     @defer.inlineCallbacks
     def _get_already_ported_count(self, table):
@@ -552,10 +641,10 @@ class Porter(object):
         defer.returnValue(rows[0][0])
 
     @defer.inlineCallbacks
-    def _get_total_count_to_port(self, table, next_chunk):
+    def _get_total_count_to_port(self, table, forward_chunk, backward_chunk):
         remaining, done = yield defer.gatherResults(
             [
-                self._get_remaining_count_to_port(table, next_chunk),
+                self._get_remaining_count_to_port(table, forward_chunk, backward_chunk),
                 self._get_already_ported_count(table),
             ],
             consumeErrors=True,
@@ -686,7 +775,7 @@ class CursesProgress(Progress):
             color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
 
             self.stdscr.addstr(
-                i+2, left_margin + max_len - len(table),
+                i + 2, left_margin + max_len - len(table),
                 table,
                 curses.A_BOLD | color,
             )
@@ -694,18 +783,18 @@ class CursesProgress(Progress):
             size = 20
 
             progress = "[%s%s]" % (
-                "#" * int(perc*size/100),
-                " " * (size - int(perc*size/100)),
+                "#" * int(perc * size / 100),
+                " " * (size - int(perc * size / 100)),
             )
 
             self.stdscr.addstr(
-                i+2, left_margin + max_len + middle_space,
+                i + 2, left_margin + max_len + middle_space,
                 "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
             )
 
         if self.finished:
             self.stdscr.addstr(
-                rows-1, 0,
+                rows - 1, 0,
                 "Press any key to exit...",
             )