summary refs log tree commit diff
path: root/synapse/config/_base.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-11-23 10:21:19 -0500
committerGitHub <noreply@github.com>2021-11-23 15:21:19 +0000
commit55669bd3de7137553085e9f1c16a686ff657108c (patch)
tree28637b8af5497395537cf11b0f627e05d32670da /synapse/config/_base.py
parentRemove code invalidated by deprecated config flag 'trust_identity_servers_for... (diff)
downloadsynapse-55669bd3de7137553085e9f1c16a686ff657108c.tar.xz
Add missing type hints to config base classes (#11377)
Diffstat (limited to 'synapse/config/_base.py')
-rw-r--r--synapse/config/_base.py157
1 files changed, 93 insertions, 64 deletions
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 7c4428a138..1265738dc1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -20,7 +20,18 @@ import os
 from collections import OrderedDict
 from hashlib import sha256
 from textwrap import dedent
-from typing import Any, Iterable, List, MutableMapping, Optional, Union
+from typing import (
+    Any,
+    Dict,
+    Iterable,
+    List,
+    MutableMapping,
+    Optional,
+    Tuple,
+    Type,
+    TypeVar,
+    Union,
+)
 
 import attr
 import jinja2
@@ -78,7 +89,7 @@ CONFIG_FILE_HEADER = """\
 """
 
 
-def path_exists(file_path):
+def path_exists(file_path: str) -> bool:
     """Check if a file exists
 
     Unlike os.path.exists, this throws an exception if there is an error
@@ -86,7 +97,7 @@ def path_exists(file_path):
     the parent dir).
 
     Returns:
-        bool: True if the file exists; False if not.
+        True if the file exists; False if not.
     """
     try:
         os.stat(file_path)
@@ -102,15 +113,15 @@ class Config:
     A configuration section, containing configuration keys and values.
 
     Attributes:
-        section (str): The section title of this config object, such as
+        section: The section title of this config object, such as
             "tls" or "logger". This is used to refer to it on the root
             logger (for example, `config.tls.some_option`). Must be
             defined in subclasses.
     """
 
-    section = None
+    section: str
 
-    def __init__(self, root_config=None):
+    def __init__(self, root_config: "RootConfig" = None):
         self.root = root_config
 
         # Get the path to the default Synapse template directory
@@ -119,7 +130,7 @@ class Config:
         )
 
     @staticmethod
-    def parse_size(value):
+    def parse_size(value: Union[str, int]) -> int:
         if isinstance(value, int):
             return value
         sizes = {"K": 1024, "M": 1024 * 1024}
@@ -162,15 +173,15 @@ class Config:
         return int(value) * size
 
     @staticmethod
-    def abspath(file_path):
+    def abspath(file_path: str) -> str:
         return os.path.abspath(file_path) if file_path else file_path
 
     @classmethod
-    def path_exists(cls, file_path):
+    def path_exists(cls, file_path: str) -> bool:
         return path_exists(file_path)
 
     @classmethod
-    def check_file(cls, file_path, config_name):
+    def check_file(cls, file_path: Optional[str], config_name: str) -> str:
         if file_path is None:
             raise ConfigError("Missing config for %s." % (config_name,))
         try:
@@ -183,7 +194,7 @@ class Config:
         return cls.abspath(file_path)
 
     @classmethod
-    def ensure_directory(cls, dir_path):
+    def ensure_directory(cls, dir_path: str) -> str:
         dir_path = cls.abspath(dir_path)
         os.makedirs(dir_path, exist_ok=True)
         if not os.path.isdir(dir_path):
@@ -191,7 +202,7 @@ class Config:
         return dir_path
 
     @classmethod
-    def read_file(cls, file_path, config_name):
+    def read_file(cls, file_path: Any, config_name: str) -> str:
         """Deprecated: call read_file directly"""
         return read_file(file_path, (config_name,))
 
@@ -284,6 +295,9 @@ class Config:
         return [env.get_template(filename) for filename in filenames]
 
 
+TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
+
+
 class RootConfig:
     """
     Holder of an application's configuration.
@@ -308,7 +322,9 @@ class RootConfig:
                 raise Exception("Failed making %s: %r" % (config_class.section, e))
             setattr(self, config_class.section, conf)
 
-    def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+    def invoke_all(
+        self, func_name: str, *args: Any, **kwargs: Any
+    ) -> MutableMapping[str, Any]:
         """
         Invoke a function on all instantiated config objects this RootConfig is
         configured to use.
@@ -317,6 +333,7 @@ class RootConfig:
             func_name: Name of function to invoke
             *args
             **kwargs
+
         Returns:
             ordered dictionary of config section name and the result of the
             function from it.
@@ -332,7 +349,7 @@ class RootConfig:
         return res
 
     @classmethod
-    def invoke_all_static(cls, func_name: str, *args, **kwargs):
+    def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None:
         """
         Invoke a static function on config objects this RootConfig is
         configured to use.
@@ -341,6 +358,7 @@ class RootConfig:
             func_name: Name of function to invoke
             *args
             **kwargs
+
         Returns:
             ordered dictionary of config section name and the result of the
             function from it.
@@ -351,16 +369,16 @@ class RootConfig:
 
     def generate_config(
         self,
-        config_dir_path,
-        data_dir_path,
-        server_name,
-        generate_secrets=False,
-        report_stats=None,
-        open_private_ports=False,
-        listeners=None,
-        tls_certificate_path=None,
-        tls_private_key_path=None,
-    ):
+        config_dir_path: str,
+        data_dir_path: str,
+        server_name: str,
+        generate_secrets: bool = False,
+        report_stats: Optional[bool] = None,
+        open_private_ports: bool = False,
+        listeners: Optional[List[dict]] = None,
+        tls_certificate_path: Optional[str] = None,
+        tls_private_key_path: Optional[str] = None,
+    ) -> str:
         """
         Build a default configuration file
 
@@ -368,27 +386,27 @@ class RootConfig:
         (eg with --generate_config).
 
         Args:
-            config_dir_path (str): The path where the config files are kept. Used to
+            config_dir_path: The path where the config files are kept. Used to
                 create filenames for things like the log config and the signing key.
 
-            data_dir_path (str): The path where the data files are kept. Used to create
+            data_dir_path: The path where the data files are kept. Used to create
                 filenames for things like the database and media store.
 
-            server_name (str): The server name. Used to initialise the server_name
+            server_name: The server name. Used to initialise the server_name
                 config param, but also used in the names of some of the config files.
 
-            generate_secrets (bool): True if we should generate new secrets for things
+            generate_secrets: True if we should generate new secrets for things
                 like the macaroon_secret_key. If False, these parameters will be left
                 unset.
 
-            report_stats (bool|None): Initial setting for the report_stats setting.
+            report_stats: Initial setting for the report_stats setting.
                 If None, report_stats will be left unset.
 
-            open_private_ports (bool): True to leave private ports (such as the non-TLS
+            open_private_ports: True to leave private ports (such as the non-TLS
                 HTTP listener) open to the internet.
 
-            listeners (list(dict)|None): A list of descriptions of the listeners
-                synapse should start with each of which specifies a port (str), a list of
+            listeners: A list of descriptions of the listeners synapse should
+                start with each of which specifies a port (int), a list of
                 resources (list(str)), tls (bool) and type (str). For example:
                 [{
                     "port": 8448,
@@ -403,16 +421,12 @@ class RootConfig:
                     "type": "http",
                 }],
 
+            tls_certificate_path: The path to the tls certificate.
 
-            database (str|None): The database type to configure, either `psycog2`
-                or `sqlite3`.
-
-            tls_certificate_path (str|None): The path to the tls certificate.
-
-            tls_private_key_path (str|None): The path to the tls private key.
+            tls_private_key_path: The path to the tls private key.
 
         Returns:
-            str: the yaml config file
+            The yaml config file
         """
 
         return CONFIG_FILE_HEADER + "\n\n".join(
@@ -432,12 +446,15 @@ class RootConfig:
         )
 
     @classmethod
-    def load_config(cls, description, argv):
+    def load_config(
+        cls: Type[TRootConfig], description: str, argv: List[str]
+    ) -> TRootConfig:
         """Parse the commandline and config files
 
         Doesn't support config-file-generation: used by the worker apps.
 
-        Returns: Config object.
+        Returns:
+            Config object.
         """
         config_parser = argparse.ArgumentParser(description=description)
         cls.add_arguments_to_parser(config_parser)
@@ -446,7 +463,7 @@ class RootConfig:
         return obj
 
     @classmethod
-    def add_arguments_to_parser(cls, config_parser):
+    def add_arguments_to_parser(cls, config_parser: argparse.ArgumentParser) -> None:
         """Adds all the config flags to an ArgumentParser.
 
         Doesn't support config-file-generation: used by the worker apps.
@@ -454,7 +471,7 @@ class RootConfig:
         Used for workers where we want to add extra flags/subcommands.
 
         Args:
-            config_parser (ArgumentParser): App description
+            config_parser: App description
         """
 
         config_parser.add_argument(
@@ -477,7 +494,9 @@ class RootConfig:
         cls.invoke_all_static("add_arguments", config_parser)
 
     @classmethod
-    def load_config_with_parser(cls, parser, argv):
+    def load_config_with_parser(
+        cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
+    ) -> Tuple[TRootConfig, argparse.Namespace]:
         """Parse the commandline and config files with the given parser
 
         Doesn't support config-file-generation: used by the worker apps.
@@ -485,13 +504,12 @@ class RootConfig:
         Used for workers where we want to add extra flags/subcommands.
 
         Args:
-            parser (ArgumentParser)
-            argv (list[str])
+            parser
+            argv
 
         Returns:
-            tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed
-            config object and the parsed argparse.Namespace object from
-            `parser.parse_args(..)`
+            Returns the parsed config object and the parsed argparse.Namespace
+            object from parser.parse_args(..)`
         """
 
         obj = cls()
@@ -520,12 +538,15 @@ class RootConfig:
         return obj, config_args
 
     @classmethod
-    def load_or_generate_config(cls, description, argv):
+    def load_or_generate_config(
+        cls: Type[TRootConfig], description: str, argv: List[str]
+    ) -> Optional[TRootConfig]:
         """Parse the commandline and config files
 
         Supports generation of config files, so is used for the main homeserver app.
 
-        Returns: Config object, or None if --generate-config or --generate-keys was set
+        Returns:
+            Config object, or None if --generate-config or --generate-keys was set
         """
         parser = argparse.ArgumentParser(description=description)
         parser.add_argument(
@@ -680,16 +701,21 @@ class RootConfig:
 
         return obj
 
-    def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
+    def parse_config_dict(
+        self,
+        config_dict: Dict[str, Any],
+        config_dir_path: Optional[str] = None,
+        data_dir_path: Optional[str] = None,
+    ) -> None:
         """Read the information from the config dict into this Config object.
 
         Args:
-            config_dict (dict): Configuration data, as read from the yaml
+            config_dict: Configuration data, as read from the yaml
 
-            config_dir_path (str): The path where the config files are kept. Used to
+            config_dir_path: The path where the config files are kept. Used to
                 create filenames for things like the log config and the signing key.
 
-            data_dir_path (str): The path where the data files are kept. Used to create
+            data_dir_path: The path where the data files are kept. Used to create
                 filenames for things like the database and media store.
         """
         self.invoke_all(
@@ -699,17 +725,20 @@ class RootConfig:
             data_dir_path=data_dir_path,
         )
 
-    def generate_missing_files(self, config_dict, config_dir_path):
+    def generate_missing_files(
+        self, config_dict: Dict[str, Any], config_dir_path: str
+    ) -> None:
         self.invoke_all("generate_files", config_dict, config_dir_path)
 
 
-def read_config_files(config_files):
+def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
     """Read the config files into a dict
 
     Args:
-        config_files (iterable[str]): A list of the config files to read
+        config_files: A list of the config files to read
 
-    Returns: dict
+    Returns:
+        The configuration dictionary.
     """
     specified_config = {}
     for config_file in config_files:
@@ -733,17 +762,17 @@ def read_config_files(config_files):
     return specified_config
 
 
-def find_config_files(search_paths):
+def find_config_files(search_paths: List[str]) -> List[str]:
     """Finds config files using a list of search paths. If a path is a file
     then that file path is added to the list. If a search path is a directory
     then all the "*.yaml" files in that directory are added to the list in
     sorted order.
 
     Args:
-        search_paths(list(str)): A list of paths to search.
+        search_paths: A list of paths to search.
 
     Returns:
-        list(str): A list of file paths.
+        A list of file paths.
     """
 
     config_files = []
@@ -777,7 +806,7 @@ def find_config_files(search_paths):
     return config_files
 
 
-@attr.s
+@attr.s(auto_attribs=True)
 class ShardedWorkerHandlingConfig:
     """Algorithm for choosing which instance is responsible for handling some
     sharded work.
@@ -787,7 +816,7 @@ class ShardedWorkerHandlingConfig:
     below).
     """
 
-    instances = attr.ib(type=List[str])
+    instances: List[str]
 
     def should_handle(self, instance_name: str, key: str) -> bool:
         """Whether this instance is responsible for handling the given key."""