diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index ee60e2a718..4957e77f4c 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -19,12 +19,15 @@ import logging
import os
import re
from collections import Counter
-from typing import TextIO
+from typing import Optional, TextIO
import attr
+from synapse.config.homeserver import HomeServerConfig
+from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.engines.postgres import PostgresEngine
-from synapse.storage.types import Cursor
+from synapse.storage.types import Connection, Cursor
+from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -63,7 +66,12 @@ UNAPPLIED_DELTA_ON_WORKER_ERROR = (
)
-def prepare_database(db_conn, database_engine, config, databases=["main", "state"]):
+def prepare_database(
+ db_conn: Connection,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str] = ["main", "state"],
+):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -73,16 +81,24 @@ def prepare_database(db_conn, database_engine, config, databases=["main", "state
Args:
db_conn:
database_engine:
- config (synapse.config.homeserver.HomeServerConfig|None):
+ config :
application config, or None if we are connecting to an existing
database which we expect to be configured already
- databases (list[str]): The name of the databases that will be used
+ databases: The name of the databases that will be used
with this physical database. Defaults to all databases.
"""
try:
cur = db_conn.cursor()
+ # sqlite does not automatically start transactions for DDL / SELECT statements,
+ # so we start one before running anything. This ensures that any upgrades
+ # are either applied completely, or not at all.
+ #
+ # (psycopg2 automatically starts a transaction as soon as we run any statements
+ # at all, so this is redundant but harmless there.)
+ cur.execute("BEGIN TRANSACTION")
+
logger.info("%r: Checking existing schema version", databases)
version_info = _get_or_create_schema_state(cur, database_engine)
@@ -622,7 +638,7 @@ def _get_or_create_schema_state(txn, database_engine):
return None
-@attr.s()
+@attr.s(slots=True)
class _DirectoryListing:
"""Helper class to store schema file name and the
absolute path to it.
|