summary refs log tree commit diff
path: root/scripts/synapse_port_db
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/synapse_port_db')
-rwxr-xr-xscripts/synapse_port_db35
1 files changed, 28 insertions, 7 deletions
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index dc7fe940e8..d46581e4e1 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -42,6 +42,14 @@ BOOLEAN_COLUMNS = {
     "public_room_list_stream": ["visibility"],
     "device_lists_outbound_pokes": ["sent"],
     "users_who_share_rooms": ["share_private"],
+    "groups": ["is_public"],
+    "group_rooms": ["is_public"],
+    "group_users": ["is_public", "is_admin"],
+    "group_summary_rooms": ["is_public"],
+    "group_room_categories": ["is_public"],
+    "group_summary_users": ["is_public"],
+    "group_roles": ["is_public"],
+    "local_group_membership": ["is_publicised", "is_admin"],
 }
 
 
@@ -112,6 +120,7 @@ class Store(object):
 
     _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
     _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
+    _simple_update_txn = SQLBaseStore.__dict__["_simple_update_txn"]
 
     def runInteraction(self, desc, func, *args, **kwargs):
         def r(conn):
@@ -318,7 +327,7 @@ class Porter(object):
                     backward_chunk = min(row[0] for row in brows) - 1
 
                 rows = frows + brows
-                self._convert_rows(table, headers, rows)
+                rows = self._convert_rows(table, headers, rows)
 
                 def insert(txn):
                     self.postgres_store.insert_many_txn(
@@ -554,17 +563,29 @@ class Porter(object):
             i for i, h in enumerate(headers) if h in bool_col_names
         ]
 
+        class BadValueException(Exception):
+            pass
+
         def conv(j, col):
             if j in bool_cols:
                 return bool(col)
+            elif isinstance(col, basestring) and "\0" in col:
+                logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
+                raise BadValueException();
             return col
 
+        outrows = []
         for i, row in enumerate(rows):
-            rows[i] = tuple(
-                conv(j, col)
-                for j, col in enumerate(row)
-                if j > 0
-            )
+            try:
+                outrows.append(tuple(
+                    conv(j, col)
+                    for j, col in enumerate(row)
+                    if j > 0
+                ))
+            except BadValueException:
+                pass
+
+        return outrows
 
     @defer.inlineCallbacks
     def _setup_sent_transactions(self):
@@ -592,7 +613,7 @@ class Porter(object):
             "select", r,
         )
 
-        self._convert_rows("sent_transactions", headers, rows)
+        rows = self._convert_rows("sent_transactions", headers, rows)
 
         inserted_rows = len(rows)
         if inserted_rows: