summary refs log tree commit diff
path: root/synapse/storage/prepare_database.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-12-30 08:09:53 -0500
committerGitHub <noreply@github.com>2020-12-30 08:09:53 -0500
commit637282bb5019ce1656001927eea1be46c4854815 (patch)
tree7e9a9ba64918de5e995241a02054174ced099d75 /synapse/storage/prepare_database.py
parentDoc/move database setup instructions in install md (#8987) (diff)
downloadsynapse-637282bb5019ce1656001927eea1be46c4854815.tar.xz
Add additional type hints to the storage module. (#8980)
Diffstat (limited to 'synapse/storage/prepare_database.py')
-rw-r--r--synapse/storage/prepare_database.py104
1 files changed, 56 insertions, 48 deletions
diff --git a/synapse/storage/prepare_database.py b/synapse/storage/prepare_database.py
index 459754feab..f91a2eae7a 100644
--- a/synapse/storage/prepare_database.py
+++ b/synapse/storage/prepare_database.py
@@ -18,9 +18,10 @@ import logging
 import os
 import re
 from collections import Counter
-from typing import Optional, TextIO
+from typing import Generator, Iterable, List, Optional, TextIO, Tuple
 
 import attr
+from typing_extensions import Counter as CounterType
 
 from synapse.config.homeserver import HomeServerConfig
 from synapse.storage.database import LoggingDatabaseConnection
@@ -70,7 +71,7 @@ def prepare_database(
     db_conn: LoggingDatabaseConnection,
     database_engine: BaseDatabaseEngine,
     config: Optional[HomeServerConfig],
-    databases: Collection[str] = ["main", "state"],
+    databases: Collection[str] = ("main", "state"),
 ):
     """Prepares a physical database for usage. Will either create all necessary tables
     or upgrade from an older schema version.
@@ -155,7 +156,9 @@ def prepare_database(
         raise
 
 
-def _setup_new_database(cur, database_engine, databases):
+def _setup_new_database(
+    cur: Cursor, database_engine: BaseDatabaseEngine, databases: Collection[str]
+) -> None:
     """Sets up the physical database by finding a base set of "full schemas" and
     then applying any necessary deltas, including schemas from the given data
     stores.
@@ -188,10 +191,9 @@ def _setup_new_database(cur, database_engine, databases):
     folder as well those in the data stores specified.
 
     Args:
-        cur (Cursor): a database cursor
-        database_engine (DatabaseEngine)
-        databases (list[str]): The names of the databases to instantiate
-            on the given physical database.
+        cur: a database cursor
+        database_engine
+        databases: The names of the databases to instantiate on the given physical database.
     """
 
     # We're about to set up a brand new database so we check that its
@@ -199,12 +201,11 @@ def _setup_new_database(cur, database_engine, databases):
     database_engine.check_new_database(cur)
 
     current_dir = os.path.join(dir_path, "schema", "full_schemas")
-    directory_entries = os.listdir(current_dir)
 
     # First we find the highest full schema version we have
     valid_versions = []
 
-    for filename in directory_entries:
+    for filename in os.listdir(current_dir):
         try:
             ver = int(filename)
         except ValueError:
@@ -237,7 +238,7 @@ def _setup_new_database(cur, database_engine, databases):
         for database in databases
     )
 
-    directory_entries = []
+    directory_entries = []  # type: List[_DirectoryListing]
     for directory in directories:
         directory_entries.extend(
             _DirectoryListing(file_name, os.path.join(directory, file_name))
@@ -275,15 +276,15 @@ def _setup_new_database(cur, database_engine, databases):
 
 
 def _upgrade_existing_database(
-    cur,
-    current_version,
-    applied_delta_files,
-    upgraded,
-    database_engine,
-    config,
-    databases,
-    is_empty=False,
-):
+    cur: Cursor,
+    current_version: int,
+    applied_delta_files: List[str],
+    upgraded: bool,
+    database_engine: BaseDatabaseEngine,
+    config: Optional[HomeServerConfig],
+    databases: Collection[str],
+    is_empty: bool = False,
+) -> None:
     """Upgrades an existing physical database.
 
     Delta files can either be SQL stored in *.sql files, or python modules
@@ -323,21 +324,20 @@ def _upgrade_existing_database(
     for a version before applying those in the next version.
 
     Args:
-        cur (Cursor)
-        current_version (int): The current version of the schema.
-        applied_delta_files (list): A list of deltas that have already been
-            applied.
-        upgraded (bool): Whether the current version was generated by having
+        cur
+        current_version: The current version of the schema.
+        applied_delta_files: A list of deltas that have already been applied.
+        upgraded: Whether the current version was generated by having
             applied deltas or from full schema file. If `True` the function
             will never apply delta files for the given `current_version`, since
             the current_version wasn't generated by applying those delta files.
-        database_engine (DatabaseEngine)
-        config (synapse.config.homeserver.HomeServerConfig|None):
+        database_engine
+        config:
             None if we are initialising a blank database, otherwise the application
             config
-        databases (list[str]): The names of the databases to instantiate
+        databases: The names of the databases to instantiate
             on the given physical database.
-        is_empty (bool): Is this a blank database? I.e. do we need to run the
+        is_empty: Is this a blank database? I.e. do we need to run the
             upgrade portions of the delta scripts.
     """
     if is_empty:
@@ -358,6 +358,7 @@ def _upgrade_existing_database(
     if not is_empty and "main" in databases:
         from synapse.storage.databases.main import check_database_before_upgrade
 
+        assert config is not None
         check_database_before_upgrade(cur, database_engine, config)
 
     start_ver = current_version
@@ -388,10 +389,10 @@ def _upgrade_existing_database(
             )
 
         # Used to check if we have any duplicate file names
-        file_name_counter = Counter()
+        file_name_counter = Counter()  # type: CounterType[str]
 
         # Now find which directories have anything of interest.
-        directory_entries = []
+        directory_entries = []  # type: List[_DirectoryListing]
         for directory in directories:
             logger.debug("Looking for schema deltas in %s", directory)
             try:
@@ -445,11 +446,11 @@ def _upgrade_existing_database(
 
                 module_name = "synapse.storage.v%d_%s" % (v, root_name)
                 with open(absolute_path) as python_file:
-                    module = imp.load_source(module_name, absolute_path, python_file)
+                    module = imp.load_source(module_name, absolute_path, python_file)  # type: ignore
                 logger.info("Running script %s", relative_path)
-                module.run_create(cur, database_engine)
+                module.run_create(cur, database_engine)  # type: ignore
                 if not is_empty:
-                    module.run_upgrade(cur, database_engine, config=config)
+                    module.run_upgrade(cur, database_engine, config=config)  # type: ignore
             elif ext == ".pyc" or file_name == "__pycache__":
                 # Sometimes .pyc files turn up anyway even though we've
                 # disabled their generation; e.g. from distribution package
@@ -497,14 +498,15 @@ def _upgrade_existing_database(
     logger.info("Schema now up to date")
 
 
-def _apply_module_schemas(txn, database_engine, config):
+def _apply_module_schemas(
+    txn: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
     """Apply the module schemas for the dynamic modules, if any
 
     Args:
         cur: database cursor
-        database_engine: synapse database engine class
-        config (synapse.config.homeserver.HomeServerConfig):
-            application config
+        database_engine:
+        config: application config
     """
     for (mod, _config) in config.password_providers:
         if not hasattr(mod, "get_db_schema_files"):
@@ -515,15 +517,19 @@ def _apply_module_schemas(txn, database_engine, config):
         )
 
 
-def _apply_module_schema_files(cur, database_engine, modname, names_and_streams):
+def _apply_module_schema_files(
+    cur: Cursor,
+    database_engine: BaseDatabaseEngine,
+    modname: str,
+    names_and_streams: Iterable[Tuple[str, TextIO]],
+) -> None:
     """Apply the module schemas for a single module
 
     Args:
         cur: database cursor
         database_engine: synapse database engine class
-        modname (str): fully qualified name of the module
-        names_and_streams (Iterable[(str, file)]): the names and streams of
-            schemas to be applied
+        modname: fully qualified name of the module
+        names_and_streams: the names and streams of schemas to be applied
     """
     cur.execute(
         "SELECT file FROM applied_module_schemas WHERE module_name = ?", (modname,),
@@ -549,7 +555,7 @@ def _apply_module_schema_files(cur, database_engine, modname, names_and_streams)
         )
 
 
-def get_statements(f):
+def get_statements(f: Iterable[str]) -> Generator[str, None, None]:
     statement_buffer = ""
     in_comment = False  # If we're in a /* ... */ style comment
 
@@ -594,17 +600,19 @@ def get_statements(f):
         statement_buffer = statements[-1].strip()
 
 
-def executescript(txn, schema_path):
+def executescript(txn: Cursor, schema_path: str) -> None:
     with open(schema_path, "r") as f:
         execute_statements_from_stream(txn, f)
 
 
-def execute_statements_from_stream(cur: Cursor, f: TextIO):
+def execute_statements_from_stream(cur: Cursor, f: TextIO) -> None:
     for statement in get_statements(f):
         cur.execute(statement)
 
 
-def _get_or_create_schema_state(txn, database_engine):
+def _get_or_create_schema_state(
+    txn: Cursor, database_engine: BaseDatabaseEngine
+) -> Optional[Tuple[int, List[str], bool]]:
     # Bluntly try creating the schema_version tables.
     schema_path = os.path.join(dir_path, "schema", "schema_version.sql")
     executescript(txn, schema_path)
@@ -612,7 +620,6 @@ def _get_or_create_schema_state(txn, database_engine):
     txn.execute("SELECT version, upgraded FROM schema_version")
     row = txn.fetchone()
     current_version = int(row[0]) if row else None
-    upgraded = bool(row[1]) if row else None
 
     if current_version:
         txn.execute(
@@ -620,6 +627,7 @@ def _get_or_create_schema_state(txn, database_engine):
             (current_version,),
         )
         applied_deltas = [d for d, in txn]
+        upgraded = bool(row[1])
         return current_version, applied_deltas, upgraded
 
     return None
@@ -634,5 +642,5 @@ class _DirectoryListing:
     `file_name` attr is kept first.
     """
 
-    file_name = attr.ib()
-    absolute_path = attr.ib()
+    file_name = attr.ib(type=str)
+    absolute_path = attr.ib(type=str)