summary refs log tree commit diff
path: root/synapse/storage/__init__.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-03-04 12:04:19 +0000
committerErik Johnston <erik@matrix.org>2015-03-04 12:04:19 +0000
commit82b34e813de4dadb8ec5bce068f7113e32e60ead (patch)
tree882f9d407ff6afe56e57d9f3e0cf701b658581c6 /synapse/storage/__init__.py
parentSYN-67: Begin changing the way we handle schema versioning (diff)
downloadsynapse-82b34e813de4dadb8ec5bce068f7113e32e60ead.tar.xz
SYN-67: Finish up implementing new database schema management
Diffstat (limited to 'synapse/storage/__init__.py')
-rw-r--r--synapse/storage/__init__.py197
1 files changed, 131 insertions, 66 deletions
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index a08c74fac1..07ccc4e2ee 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -45,35 +45,16 @@ from syutil.jsonutil import encode_canonical_json
 from synapse.crypto.event_signing import compute_event_reference_hash
 
 
+import fnmatch
 import imp
 import logging
 import os
-import sqlite3
+import re
 
 
 logger = logging.getLogger(__name__)
 
 
-SCHEMAS = [
-    "transactions",
-    "users",
-    "profiles",
-    "presence",
-    "im",
-    "room_aliases",
-    "keys",
-    "redactions",
-    "state",
-    "event_edges",
-    "event_signatures",
-    "pusher",
-    "media_repository",
-    "application_services",
-    "filtering",
-    "rejections",
-]
-
-
 # Remember to update this number every time an incompatible change is made to
 # database schema files, so the users will be informed on server restarts.
 SCHEMA_VERSION = 14
@@ -578,28 +559,15 @@ class DataStore(RoomMemberStore, RoomStore,
         )
 
 
-def schema_path(schema):
-    """ Get a filesystem path for the named database schema
-
-    Args:
-        schema: Name of the database schema.
-    Returns:
-        A filesystem path pointing at a ".sql" file.
-
-    """
-    schemaPath = os.path.join(dir_path, "schema", schema + ".sql")
-    return schemaPath
-
-
-def read_schema(schema):
+def read_schema(path):
     """ Read the named database schema.
 
     Args:
-        schema: Name of the datbase schema.
+        path: Path of the database schema.
     Returns:
         A string containing the database schema.
     """
-    with open(schema_path(schema)) as schema_file:
+    with open(path) as schema_file:
         return schema_file.read()
 
 
@@ -616,11 +584,11 @@ def prepare_database(db_conn):
     or upgrade from an older schema version.
     """
     cur = db_conn.cursor()
-    version_info = get_schema_state(cur)
+    version_info = get_or_create_schema_state(cur)
 
     if version_info:
-        user_version, delta_files = version_info
-        _upgrade_existing_database(cur, user_version, delta_files)
+        user_version, delta_files, upgraded = version_info
+        _upgrade_existing_database(cur, user_version, delta_files, upgraded)
     else:
         _setup_new_database(cur)
 
@@ -631,16 +599,52 @@ def prepare_database(db_conn):
 
 
 def _setup_new_database(cur):
+    current_dir = os.path.join(dir_path, "schema", "current")
+    directory_entries = os.listdir(current_dir)
+
+    valid_dirs = []
+    pattern = re.compile(r"^\d+(\.sql)?$")
+    for filename in directory_entries:
+        match = pattern.match(filename)
+        abs_path = os.path.join(current_dir, filename)
+        if match and os.path.isdir(abs_path):
+            ver = int(match.group(0))
+            if ver < SCHEMA_VERSION:
+                valid_dirs.append((ver, abs_path))
+
+    if not valid_dirs:
+        raise RuntimeError("Could not find a suitable current.sql")
+
+    max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
+
+    logger.debug("Initialising schema v%d", max_current_ver)
+
+    directory_entries = os.listdir(sql_dir)
+
     sql_script = "BEGIN TRANSACTION;\n"
-    for sql_loc in SCHEMAS:
+    for filename in fnmatch.filter(directory_entries, "*.sql"):
+        sql_loc = os.path.join(sql_dir, filename)
         logger.debug("Applying schema %r", sql_loc)
         sql_script += read_schema(sql_loc)
         sql_script += "\n"
     sql_script += "COMMIT TRANSACTION;"
     cur.executescript(sql_script)
 
+    cur.execute(
+        "INSERT INTO schema_version (version, upgraded)"
+        " VALUES (?,?)",
+        (max_current_ver, False)
+    )
+
+    _upgrade_existing_database(
+        cur,
+        current_version=max_current_ver,
+        delta_files=[],
+        upgraded=False
+    )
+
 
-def _upgrade_existing_database(cur, user_version, delta_files):
+def _upgrade_existing_database(cur, current_version, delta_files, upgraded):
     """Upgrades an existing database.
 
     Delta files can either be SQL stored in *.sql files, or python modules
@@ -650,20 +654,41 @@ def _upgrade_existing_database(cur, user_version, delta_files):
     which delta files have been applied, and will apply any that haven't been
     even if there has been no version bump. This is useful for development
     where orthogonal schema changes may happen on separate branches.
+
+    Args:
+        cur (Cursor)
+        current_version (int): The current version of the schema
+        delta_files (list): A list of deltas that have already been applied
+        upgraded (bool): Whether the current version was generated by having
+            applied deltas or from full schema file. If `True` the function
+            will never apply delta files for the given `current_version`, since
+            the current_version wasn't generated by applying those delta files.
     """
 
-    if user_version > SCHEMA_VERSION:
+    if current_version > SCHEMA_VERSION:
         raise ValueError(
             "Cannot use this database as it is too " +
             "new for the server to understand"
         )
 
-    for v in range(user_version, SCHEMA_VERSION + 1):
-        delta_dir = os.path.join(dir_path, "schema", "delta", v)
-        directory_entries = os.listdir(delta_dir)
+    start_ver = current_version
+    if not upgraded:
+        start_ver += 1
+
+    for v in range(start_ver, SCHEMA_VERSION + 1):
+        logger.debug("Upgrading schema to v%d", v)
+
+        delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
+
+        try:
+            directory_entries = os.listdir(delta_dir)
+        except OSError:
+            logger.exception("Could not open delta dir for version %d", v)
+            raise
 
+        directory_entries.sort()
         for file_name in directory_entries:
-            relative_path = os.path.join(v, file_name)
+            relative_path = os.path.join(str(v), file_name)
             if relative_path in delta_files:
                 continue
 
@@ -672,17 +697,19 @@ def _upgrade_existing_database(cur, user_version, delta_files):
             )
             root_name, ext = os.path.splitext(file_name)
             if ext == ".py":
-                module_name = "synapse.storage.schema.v%d_%s" % (
+                module_name = "synapse.storage.v%d_%s" % (
                     v, root_name
                 )
                 with open(absolute_path) as schema_file:
                     module = imp.load_source(
                         module_name, absolute_path, schema_file
                     )
+                logger.debug("Running script %s", relative_path)
                 module.run_upgrade(cur)
             elif ext == ".sql":
                 with open(absolute_path) as schema_file:
                     delta_schema = schema_file.read()
+                logger.debug("Applying schema %s", relative_path)
                 cur.executescript(delta_schema)
             else:
                 # Not a valid delta file.
@@ -695,32 +722,70 @@ def _upgrade_existing_database(cur, user_version, delta_files):
 
             # Mark as done.
             cur.execute(
-                "INSERT INTO schema_version (version, file)"
+                "INSERT INTO schema_deltas (version, file)"
                 " VALUES (?,?)",
                 (v, relative_path)
             )
 
+            cur.execute(
+                "INSERT INTO schema_version (version, upgraded)"
+                " VALUES (?,?)",
+                (v, True)
+            )
+
 
-def get_schema_state(txn):
-    sql = (
-        "SELECT MAX(version), file FROM schema_version"
-        " WHERE version = (SELECT MAX(version) FROM schema_version)"
+def get_or_create_schema_state(txn):
+    schema_path = os.path.join(
+        dir_path, "schema", "schema_version.sql",
     )
+    create_schema = read_schema(schema_path)
+    txn.executescript(create_schema)
 
-    try:
-        txn.execute(sql)
+    txn.execute("SELECT version, upgraded FROM schema_version")
+    row = txn.fetchone()
+    current_version = int(row[0]) if row else None
+    upgraded = bool(row[1]) if row else None
+
+    if current_version:
+        txn.execute(
+            "SELECT file FROM schema_deltas WHERE version >= ?",
+            (current_version,)
+        )
         res = txn.fetchall()
+        return current_version, txn.fetchall(), upgraded
 
-        if res:
-            current_verison = max(r[0] for r in res)
-            applied_delta = [r[1] for r in res]
+    return None
 
-            return current_verison, applied_delta
-    except sqlite3.OperationalError:
-        txn.execute("PRAGMA user_version")
-        row = txn.fetchone()
-        if row and row[0]:
-            # FIXME: We need to create schema_version table!
-            return row[0], []
 
-    return None
+def prepare_sqlite3_database(db_conn):
+    """This function should be called before `prepare_database` on sqlite3
+    databases.
+
+    Since we changed the way we store the current schema version and handle
+    updates to schemas, we need a way to upgrade from the old method to the
+    new. This only affects sqlite databases since they were the only ones
+    supported at the time.
+    """
+    with db_conn:
+        schema_path = os.path.join(
+            dir_path, "schema", "schema_version.sql",
+        )
+        create_schema = read_schema(schema_path)
+        db_conn.executescript(create_schema)
+
+        c = db_conn.execute("SELECT * FROM schema_version")
+        rows = c.fetchall()
+        c.close()
+
+        if not rows:
+            c = db_conn.execute("PRAGMA user_version")
+            row = c.fetchone()
+            c.close()
+
+            if row and row[0]:
+                ver = row[0]
+                db_conn.execute(
+                    "INSERT INTO schema_version (version, upgraded)"
+                    " VALUES (?,?)",
+                    (row[0], False)
+                )