summary refs log tree commit diff
path: root/synapse/storage/prepare_database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/prepare_database.py')
-rw-r--r--synapse/storage/prepare_database.py235
1 files changed, 184 insertions, 51 deletions
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e96eed8a6d..9cc3b51fe6 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -14,20 +14,27 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import fnmatch
 import imp
 import logging
 import os
 import re
+from collections import Counter
+from typing import TextIO
+
+import attr
 
 from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.types import Cursor
 
 logger = logging.getLogger(__name__)
 
 
 # Remember to update this number every time a change is made to database
 # schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 56
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
+# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
+SCHEMA_VERSION = 58
 
 dir_path = os.path.abspath(os.path.dirname(__file__))
 
@@ -40,7 +47,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
     pass
 
 
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
     """Prepares a database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
 
@@ -53,7 +60,10 @@ def prepare_database(db_conn, database_engine, config):
         config (synapse.config.homeserver.HomeServerConfig|None):
             application config, or None if we are connecting to an existing
             database which we expect to be configured already
+        data_stores (list[str]): The name of the data stores that will be used
+            with this database. Defaults to all data stores.
     """
+
     try:
         cur = db_conn.cursor()
         version_info = _get_or_create_schema_state(cur, database_engine)
@@ -65,13 +75,22 @@ def prepare_database(db_conn, database_engine, config):
                 if user_version != SCHEMA_VERSION:
                     # If we don't pass in a config file then we are expecting to
                     # have already upgraded the DB.
-                    raise UpgradeDatabaseException("Database needs to be upgraded")
+                    raise UpgradeDatabaseException(
+                        "Expected database schema version %i but got %i"
+                        % (SCHEMA_VERSION, user_version)
+                    )
             else:
                 _upgrade_existing_database(
-                    cur, user_version, delta_files, upgraded, database_engine, config
+                    cur,
+                    user_version,
+                    delta_files,
+                    upgraded,
+                    database_engine,
+                    config,
+                    data_stores=data_stores,
                 )
         else:
-            _setup_new_database(cur, database_engine)
+            _setup_new_database(cur, database_engine, data_stores=data_stores)
 
         # check if any of our configured dynamic modules want a database
         if config is not None:
@@ -84,9 +103,10 @@ def prepare_database(db_conn, database_engine, config):
         raise
 
 
-def _setup_new_database(cur, database_engine):
+def _setup_new_database(cur, database_engine, data_stores):
     """Sets up the database by finding a base set of "full schemas" and then
-    applying any necessary deltas.
+    applying any necessary deltas, including schemas from the given data
+    stores.
 
     The "full_schemas" directory has subdirectories named after versions. This
     function searches for the highest version less than or equal to
@@ -111,52 +131,83 @@ def _setup_new_database(cur, database_engine):
 
     In the example foo.sql and bar.sql would be run, and then any delta files
     for versions strictly greater than 11.
+
+    Note: we apply the full schemas and deltas from the top level `schema/`
+    folder as well those in the data stores specified.
+
+    Args:
+        cur (Cursor): a database cursor
+        database_engine (DatabaseEngine)
+        data_stores (list[str]): The names of the data stores to instantiate
+            on the given database.
     """
-    current_dir = os.path.join(dir_path, "schema", "full_schemas")
-    directory_entries = os.listdir(current_dir)
 
-    valid_dirs = []
-    pattern = re.compile(r"^\d+(\.sql)?$")
+    # We're about to set up a brand new database so we check that its
+    # configured to our liking.
+    database_engine.check_new_database(cur)
 
-    if isinstance(database_engine, PostgresEngine):
-        specific = "postgres"
-    else:
-        specific = "sqlite"
+    current_dir = os.path.join(dir_path, "schema", "full_schemas")
+    directory_entries = os.listdir(current_dir)
 
-    specific_pattern = re.compile(r"^\d+(\.sql." + specific + r")?$")
+    # First we find the highest full schema version we have
+    valid_versions = []
 
     for filename in directory_entries:
-        match = pattern.match(filename) or specific_pattern.match(filename)
-        abs_path = os.path.join(current_dir, filename)
-        if match and os.path.isdir(abs_path):
-            ver = int(match.group(0))
-            if ver <= SCHEMA_VERSION:
-                valid_dirs.append((ver, abs_path))
-        else:
-            logger.debug("Ignoring entry '%s' in 'full_schemas'", filename)
+        try:
+            ver = int(filename)
+        except ValueError:
+            continue
+
+        if ver <= SCHEMA_VERSION:
+            valid_versions.append(ver)
 
-    if not valid_dirs:
+    if not valid_versions:
         raise PrepareDatabaseException(
             "Could not find a suitable base set of full schemas"
         )
 
-    max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
+    max_current_ver = max(valid_versions)
 
     logger.debug("Initialising schema v%d", max_current_ver)
 
-    directory_entries = os.listdir(sql_dir)
+    # Now lets find all the full schema files, both in the global schema and
+    # in data store schemas.
+    directories = [os.path.join(current_dir, str(max_current_ver))]
+    directories.extend(
+        os.path.join(
+            dir_path,
+            "data_stores",
+            data_store,
+            "schema",
+            "full_schemas",
+            str(max_current_ver),
+        )
+        for data_store in data_stores
+    )
+
+    directory_entries = []
+    for directory in directories:
+        directory_entries.extend(
+            _DirectoryListing(file_name, os.path.join(directory, file_name))
+            for file_name in os.listdir(directory)
+        )
+
+    if isinstance(database_engine, PostgresEngine):
+        specific = "postgres"
+    else:
+        specific = "sqlite"
 
-    for filename in sorted(
-        fnmatch.filter(directory_entries, "*.sql")
-        + fnmatch.filter(directory_entries, "*.sql." + specific)
-    ):
-        sql_loc = os.path.join(sql_dir, filename)
-        logger.debug("Applying schema %s", sql_loc)
-        executescript(cur, sql_loc)
+    directory_entries.sort()
+    for entry in directory_entries:
+        if entry.file_name.endswith(".sql") or entry.file_name.endswith(
+            ".sql." + specific
+        ):
+            logger.debug("Applying schema %s", entry.absolute_path)
+            executescript(cur, entry.absolute_path)
 
     cur.execute(
         database_engine.convert_param_style(
-            "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+            "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
         ),
         (max_current_ver, False),
     )
@@ -168,6 +219,7 @@ def _setup_new_database(cur, database_engine):
         upgraded=False,
         database_engine=database_engine,
         config=None,
+        data_stores=data_stores,
         is_empty=True,
     )
 
@@ -179,6 +231,7 @@ def _upgrade_existing_database(
     upgraded,
     database_engine,
     config,
+    data_stores,
     is_empty=False,
 ):
     """Upgrades an existing database.
@@ -215,6 +268,10 @@ def _upgrade_existing_database(
     only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
     some arbitrary order.
 
+    Note: we apply the delta files from the specified data stores as well as
+    those in the top-level schema. We apply all delta files across data stores
+    for a version before applying those in the next version.
+
     Args:
         cur (Cursor)
         current_version (int): The current version of the schema.
@@ -224,7 +281,19 @@ def _upgrade_existing_database(
             applied deltas or from full schema file. If `True` the function
             will never apply delta files for the given `current_version`, since
             the current_version wasn't generated by applying those delta files.
+        database_engine (DatabaseEngine)
+        config (synapse.config.homeserver.HomeServerConfig|None):
+            None if we are initialising a blank database, otherwise the application
+            config
+        data_stores (list[str]): The names of the data stores to instantiate
+            on the given database.
+        is_empty (bool): Is this a blank database? I.e. do we need to run the
+            upgrade portions of the delta scripts.
     """
+    if is_empty:
+        assert not applied_delta_files
+    else:
+        assert config
 
     if current_version > SCHEMA_VERSION:
         raise ValueError(
@@ -232,6 +301,13 @@ def _upgrade_existing_database(
             + "new for the server to understand"
         )
 
+    # some of the deltas assume that config.server_name is set correctly, so now
+    # is a good time to run the sanity check.
+    if not is_empty and "main" in data_stores:
+        from synapse.storage.data_stores.main import check_database_before_upgrade
+
+        check_database_before_upgrade(cur, database_engine, config)
+
     start_ver = current_version
     if not upgraded:
         start_ver += 1
@@ -248,24 +324,65 @@ def _upgrade_existing_database(
     for v in range(start_ver, SCHEMA_VERSION + 1):
         logger.info("Upgrading schema to v%d", v)
 
+        # We need to search both the global and per data store schema
+        # directories for schema updates.
+
+        # First we find the directories to search in
         delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
+        directories = [delta_dir]
+        for data_store in data_stores:
+            directories.append(
+                os.path.join(
+                    dir_path, "data_stores", data_store, "schema", "delta", str(v)
+                )
+            )
 
-        try:
-            directory_entries = os.listdir(delta_dir)
-        except OSError:
-            logger.exception("Could not open delta dir for version %d", v)
-            raise UpgradeDatabaseException(
-                "Could not open delta dir for version %d" % (v,)
+        # Used to check if we have any duplicate file names
+        file_name_counter = Counter()
+
+        # Now find which directories have anything of interest.
+        directory_entries = []
+        for directory in directories:
+            logger.debug("Looking for schema deltas in %s", directory)
+            try:
+                file_names = os.listdir(directory)
+                directory_entries.extend(
+                    _DirectoryListing(file_name, os.path.join(directory, file_name))
+                    for file_name in file_names
+                )
+
+                for file_name in file_names:
+                    file_name_counter[file_name] += 1
+            except FileNotFoundError:
+                # Data stores can have empty entries for a given version delta.
+                pass
+            except OSError:
+                raise UpgradeDatabaseException(
+                    "Could not open delta dir for version %d: %s" % (v, directory)
+                )
+
+        duplicates = {
+            file_name for file_name, count in file_name_counter.items() if count > 1
+        }
+        if duplicates:
+            # We don't support using the same file name in the same delta version.
+            raise PrepareDatabaseException(
+                "Found multiple delta files with the same name in v%d: %s"
+                % (v, duplicates,)
             )
 
+        # We sort to ensure that we apply the delta files in a consistent
+        # order (to avoid bugs caused by inconsistent directory listing order)
         directory_entries.sort()
-        for file_name in directory_entries:
+        for entry in directory_entries:
+            file_name = entry.file_name
             relative_path = os.path.join(str(v), file_name)
-            logger.debug("Found file: %s", relative_path)
+            absolute_path = entry.absolute_path
+
+            logger.debug("Found file: %s (%s)", relative_path, absolute_path)
             if relative_path in applied_delta_files:
                 continue
 
-            absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
             root_name, ext = os.path.splitext(file_name)
             if ext == ".py":
                 # This is a python upgrade module. We need to import into some
@@ -352,7 +469,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
         ),
         (modname,),
     )
-    applied_deltas = set(d for d, in cur)
+    applied_deltas = {d for d, in cur}
     for (name, stream) in names_and_streams:
         if name in applied_deltas:
             continue
@@ -364,13 +481,12 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
             )
 
         logger.info("applying schema %s for %s", name, modname)
-        for statement in get_statements(stream):
-            cur.execute(statement)
+        execute_statements_from_stream(cur, stream)
 
         # Mark as done.
         cur.execute(
             database_engine.convert_param_style(
-                "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+                "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
             ),
             (modname, name),
         )
@@ -423,8 +539,12 @@ def get_statements(f):
 
 def executescript(txn, schema_path):
     with open(schema_path, "r") as f:
-        for statement in get_statements(f):
-            txn.execute(statement)
+        execute_statements_from_stream(txn, f)
+
+
+def execute_statements_from_stream(cur: Cursor, f: TextIO):
+    for statement in get_statements(f):
+        cur.execute(statement)
 
 
 def _get_or_create_schema_state(txn, database_engine):
@@ -448,3 +568,16 @@ def _get_or_create_schema_state(txn, database_engine):
         return current_version, applied_deltas, upgraded
 
     return None
+
+
+@attr.s()
+class _DirectoryListing(object):
+    """Helper class to store schema file name and the
+    absolute path to it.
+
+    These entries get sorted, so for consistency we want to ensure that
+    `file_name` attr is kept first.
+    """
+
+    file_name = attr.ib()
+    absolute_path = attr.ib()