summary refs log tree commit diff
path: root/scripts/synapse_port_db
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/synapse_port_db')
-rwxr-xr-xscripts/synapse_port_db148
1 files changed, 104 insertions, 44 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 7d158a46a4..b9b828c154 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 # -*- coding: utf-8 -*-
 # Copyright 2015, 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -29,6 +30,8 @@ import time
 import traceback
 import yaml
 
+from six import string_types
+
 
 logger = logging.getLogger("synapse_port_db")
 
@@ -42,6 +45,14 @@ BOOLEAN_COLUMNS = {
     "public_room_list_stream": ["visibility"],
     "device_lists_outbound_pokes": ["sent"],
     "users_who_share_rooms": ["share_private"],
+    "groups": ["is_public"],
+    "group_rooms": ["is_public"],
+    "group_users": ["is_public", "is_admin"],
+    "group_summary_rooms": ["is_public"],
+    "group_room_categories": ["is_public"],
+    "group_summary_users": ["is_public"],
+    "group_roles": ["is_public"],
+    "local_group_membership": ["is_publicised", "is_admin"],
 }
 
 
@@ -112,6 +123,7 @@ class Store(object):
 
     _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
     _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
+    _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
 
     def runInteraction(self, desc, func, *args, **kwargs):
         def r(conn):
@@ -241,6 +253,12 @@ class Porter(object):
     @defer.inlineCallbacks
     def handle_table(self, table, postgres_size, table_size, forward_chunk,
                      backward_chunk):
+        logger.info(
+            "Table %s: %i/%i (rows %i-%i) already ported",
+            table, postgres_size, table_size,
+            backward_chunk+1, forward_chunk-1,
+        )
+
         if not table_size:
             return
 
@@ -252,6 +270,25 @@ class Porter(object):
             )
             return
 
+        if table in (
+            "user_directory", "user_directory_search", "users_who_share_rooms",
+            "users_in_pubic_room",
+        ):
+            # We don't port these tables, as they're a faff and we can regenreate
+            # them anyway.
+            self.progress.update(table, table_size)  # Mark table as done
+            return
+
+        if table == "user_directory_stream_pos":
+            # We need to make sure there is a single row, `(X, null), as that is
+            # what synapse expects to be there.
+            yield self.postgres_store._simple_insert(
+                table=table,
+                values={"stream_id": None},
+            )
+            self.progress.update(table, table_size)  # Mark table as done
+            return
+
         forward_select = (
             "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
             % (table,)
@@ -299,7 +336,7 @@ class Porter(object):
                     backward_chunk = min(row[0] for row in brows) - 1
 
                 rows = frows + brows
-                self._convert_rows(table, headers, rows)
+                rows = self._convert_rows(table, headers, rows)
 
                 def insert(txn):
                     self.postgres_store.insert_many_txn(
@@ -357,10 +394,13 @@ class Porter(object):
                         " VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
                     )
 
-                    rows_dict = [
-                        dict(zip(headers, row))
-                        for row in rows
-                    ]
+                    rows_dict = []
+                    for row in rows:
+                        d = dict(zip(headers, row))
+                        if "\0" in d['value']:
+                            logger.warn('dropping search row %s', d)
+                        else:
+                            rows_dict.append(d)
 
                     txn.executemany(sql, [
                         (
@@ -436,31 +476,10 @@ class Porter(object):
             self.progress.set_state("Preparing PostgreSQL")
             self.setup_db(postgres_config, postgres_engine)
 
-            # Step 2. Get tables.
-            self.progress.set_state("Fetching tables")
-            sqlite_tables = yield self.sqlite_store._simple_select_onecol(
-                table="sqlite_master",
-                keyvalues={
-                    "type": "table",
-                },
-                retcol="name",
-            )
-
-            postgres_tables = yield self.postgres_store._simple_select_onecol(
-                table="information_schema.tables",
-                keyvalues={},
-                retcol="distinct table_name",
-            )
-
-            tables = set(sqlite_tables) & set(postgres_tables)
-
-            self.progress.set_state("Creating tables")
-
-            logger.info("Found %d tables", len(tables))
-
+            self.progress.set_state("Creating port tables")
             def create_port_table(txn):
                 txn.execute(
-                    "CREATE TABLE port_from_sqlite3 ("
+                    "CREATE TABLE IF NOT EXISTS port_from_sqlite3 ("
                     " table_name varchar(100) NOT NULL UNIQUE,"
                     " forward_rowid bigint NOT NULL,"
                     " backward_rowid bigint NOT NULL"
@@ -486,18 +505,33 @@ class Porter(object):
                     "alter_table", alter_table
                 )
             except Exception as e:
-                logger.info("Failed to create port table: %s", e)
+                pass
 
-            try:
-                yield self.postgres_store.runInteraction(
-                    "create_port_table", create_port_table
-                )
-            except Exception as e:
-                logger.info("Failed to create port table: %s", e)
+            yield self.postgres_store.runInteraction(
+                "create_port_table", create_port_table
+            )
 
-            self.progress.set_state("Setting up")
+            # Step 2. Get tables.
+            self.progress.set_state("Fetching tables")
+            sqlite_tables = yield self.sqlite_store._simple_select_onecol(
+                table="sqlite_master",
+                keyvalues={
+                    "type": "table",
+                },
+                retcol="name",
+            )
 
-            # Set up tables.
+            postgres_tables = yield self.postgres_store._simple_select_onecol(
+                table="information_schema.tables",
+                keyvalues={},
+                retcol="distinct table_name",
+            )
+
+            tables = set(sqlite_tables) & set(postgres_tables)
+            logger.info("Found %d tables", len(tables))
+
+            # Step 3. Figure out what still needs copying
+            self.progress.set_state("Checking on port progress")
             setup_res = yield defer.gatherResults(
                 [
                     self.setup_table(table)
@@ -508,7 +542,8 @@ class Porter(object):
                 consumeErrors=True,
             )
 
-            # Process tables.
+            # Step 4. Do the copying.
+            self.progress.set_state("Copying to postgres")
             yield defer.gatherResults(
                 [
                     self.handle_table(*res)
@@ -517,6 +552,9 @@ class Porter(object):
                 consumeErrors=True,
             )
 
+            # Step 5. Do final post-processing
+            yield self._setup_state_group_id_seq()
+
             self.progress.done()
         except:
             global end_error_exec_info
@@ -532,17 +570,29 @@ class Porter(object):
             i for i, h in enumerate(headers) if h in bool_col_names
         ]
 
+        class BadValueException(Exception):
+            pass
+
         def conv(j, col):
             if j in bool_cols:
                 return bool(col)
+            elif isinstance(col, string_types) and "\0" in col:
+                logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
+                raise BadValueException();
             return col
 
+        outrows = []
         for i, row in enumerate(rows):
-            rows[i] = tuple(
-                conv(j, col)
-                for j, col in enumerate(row)
-                if j > 0
-            )
+            try:
+                outrows.append(tuple(
+                    conv(j, col)
+                    for j, col in enumerate(row)
+                    if j > 0
+                ))
+            except BadValueException:
+                pass
+
+        return outrows
 
     @defer.inlineCallbacks
     def _setup_sent_transactions(self):
@@ -570,7 +620,7 @@ class Porter(object):
             "select", r,
         )
 
-        self._convert_rows("sent_transactions", headers, rows)
+        rows = self._convert_rows("sent_transactions", headers, rows)
 
         inserted_rows = len(rows)
         if inserted_rows:
@@ -664,6 +714,16 @@ class Porter(object):
 
         defer.returnValue((done, remaining + done))
 
+    def _setup_state_group_id_seq(self):
+        def r(txn):
+            txn.execute("SELECT MAX(id) FROM state_groups")
+            next_id = txn.fetchone()[0]+1
+            txn.execute(
+                "ALTER SEQUENCE state_group_id_seq RESTART WITH %s",
+                (next_id,),
+            )
+        return self.postgres_store.runInteraction("setup_state_group_id_seq", r)
+
 
 ##############################################
 ###### The following is simply UI stuff ######