summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <github@rvanderhoff.org.uk>2017-10-31 17:20:47 +0000
committerGitHub <noreply@github.com>2017-10-31 17:20:47 +0000
commita72e4e3e2860694013dade663f46b82160d08add (patch)
treec868dffb9cf270632649e54fe4ba7f4cc685b188
parentMerge pull request #2611 from matrix-org/dbkr/port_script_drop_nuls (diff)
parentfix tests (diff)
downloadsynapse-a72e4e3e2860694013dade663f46b82160d08add.tar.xz
Merge pull request #2610 from matrix-org/rav/schema_for_pw_providers
DB schema interface for password auth providers
-rw-r--r--docs/password_auth_providers.rst12
-rw-r--r--synapse/storage/prepare_database.py70
-rw-r--r--synapse/storage/schema/schema_version.sql7
-rw-r--r--tests/utils.py1
4 files changed, 90 insertions, 0 deletions
diff --git a/docs/password_auth_providers.rst b/docs/password_auth_providers.rst
index 3da1a67844..ca05a76617 100644
--- a/docs/password_auth_providers.rst
+++ b/docs/password_auth_providers.rst
@@ -37,3 +37,15 @@ Password auth provider classes must provide the following methods:
 
     The method should return a Twisted ``Deferred`` object, which resolves to
     ``True`` if authentication is successful, and ``False`` if not.
+
+Optional methods
+----------------
+
+Password provider classes may optionally provide the following methods.
+
+*class* ``SomeProvider.get_db_schema_files()``
+
+    This method, if implemented, should return an Iterable of ``(name,
+    stream)`` pairs of database schema files. Each file is applied in turn at
+    initialisation, and a record is then made in the database so that it is
+    not re-applied on the next start.
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index a4e08e6757..d1691bbac2 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
 
     If `config` is None then prepare_database will assert that no upgrade is
     necessary, *or* will create a fresh database if the database is empty.
+
+    Args:
+        db_conn:
+        database_engine:
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            application config, or None if we are connecting to an existing
+            database which we expect to be configured already
     """
     try:
         cur = db_conn.cursor()
@@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config):
         else:
             _setup_new_database(cur, database_engine)
 
+        # check if any of our configured dynamic modules want a database
+        if config is not None:
+            _apply_module_schemas(cur, database_engine, config)
+
         cur.close()
         db_conn.commit()
     except Exception:
@@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
             )
 
 
+def _apply_module_schemas(txn, database_engine, config):
+    """Apply the module schemas for the dynamic modules, if any
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        config (synapse.config.homeserver.HomeServerConfig):
+            application config
+    """
+    for (mod, _config) in config.password_providers:
+        if not hasattr(mod, 'get_db_schema_files'):
+            continue
+        modname = ".".join((mod.__module__, mod.__name__))
+        _apply_module_schema_files(
+            txn, database_engine, modname, mod.get_db_schema_files(),
+        )
+
+
+def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+    """Apply the module schemas for a single module
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        modname (str): fully qualified name of the module
+        names_and_streams (Iterable[(str, file)]): the names and streams of
+            schemas to be applied
+    """
+    cur.execute(
+        database_engine.convert_param_style(
+            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
+        ),
+        (modname,)
+    )
+    applied_deltas = set(d for d, in cur)
+    for (name, stream) in names_and_streams:
+        if name in applied_deltas:
+            continue
+
+        root_name, ext = os.path.splitext(name)
+        if ext != '.sql':
+            raise PrepareDatabaseException(
+                "only .sql files are currently supported for module schemas",
+            )
+
+        logger.info("applying schema %s for %s", name, modname)
+        for statement in get_statements(stream):
+            cur.execute(statement)
+
+        # Mark as done.
+        cur.execute(
+            database_engine.convert_param_style(
+                "INSERT INTO applied_module_schemas (module_name, file)"
+                " VALUES (?,?)",
+            ),
+            (modname, name)
+        )
+
+
 def get_statements(f):
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment
diff --git a/synapse/storage/schema/schema_version.sql b/synapse/storage/schema/schema_version.sql
index a7ade69986..42e5cb6df5 100644
--- a/synapse/storage/schema/schema_version.sql
+++ b/synapse/storage/schema/schema_version.sql
@@ -25,3 +25,10 @@ CREATE TABLE IF NOT EXISTS applied_schema_deltas(
     file TEXT NOT NULL,
     UNIQUE(version, file)
 );
+
+-- a list of schema files we have loaded on behalf of dynamic modules
+CREATE TABLE IF NOT EXISTS applied_module_schemas(
+    module_name TEXT NOT NULL,
+    file TEXT NOT NULL,
+    UNIQUE(module_name, file)
+);
diff --git a/tests/utils.py b/tests/utils.py
index d2ebce4b2e..ed8a7360f5 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -310,6 +310,7 @@ class SQLiteMemoryDbPool(ConnectionPool, object):
         )
 
         self.config = Mock()
+        self.config.password_providers = []
         self.config.database_config = {"name": "sqlite3"}
 
     def prepare(self):