summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--docs/password_auth_providers.rst12
-rwxr-xr-xscripts/synapse_port_db26
-rw-r--r--synapse/storage/prepare_database.py70
-rw-r--r--synapse/storage/schema/schema_version.sql7
-rw-r--r--tests/utils.py1
5 files changed, 109 insertions, 7 deletions
diff --git a/docs/password_auth_providers.rst b/docs/password_auth_providers.rst
index 3da1a67844..ca05a76617 100644
--- a/docs/password_auth_providers.rst
+++ b/docs/password_auth_providers.rst
@@ -37,3 +37,15 @@ Password auth provider classes must provide the following methods:
 
     The method should return a Twisted ``Deferred`` object, which resolves to
     ``True`` if authentication is successful, and ``False`` if not.
+
+Optional methods
+----------------
+
+Password provider classes may optionally provide the following methods.
+
+*class* ``SomeProvider.get_db_schema_files()``
+
+    This method, if implemented, should return an Iterable of ``(name,
+    stream)`` pairs of database schema files. Each file is applied in turn at
+    initialisation, and a record is then made in the database so that it is
+    not re-applied on the next start.
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index d6d8ee50cb..3a8972efc3 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -320,7 +320,7 @@ class Porter(object):
                     backward_chunk = min(row[0] for row in brows) - 1
 
                 rows = frows + brows
-                self._convert_rows(table, headers, rows)
+                rows = self._convert_rows(table, headers, rows)
 
                 def insert(txn):
                     self.postgres_store.insert_many_txn(
@@ -556,17 +556,29 @@ class Porter(object):
             i for i, h in enumerate(headers) if h in bool_col_names
         ]
 
+        class BadValueException(Exception):
+            pass
+
         def conv(j, col):
             if j in bool_cols:
                 return bool(col)
+            elif isinstance(col, basestring) and "\0" in col:
+                logger.warn("DROPPING ROW: NUL value in table %s col %s: %r", table, headers[j], col)
+                raise BadValueException();
             return col
 
+        outrows = []
         for i, row in enumerate(rows):
-            rows[i] = tuple(
-                conv(j, col)
-                for j, col in enumerate(row)
-                if j > 0
-            )
+            try:
+                outrows.append(tuple(
+                    conv(j, col)
+                    for j, col in enumerate(row)
+                    if j > 0
+                ))
+            except BadValueException:
+                pass
+
+        return outrows
 
     @defer.inlineCallbacks
     def _setup_sent_transactions(self):
@@ -594,7 +606,7 @@ class Porter(object):
             "select", r,
         )
 
-        self._convert_rows("sent_transactions", headers, rows)
+        rows = self._convert_rows("sent_transactions", headers, rows)
 
         inserted_rows = len(rows)
         if inserted_rows:
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index a4e08e6757..d1691bbac2 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
 
     If `config` is None then prepare_database will assert that no upgrade is
     necessary, *or* will create a fresh database if the database is empty.
+
+    Args:
+        db_conn:
+        database_engine:
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            application config, or None if we are connecting to an existing
+            database which we expect to be configured already
     """
     try:
         cur = db_conn.cursor()
@@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config):
         else:
             _setup_new_database(cur, database_engine)
 
+        # check if any of our configured dynamic modules want a database
+        if config is not None:
+            _apply_module_schemas(cur, database_engine, config)
+
         cur.close()
         db_conn.commit()
     except Exception:
@@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
             )
 
 
+def _apply_module_schemas(txn, database_engine, config):
+    """Apply the module schemas for the dynamic modules, if any
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        config (synapse.config.homeserver.HomeServerConfig):
+            application config
+    """
+    for (mod, _config) in config.password_providers:
+        if not hasattr(mod, 'get_db_schema_files'):
+            continue
+        modname = ".".join((mod.__module__, mod.__name__))
+        _apply_module_schema_files(
+            txn, database_engine, modname, mod.get_db_schema_files(),
+        )
+
+
+def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+    """Apply the module schemas for a single module
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        modname (str): fully qualified name of the module
+        names_and_streams (Iterable[(str, file)]): the names and streams of
+            schemas to be applied
+    """
+    cur.execute(
+        database_engine.convert_param_style(
+            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
+        ),
+        (modname,)
+    )
+    applied_deltas = set(d for d, in cur)
+    for (name, stream) in names_and_streams:
+        if name in applied_deltas:
+            continue
+
+        root_name, ext = os.path.splitext(name)
+        if ext != '.sql':
+            raise PrepareDatabaseException(
+                "only .sql files are currently supported for module schemas",
+            )
+
+        logger.info("applying schema %s for %s", name, modname)
+        for statement in get_statements(stream):
+            cur.execute(statement)
+
+        # Mark as done.
+        cur.execute(
+            database_engine.convert_param_style(
+                "INSERT INTO applied_module_schemas (module_name, file)"
+                " VALUES (?,?)",
+            ),
+            (modname, name)
+        )
+
+
 def get_statements(f):
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment
diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql
index a7ade69986..42e5cb6df5 100644
--- a/synapse/storage/schema/schema_version.sql
+++ b/synapse/storage/schema/schema_version.sql
@@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
     file TEXT NOT NULL,
     UNIQUE(version, file)
 );
+
+-- a list of schema files we have loaded on behalf of dynamic modules
+CREATE TABLE IF NOT EXISTS applied_module_schemas(
+    module_name TEXT NOT NULL,
+    file TEXT NOT NULL,
+    UNIQUE(module_name, file)
+);
diff --git a/tests/utils.py b/tests/utils.py
index d2ebce4b2e..ed8a7360f5 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -310,6 +310,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
         )
 
         self.config = Mock()
+        self.config.password_providers = []
         self.config.database_config = {"name": "sqlite3"}
 
     def prepare(self):