diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index e96eed8a6d..9cc3b51fe6 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -14,20 +14,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import fnmatch
import imp
import logging
import os
import re
+from collections import Counter
+from typing import TextIO
+
+import attr
from synapse.storage.engines.postgres import PostgresEngine
+from synapse.storage.types import Cursor
logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
-SCHEMA_VERSION = 56
+# XXX: If you're about to bump this to 59 (or higher) please create an update
+# that drops the unused `cache_invalidation_stream` table, as per #7436!
+# XXX: Also add an update to drop `account_data_max_stream_id` as per #7656!
+SCHEMA_VERSION = 58
dir_path = os.path.abspath(os.path.dirname(__file__))
@@ -40,7 +47,7 @@ class UpgradeDatabaseException(PrepareDatabaseException):
pass
-def prepare_database(db_conn, database_engine, config):
+def prepare_database(db_conn, database_engine, config, data_stores=["main", "state"]):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -53,7 +60,10 @@ def prepare_database(db_conn, database_engine, config):
config (synapse.config.homeserver.HomeServerConfig|None):
application config, or None if we are connecting to an existing
database which we expect to be configured already
+ data_stores (list[str]): The name of the data stores that will be used
+ with this database. Defaults to all data stores.
"""
+
try:
cur = db_conn.cursor()
version_info = _get_or_create_schema_state(cur, database_engine)
@@ -65,13 +75,22 @@ def prepare_database(db_conn, database_engine, config):
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
- raise UpgradeDatabaseException("Database needs to be upgraded")
+ raise UpgradeDatabaseException(
+ "Expected database schema version %i but got %i"
+ % (SCHEMA_VERSION, user_version)
+ )
else:
_upgrade_existing_database(
- cur, user_version, delta_files, upgraded, database_engine, config
+ cur,
+ user_version,
+ delta_files,
+ upgraded,
+ database_engine,
+ config,
+ data_stores=data_stores,
)
else:
- _setup_new_database(cur, database_engine)
+ _setup_new_database(cur, database_engine, data_stores=data_stores)
# check if any of our configured dynamic modules want a database
if config is not None:
@@ -84,9 +103,10 @@ def prepare_database(db_conn, database_engine, config):
raise
-def _setup_new_database(cur, database_engine):
+def _setup_new_database(cur, database_engine, data_stores):
"""Sets up the database by finding a base set of "full schemas" and then
- applying any necessary deltas.
+ applying any necessary deltas, including schemas from the given data
+ stores.
The "full_schemas" directory has subdirectories named after versions. This
function searches for the highest version less than or equal to
@@ -111,52 +131,83 @@ def _setup_new_database(cur, database_engine):
In the example foo.sql and bar.sql would be run, and then any delta files
for versions strictly greater than 11.
+
+ Note: we apply the full schemas and deltas from the top level `schema/`
+ folder as well those in the data stores specified.
+
+ Args:
+ cur (Cursor): a database cursor
+ database_engine (DatabaseEngine)
+ data_stores (list[str]): The names of the data stores to instantiate
+ on the given database.
"""
- current_dir = os.path.join(dir_path, "schema", "full_schemas")
- directory_entries = os.listdir(current_dir)
- valid_dirs = []
- pattern = re.compile(r"^\d+(\.sql)?$")
+ # We're about to set up a brand new database so we check that its
+ # configured to our liking.
+ database_engine.check_new_database(cur)
- if isinstance(database_engine, PostgresEngine):
- specific = "postgres"
- else:
- specific = "sqlite"
+ current_dir = os.path.join(dir_path, "schema", "full_schemas")
+ directory_entries = os.listdir(current_dir)
- specific_pattern = re.compile(r"^\d+(\.sql." + specific + r")?$")
+ # First we find the highest full schema version we have
+ valid_versions = []
for filename in directory_entries:
- match = pattern.match(filename) or specific_pattern.match(filename)
- abs_path = os.path.join(current_dir, filename)
- if match and os.path.isdir(abs_path):
- ver = int(match.group(0))
- if ver <= SCHEMA_VERSION:
- valid_dirs.append((ver, abs_path))
- else:
- logger.debug("Ignoring entry '%s' in 'full_schemas'", filename)
+ try:
+ ver = int(filename)
+ except ValueError:
+ continue
+
+ if ver <= SCHEMA_VERSION:
+ valid_versions.append(ver)
- if not valid_dirs:
+ if not valid_versions:
raise PrepareDatabaseException(
"Could not find a suitable base set of full schemas"
)
- max_current_ver, sql_dir = max(valid_dirs, key=lambda x: x[0])
+ max_current_ver = max(valid_versions)
logger.debug("Initialising schema v%d", max_current_ver)
- directory_entries = os.listdir(sql_dir)
+ # Now lets find all the full schema files, both in the global schema and
+ # in data store schemas.
+ directories = [os.path.join(current_dir, str(max_current_ver))]
+ directories.extend(
+ os.path.join(
+ dir_path,
+ "data_stores",
+ data_store,
+ "schema",
+ "full_schemas",
+ str(max_current_ver),
+ )
+ for data_store in data_stores
+ )
+
+ directory_entries = []
+ for directory in directories:
+ directory_entries.extend(
+ _DirectoryListing(file_name, os.path.join(directory, file_name))
+ for file_name in os.listdir(directory)
+ )
+
+ if isinstance(database_engine, PostgresEngine):
+ specific = "postgres"
+ else:
+ specific = "sqlite"
- for filename in sorted(
- fnmatch.filter(directory_entries, "*.sql")
- + fnmatch.filter(directory_entries, "*.sql." + specific)
- ):
- sql_loc = os.path.join(sql_dir, filename)
- logger.debug("Applying schema %s", sql_loc)
- executescript(cur, sql_loc)
+ directory_entries.sort()
+ for entry in directory_entries:
+ if entry.file_name.endswith(".sql") or entry.file_name.endswith(
+ ".sql." + specific
+ ):
+ logger.debug("Applying schema %s", entry.absolute_path)
+ executescript(cur, entry.absolute_path)
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO schema_version (version, upgraded)" " VALUES (?,?)"
+ "INSERT INTO schema_version (version, upgraded) VALUES (?,?)"
),
(max_current_ver, False),
)
@@ -168,6 +219,7 @@ def _setup_new_database(cur, database_engine):
upgraded=False,
database_engine=database_engine,
config=None,
+ data_stores=data_stores,
is_empty=True,
)
@@ -179,6 +231,7 @@ def _upgrade_existing_database(
upgraded,
database_engine,
config,
+ data_stores,
is_empty=False,
):
"""Upgrades an existing database.
@@ -215,6 +268,10 @@ def _upgrade_existing_database(
only if `upgraded` is True. Then `foo.sql` and `bar.py` would be run in
some arbitrary order.
+ Note: we apply the delta files from the specified data stores as well as
+ those in the top-level schema. We apply all delta files across data stores
+ for a version before applying those in the next version.
+
Args:
cur (Cursor)
current_version (int): The current version of the schema.
@@ -224,7 +281,19 @@ def _upgrade_existing_database(
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
+ database_engine (DatabaseEngine)
+ config (synapse.config.homeserver.HomeServerConfig|None):
+ None if we are initialising a blank database, otherwise the application
+ config
+ data_stores (list[str]): The names of the data stores to instantiate
+ on the given database.
+ is_empty (bool): Is this a blank database? I.e. do we need to run the
+ upgrade portions of the delta scripts.
"""
+ if is_empty:
+ assert not applied_delta_files
+ else:
+ assert config
if current_version > SCHEMA_VERSION:
raise ValueError(
@@ -232,6 +301,13 @@ def _upgrade_existing_database(
+ "new for the server to understand"
)
+ # some of the deltas assume that config.server_name is set correctly, so now
+ # is a good time to run the sanity check.
+ if not is_empty and "main" in data_stores:
+ from synapse.storage.data_stores.main import check_database_before_upgrade
+
+ check_database_before_upgrade(cur, database_engine, config)
+
start_ver = current_version
if not upgraded:
start_ver += 1
@@ -248,24 +324,65 @@ def _upgrade_existing_database(
for v in range(start_ver, SCHEMA_VERSION + 1):
logger.info("Upgrading schema to v%d", v)
+ # We need to search both the global and per data store schema
+ # directories for schema updates.
+
+ # First we find the directories to search in
delta_dir = os.path.join(dir_path, "schema", "delta", str(v))
+ directories = [delta_dir]
+ for data_store in data_stores:
+ directories.append(
+ os.path.join(
+ dir_path, "data_stores", data_store, "schema", "delta", str(v)
+ )
+ )
- try:
- directory_entries = os.listdir(delta_dir)
- except OSError:
- logger.exception("Could not open delta dir for version %d", v)
- raise UpgradeDatabaseException(
- "Could not open delta dir for version %d" % (v,)
+ # Used to check if we have any duplicate file names
+ file_name_counter = Counter()
+
+ # Now find which directories have anything of interest.
+ directory_entries = []
+ for directory in directories:
+ logger.debug("Looking for schema deltas in %s", directory)
+ try:
+ file_names = os.listdir(directory)
+ directory_entries.extend(
+ _DirectoryListing(file_name, os.path.join(directory, file_name))
+ for file_name in file_names
+ )
+
+ for file_name in file_names:
+ file_name_counter[file_name] += 1
+ except FileNotFoundError:
+ # Data stores can have empty entries for a given version delta.
+ pass
+ except OSError:
+ raise UpgradeDatabaseException(
+ "Could not open delta dir for version %d: %s" % (v, directory)
+ )
+
+ duplicates = {
+ file_name for file_name, count in file_name_counter.items() if count > 1
+ }
+ if duplicates:
+ # We don't support using the same file name in the same delta version.
+ raise PrepareDatabaseException(
+ "Found multiple delta files with the same name in v%d: %s"
+ % (v, duplicates,)
)
+ # We sort to ensure that we apply the delta files in a consistent
+ # order (to avoid bugs caused by inconsistent directory listing order)
directory_entries.sort()
- for file_name in directory_entries:
+ for entry in directory_entries:
+ file_name = entry.file_name
relative_path = os.path.join(str(v), file_name)
- logger.debug("Found file: %s", relative_path)
+ absolute_path = entry.absolute_path
+
+ logger.debug("Found file: %s (%s)", relative_path, absolute_path)
if relative_path in applied_delta_files:
continue
- absolute_path = os.path.join(dir_path, "schema", "delta", relative_path)
root_name, ext = os.path.splitext(file_name)
if ext == ".py":
# This is a python upgrade module. We need to import into some
@@ -352,7 +469,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
),
(modname,),
)
- applied_deltas = set(d for d, in cur)
+ applied_deltas = {d for d, in cur}
for (name, stream) in names_and_streams:
if name in applied_deltas:
continue
@@ -364,13 +481,12 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
)
logger.info("applying schema %s for %s", name, modname)
- for statement in get_statements(stream):
- cur.execute(statement)
+ execute_statements_from_stream(cur, stream)
# Mark as done.
cur.execute(
database_engine.convert_param_style(
- "INSERT INTO applied_module_schemas (module_name, file)" " VALUES (?,?)"
+ "INSERT INTO applied_module_schemas (module_name, file) VALUES (?,?)"
),
(modname, name),
)
@@ -423,8 +539,12 @@ def get_statements(f):
def executescript(txn, schema_path):
with open(schema_path, "r") as f:
- for statement in get_statements(f):
- txn.execute(statement)
+ execute_statements_from_stream(txn, f)
+
+
+def execute_statements_from_stream(cur: Cursor, f: TextIO):
+ for statement in get_statements(f):
+ cur.execute(statement)
def _get_or_create_schema_state(txn, database_engine):
@@ -448,3 +568,16 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded
return None
+
+
+@attr.s()
+class _DirectoryListing(object):
+ """Helper class to store schema file name and the
+ absolute path to it.
+
+ These entries get sorted, so for consistency we want to ensure that
+ `file_name` attr is kept first.
+ """
+
+ file_name = attr.ib()
+ absolute_path = attr.ib()
|