diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754feab..f91a2eae7a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging
import os
import re
from collections import Counter
-from typing import Optional, TextIO
+from typing import Generator, Iterable, List, Optional, TextIO, Tuple
import attr
+from typing_extensions import Counter as CounterType
from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import LoggingDatabaseConnection
@@ -70,7 +71,7 @@ def prepare_database(
db_conn: LoggingDatabaseConnection,
database_engine: BaseDatabaseEngine,
config: Optional[HomeServerConfig],
- databases: Collection[str] = ["main", "state"],
+ databases: Collection[str] = ("main", "state"),
):
"""Prepares a physical database for usage. Will either create all necessary tables
or upgrade from an older schema version.
@@ -155,7 +156,9 @@ def prepare_database(
raise
-def _setup_new_database(cur, database_engine, databases):
+def _setup_new_database(
+ cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
+) -> None:
"""Sets up the physical database by finding a base set of "full schemas" and
then applying any necessary deltas, including schemas from the given data
stores.
@@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
folder as well those in the data stores specified.
Args:
- cur (Cursor): a database cursor
- database_engine (DatabaseEngine)
- databases (list[str]): The names of the databases to instantiate
- on the given physical database.
+ cur: a database cursor
+ database_engine
+ databases: The names of the databases to instantiate on the given physical database.
"""
# We're about to set up a brand new database so we check that its
@@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
database_engine.check_new_database(cur)
current_dir = os.path.join(dir_path, "schema", "full_schemas")
- directory_entries = os.listdir(current_dir)
# First we find the highest full schema version we have
valid_versions = []
- for filename in directory_entries:
+ for filename in os.listdir(current_dir):
try:
ver = int(filename)
except ValueError:
@@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
for database in databases
)
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
directory_entries.extend(
_DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
def _upgrade_existing_database(
- cur,
- current_version,
- applied_delta_files,
- upgraded,
- database_engine,
- config,
- databases,
- is_empty=False,
-):
+ cur: Cursor,
+ current_version: int,
+ applied_delta_files: List[str],
+ upgraded: bool,
+ database_engine: BaseDatabaseEngine,
+ config: Optional[HomeServerConfig],
+ databases: Collection[str],
+ is_empty: bool = False,
+) -> None:
"""Upgrades an existing physical database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -323,21 +324,20 @@ def _upgrade_existing_database(
for a version before applying those in the next version.
Args:
- cur (Cursor)
- current_version (int): The current version of the schema.
- applied_delta_files (list): A list of deltas that have already been
- applied.
- upgraded (bool): Whether the current version was generated by having
+ cur
+ current_version: The current version of the schema.
+ applied_delta_files: A list of deltas that have already been applied.
+ upgraded: Whether the current version was generated by having
applied deltas or from full schema file. If `True` the function
will never apply delta files for the given `current_version`, since
the current_version wasn't generated by applying those delta files.
- database_engine (DatabaseEngine)
- config (synapse.config.homeserver.HomeServerConfig|None):
+ database_engine
+ config:
None if we are initialising a blank database, otherwise the application
config
- databases (list[str]): The names of the databases to instantiate
+ databases: The names of the databases to instantiate
on the given physical database.
- is_empty (bool): Is this a blank database? I.e. do we need to run the
+ is_empty: Is this a blank database? I.e. do we need to run the
upgrade portions of the delta scripts.
"""
if is_empty:
@@ -358,6 +358,7 @@ def _upgrade_existing_database(
if not is_empty and "main" in databases:
from synapse.storage.databases.main import check_database_before_upgrade
+ assert config is not None
check_database_before_upgrade(cur, database_engine, config)
start_ver = current_version
@@ -388,10 +389,10 @@ def _upgrade_existing_database(
)
# Used to check if we have any duplicate file names
- file_name_counter = Counter()
+ file_name_counter = Counter() # type: CounterType[str]
# Now find which directories have anything of interest.
- directory_entries = []
+ directory_entries = [] # type: List[_DirectoryListing]
for directory in directories:
logger.debug("Looking for schema deltas in %s", directory)
try:
@@ -445,11 +446,11 @@ def _upgrade_existing_database(
module_name = "synapse.storage.v%d_%s" % (v, root_name)
with open(absolute_path) as python_file:
- module = imp.load_source(module_name, absolute_path, python_file)
+ module = imp.load_source(module_name, absolute_path, python_file) # type: ignore
logger.info("Running script %s", relative_path)
- module.run_create(cur, database_engine)
+ module.run_create(cur, database_engine) # type: ignore
if not is_empty:
- module.run_upgrade(cur, database_engine, config=config)
+ module.run_upgrade(cur, database_engine, config=config) # type: ignore
elif ext == ".pyc" or file_name == "__pycache__":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
@@ -497,14 +498,15 @@ def _upgrade_existing_database(
logger.info("Schema now up to date")
-def _apply_module_schemas(txn, database_engine, config):
+def _apply_module_schemas(
+ txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
"""Apply the module schemas for the dynamic modules, if any
Args:
cur: database cursor
- database_engine: synapse database engine class
- config (synapse.config.homeserver.HomeServerConfig):
- application config
+ database_engine:
+ config: application config
"""
for (mod, _config) in config.password_providers:
if not hasattr(mod, "get_db_schema_files"):
@@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
)
-def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+def _apply_module_schema_files(
+ cur: Cursor,
+ database_engine: BaseDatabaseEngine,
+ modname: str,
+ names_and_streams: Iterable[Tuple[str, TextIO]],
+) -> None:
"""Apply the module schemas for a single module
Args:
cur: database cursor
database_engine: synapse database engine class
- modname (str): fully qualified name of the module
- names_and_streams (Iterable[(str, file)]): the names and streams of
- schemas to be applied
+ modname: fully qualified name of the module
+ names_and_streams: the names and streams of schemas to be applied
"""
cur.execute(
"SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
@@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
)
-def get_statements(f):
+def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
statement_buffer = ""
in_comment = False # If we're in a /* ... */ style comment
@@ -594,17 +600,19 @@ def get_statements(f):
statement_buffer = statements[-1].strip()
-def executescript(txn, schema_path):
+def executescript(txn: Cursor, schema_path: str) -> None:
with open(schema_path, "r") as f:
execute_statements_from_stream(txn, f)
-def execute_statements_from_stream(cur: Cursor, f: TextIO):
+def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
for statement in get_statements(f):
cur.execute(statement)
-def _get_or_create_schema_state(txn, database_engine):
+def _get_or_create_schema_state(
+ txn: Cursor, database_engine: BaseDatabaseEngine
+) -> Optional[Tuple[int, List[str], bool]]:
# Bluntly try creating the schema_version tables.
schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
executescript(txn, schema_path)
@@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
txn.execute("SELECT version, upgraded FROM schema_version")
row = txn.fetchone()
current_version = int(row[0]) if row else None
- upgraded = bool(row[1]) if row else None
if current_version:
txn.execute(
@@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
(current_version,),
)
applied_deltas = [d for d, in txn]
+ upgraded = bool(row[1])
return current_version, applied_deltas, upgraded
return None
@@ -634,5 +642,5 @@ class _DirectoryListing:
`file_name` attr is kept first.
"""
- file_name = attr.ib()
- absolute_path = attr.ib()
+ file_name = attr.ib(type=str)
+ absolute_path = attr.ib(type=str)
|