summary refs log tree commit diff
path: root/synapse/storage/prepare_database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/prepare_database.py')
-rw-r--r--synapse/storage/prepare_database.py76
1 files changed, 73 insertions, 3 deletions
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 72b670b83b..b290f834b3 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -1,5 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
+# Copyright 2018 New Vector Ltd
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -19,13 +20,12 @@ import logging
 import os
 import re
 
-
 logger = logging.getLogger(__name__)
 
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 43
+SCHEMA_VERSION = 50
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -44,6 +44,13 @@ def prepare_database(db_conn, database_engine, config):
 
     If `config` is None then prepare_database will assert that no upgrade is
     necessary, *or* will create a fresh database if the database is empty.
+
+    Args:
+        db_conn:
+        database_engine:
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            application config, or None if we are connecting to an existing
+            database which we expect to be configured already
     """
     try:
         cur = db_conn.cursor()
@@ -64,9 +71,13 @@ def prepare_database(db_conn, database_engine, config):
         else:
             _setup_new_database(cur, database_engine)
 
+        # check if any of our configured dynamic modules want a database
+        if config is not None:
+            _apply_module_schemas(cur, database_engine, config)
+
         cur.close()
         db_conn.commit()
-    except:
+    except Exception:
         db_conn.rollback()
         raise
 
@@ -283,6 +294,65 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
             )
 
 
+def _apply_module_schemas(txn, database_engine, config):
+    """Apply the module schemas for the dynamic modules, if any
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        config (synapse.config.homeserver.HomeServerConfig):
+            application config
+    """
+    for (mod, _config) in config.password_providers:
+        if not hasattr(mod, 'get_db_schema_files'):
+            continue
+        modname = ".".join((mod.__module__, mod.__name__))
+        _apply_module_schema_files(
+            txn, database_engine, modname, mod.get_db_schema_files(),
+        )
+
+
+def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+    """Apply the module schemas for a single module
+
+    Args:
+        cur: database cursor
+        database_engine: synapse database engine class
+        modname (str): fully qualified name of the module
+        names_and_streams (Iterable[(str, file)]): the names and streams of
+            schemas to be applied
+    """
+    cur.execute(
+        database_engine.convert_param_style(
+            "SELECT file FROM applied_module_schemas WHERE module_name = ?"
+        ),
+        (modname,)
+    )
+    applied_deltas = set(d for d, in cur)
+    for (name, stream) in names_and_streams:
+        if name in applied_deltas:
+            continue
+
+        root_name, ext = os.path.splitext(name)
+        if ext != '.sql':
+            raise PrepareDatabaseException(
+                "only .sql files are currently supported for module schemas",
+            )
+
+        logger.info("applying schema %s for %s", name, modname)
+        for statement in get_statements(stream):
+            cur.execute(statement)
+
+        # Mark as done.
+        cur.execute(
+            database_engine.convert_param_style(
+                "INSERT INTO applied_module_schemas (module_name, file)"
+                " VALUES (?,?)",
+            ),
+            (modname, name)
+        )
+
+
 def get_statements(f):
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment