summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/storage/__init__.py144
1 files changed, 106 insertions, 38 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d6ec446bd2..a08c74fac1 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -45,8 +45,10 @@ from syutil.jsonutil import encode_canonical_json
 from synapse.crypto.event_signing import compute_event_reference_hash
 
 
+import imp
 import logging
 import os
+import sqlite3
 
 
 logger = logging.getLogger(__name__)
@@ -610,49 +612,115 @@ class UpgradeDatabaseException(PrepareDatabaseException):
 
 
 def prepare_database(db_conn):
-    """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
-    don't have to worry about overwriting existing content.
+    """Prepares a database for usage. Will either create all necessary tables
+    or upgrade from an older schema version.
     """
-    c = db_conn.cursor()
-    c.execute("PRAGMA user_version")
-    row = c.fetchone()
+    cur = db_conn.cursor()
+    version_info = get_schema_state(cur)
 
-    if row and row[0]:
-        user_version = row[0]
+    if version_info:
+        user_version, delta_files = version_info
+        _upgrade_existing_database(cur, user_version, delta_files)
+    else:
+        _setup_new_database(cur)
+
+    cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
+    db_conn.commit()
+
+    cur.close()
+
+
+def _setup_new_database(cur):
+    sql_script = "BEGIN TRANSACTION;\n"
+    for sql_loc in SCHEMAS:
+        logger.debug("Applying schema %r", sql_loc)
+        sql_script += read_schema(sql_loc)
+        sql_script += "\n"
+    sql_script += "COMMIT TRANSACTION;"
+    cur.executescript(sql_script)
+
+
+def _upgrade_existing_database(cur, user_version, delta_files):
+    """Upgrades an existing database.
+
+    Delta files can either be SQL stored in *.sql files, or python modules
+    in *.py.
+
+    There can be multiple delta files per version. Synapse will keep track of
+    which delta files have been applied, and will apply any that haven't been
+    even if there has been no version bump. This is useful for development
+    where orthogonal schema changes may happen on separate branches.
+    """
+
+    if user_version > SCHEMA_VERSION:
+        raise ValueError(
+            "Cannot use this database as it is too " +
+            "new for the server to understand"
+        )
 
-        if user_version > SCHEMA_VERSION:
-            raise ValueError(
-                "Cannot use this database as it is too " +
-                "new for the server to understand"
+    for v in range(user_version, SCHEMA_VERSION + 1):
+        delta_dir = os.path.join(dir_path, "schema", "delta", v)
+        directory_entries = os.listdir(delta_dir)
+
+        for file_name in directory_entries:
+            relative_path = os.path.join(v, file_name)
+            if relative_path in delta_files:
+                continue
+
+            absolute_path = os.path.join(
+                dir_path, "schema", "delta", relative_path,
             )
-        elif user_version < SCHEMA_VERSION:
-            logger.info(
-                "Upgrading database from version %d",
-                user_version
+            root_name, ext = os.path.splitext(file_name)
+            if ext == ".py":
+                module_name = "synapse.storage.schema.v%d_%s" % (
+                    v, root_name
+                )
+                with open(absolute_path) as schema_file:
+                    module = imp.load_source(
+                        module_name, absolute_path, schema_file
+                    )
+                module.run_upgrade(cur)
+            elif ext == ".sql":
+                with open(absolute_path) as schema_file:
+                    delta_schema = schema_file.read()
+                cur.executescript(delta_schema)
+            else:
+                # Not a valid delta file.
+                logger.warn(
+                    "Found directory entry that did not end in .py or"
+                    " .sql: %s",
+                    relative_path,
+                )
+                continue
+
+            # Mark as done.
+            cur.execute(
+                "INSERT INTO schema_version (version, file)"
+                " VALUES (?,?)",
+                (v, relative_path)
             )
 
-            # Run every version since after the current version.
-            for v in range(user_version + 1, SCHEMA_VERSION + 1):
-                if v in (10, 14,):
-                    raise UpgradeDatabaseException(
-                        "No delta for version 10"
-                    )
-                sql_script = read_schema("delta/v%d" % (v,))
-                c.executescript(sql_script)
 
-            db_conn.commit()
-        else:
-            logger.info("Database is at version %r", user_version)
+def get_schema_state(txn):
+    sql = (
+        "SELECT MAX(version), file FROM schema_version"
+        " WHERE version = (SELECT MAX(version) FROM schema_version)"
+    )
 
-    else:
-        sql_script = "BEGIN TRANSACTION;\n"
-        for sql_loc in SCHEMAS:
-            logger.debug("Applying schema %r", sql_loc)
-            sql_script += read_schema(sql_loc)
-            sql_script += "\n"
-        sql_script += "COMMIT TRANSACTION;"
-        c.executescript(sql_script)
-        db_conn.commit()
-        c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
-
-    c.close()
+    try:
+        txn.execute(sql)
+        res = txn.fetchall()
+
+        if res:
+            current_verison = max(r[0] for r in res)
+            applied_delta = [r[1] for r in res]
+
+            return current_verison, applied_delta
+    except sqlite3.OperationalError:
+        txn.execute("PRAGMA user_version")
+        row = txn.fetchone()
+        if row and row[0]:
+            # FIXME: We need to create schema_version table!
+            return row[0], []
+
+    return None