diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index d6ec446bd2..a08c74fac1 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -45,8 +45,10 @@ from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import compute_event_reference_hash
+import imp
import logging
import os
+import sqlite3
logger = logging.getLogger(__name__)
@@ -610,49 +612,115 @@ class UpgradeDatabaseException(PrepareDatabaseException):
def prepare_database(db_conn):
- """ Set up all the dbs. Since all the *.sql have IF NOT EXISTS, so we
- don't have to worry about overwriting existing content.
+ """Prepares a database for usage. Will either create all necessary tables
+ or upgrade from an older schema version.
"""
- c = db_conn.cursor()
- c.execute("PRAGMA user_version")
- row = c.fetchone()
+ cur = db_conn.cursor()
+ version_info = get_schema_state(cur)
- if row and row[0]:
- user_version = row[0]
+ if version_info:
+ user_version, delta_files = version_info
+ _upgrade_existing_database(cur, user_version, delta_files)
+ else:
+ _setup_new_database(cur)
+
+ cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
+ db_conn.commit()
+
+ cur.close()
+
+
+def _setup_new_database(cur):
+ sql_script = "BEGIN TRANSACTION;\n"
+ for sql_loc in SCHEMAS:
+ logger.debug("Applying schema %r", sql_loc)
+ sql_script += read_schema(sql_loc)
+ sql_script += "\n"
+ sql_script += "COMMIT TRANSACTION;"
+ cur.executescript(sql_script)
+
+
+def _upgrade_existing_database(cur, user_version, delta_files):
+ """Upgrades an existing database.
+
+ Delta files can either be SQL stored in *.sql files, or python modules
+ in *.py.
+
+ There can be multiple delta files per version. Synapse will keep track of
+ which delta files have been applied, and will apply any that haven't been
+ even if there has been no version bump. This is useful for development
+ where orthogonal schema changes may happen on separate branches.
+ """
+
+ if user_version > SCHEMA_VERSION:
+ raise ValueError(
+ "Cannot use this database as it is too " +
+ "new for the server to understand"
+ )
- if user_version > SCHEMA_VERSION:
- raise ValueError(
- "Cannot use this database as it is too " +
- "new for the server to understand"
+ for v in range(user_version, SCHEMA_VERSION + 1):
+ delta_dir = os.path.join(dir_path, "schema", "delta", v)
+ directory_entries = os.listdir(delta_dir)
+
+ for file_name in directory_entries:
+ relative_path = os.path.join(v, file_name)
+ if relative_path in delta_files:
+ continue
+
+ absolute_path = os.path.join(
+ dir_path, "schema", "delta", relative_path,
)
- elif user_version < SCHEMA_VERSION:
- logger.info(
- "Upgrading database from version %d",
- user_version
+ root_name, ext = os.path.splitext(file_name)
+ if ext == ".py":
+ module_name = "synapse.storage.schema.v%d_%s" % (
+ v, root_name
+ )
+ with open(absolute_path) as schema_file:
+ module = imp.load_source(
+ module_name, absolute_path, schema_file
+ )
+ module.run_upgrade(cur)
+ elif ext == ".sql":
+ with open(absolute_path) as schema_file:
+ delta_schema = schema_file.read()
+ cur.executescript(delta_schema)
+ else:
+ # Not a valid delta file.
+ logger.warn(
+ "Found directory entry that did not end in .py or"
+ " .sql: %s",
+ relative_path,
+ )
+ continue
+
+ # Mark as done.
+ cur.execute(
+ "INSERT INTO schema_version (version, file)"
+ " VALUES (?,?)",
+ (v, relative_path)
)
- # Run every version since after the current version.
- for v in range(user_version + 1, SCHEMA_VERSION + 1):
- if v in (10, 14,):
- raise UpgradeDatabaseException(
- "No delta for version 10"
- )
- sql_script = read_schema("delta/v%d" % (v,))
- c.executescript(sql_script)
- db_conn.commit()
- else:
- logger.info("Database is at version %r", user_version)
+def get_schema_state(txn):
+ sql = (
+ "SELECT MAX(version), file FROM schema_version"
+ " WHERE version = (SELECT MAX(version) FROM schema_version)"
+ )
- else:
- sql_script = "BEGIN TRANSACTION;\n"
- for sql_loc in SCHEMAS:
- logger.debug("Applying schema %r", sql_loc)
- sql_script += read_schema(sql_loc)
- sql_script += "\n"
- sql_script += "COMMIT TRANSACTION;"
- c.executescript(sql_script)
- db_conn.commit()
- c.execute("PRAGMA user_version = %d" % SCHEMA_VERSION)
-
- c.close()
+ try:
+ txn.execute(sql)
+ res = txn.fetchall()
+
+ if res:
+ current_verison = max(r[0] for r in res)
+ applied_delta = [r[1] for r in res]
+
+ return current_verison, applied_delta
+ except sqlite3.OperationalError:
+ txn.execute("PRAGMA user_version")
+ row = txn.fetchone()
+ if row and row[0]:
+ # FIXME: We need to create schema_version table!
+ return row[0], []
+
+ return None
|