summary refs log tree commit diff
path: root/synapse/storage/prepare_database.py
diff options
context:
space:
mode:
authorRichard van der Hoff <github@rvanderhoff.org.uk>2017-10-31 17:20:47 +0000
committerGitHub <noreply@github.com>2017-10-31 17:20:47 +0000
commita72e4e3e2860694013dade663f46b82160d08add (patch)
treec868dffb9cf270632649e54fe4ba7f4cc685b188 /synapse/storage/prepare_database.py
parentMerge pull request #2611 from matrix-org/dbkr/port_script_drop_nuls (diff)
parentfix tests (diff)
downloadsynapse-a72e4e3e2860694013dade663f46b82160d08add.tar.xz
Merge pull request #2610 from matrix-org/rav/schema_for_pw_providers
DB schema interface for password auth providers
Diffstat (limited to 'synapse/storage/prepare_database.py')
-rw-r--r--synapse/storage/prepare_database.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index a4e08e6757..d1691bbac2 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
 
     If `config` is None then prepare_database will assert that no upgrade is
     necessary, *or* will create a fresh database if the database is empty.
+
+    Args:
+        db_conn:
+        database_engine:
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            application config, or None if we are connecting to an existing
+            database which we expect to be configured already
     """
     try:
         cur = db_conn.cursor()
@@ -64,6 +71,10 @@ def prepare_database(db_conn, database_engine, config):
         else:
             _setup_new_database(cur, database_engine)
 
+        # check if any of our configured dynamic modules want a database
+        if config is not None:
+            _apply_module_schemas(cur, database_engine, config)
+
         cur.close()
         db_conn.commit()
     except Exception:
@@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
             )
 
 
+def _apply_module_schemas(txn, database_engine, config):
+    """Apply the module schemas for the dynamic modules, if any
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        config (synapse.config.homeserver.HomeServerConfig):
+            application config
+    """
+    for (mod, _config) in config.password_providers:
+        if not hasattr(mod, 'get_db_schema_files'):
+            continue
+        modname = ".".join((mod.__module__, mod.__name__))
+        _apply_module_schema_files(
+            txn, database_engine, modname, mod.get_db_schema_files(),
+        )
+
+
+def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+    """Apply the module schemas for a single module
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        modname (str): fully qualified name of the module
+        names_and_streams (Iterable[(str, file)]): the names and streams of
+            schemas to be applied
+    """
+    cur.execute(
+        database_engine.convert_param_style(
+            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
+        ),
+        (modname,)
+    )
+    applied_deltas = set(d for d, in cur)
+    for (name, stream) in names_and_streams:
+        if name in applied_deltas:
+            continue
+
+        root_name, ext = os.path.splitext(name)
+        if ext != '.sql':
+            raise PrepareDatabaseException(
+                "only .sql files are currently supported for module schemas",
+            )
+
+        logger.info("applying schema %s for %s", name, modname)
+        for statement in get_statements(stream):
+            cur.execute(statement)
+
+        # Mark as done.
+        cur.execute(
+            database_engine.convert_param_style(
+                "INSERT INTO applied_module_schemas (module_name, file)"
+                " VALUES (?,?)",
+            ),
+            (modname, name)
+        )
+
+
 def get_statements(f):
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment