diff --git a/changelog.d/11377.bugfix b/changelog.d/11377.bugfix
new file mode 100644
index 0000000000..9831fb7bbe
--- /dev/null
+++ b/changelog.d/11377.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in v1.45.0 where the `read_templates` method of the module API would error.
diff --git a/changelog.d/11377.misc b/changelog.d/11377.misc
new file mode 100644
index 0000000000..3dac625576
--- /dev/null
+++ b/changelog.d/11377.misc
@@ -0,0 +1 @@
+Add type hints to configuration classes.
diff --git a/mypy.ini b/mypy.ini
index 308cfd95d8..bc4f59154d 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -151,6 +151,9 @@ disallow_untyped_defs = True
[mypy-synapse.app.*]
disallow_untyped_defs = True
+[mypy-synapse.config._base]
+disallow_untyped_defs = True
+
[mypy-synapse.crypto.*]
disallow_untyped_defs = True
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 7c4428a138..1265738dc1 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -20,7 +20,18 @@ import os
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
-from typing import Any, Iterable, List, MutableMapping, Optional, Union
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
import attr
import jinja2
@@ -78,7 +89,7 @@ CONFIG_FILE_HEADER = """\
"""
-def path_exists(file_path):
+def path_exists(file_path: str) -> bool:
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
@@ -86,7 +97,7 @@ def path_exists(file_path):
the parent dir).
Returns:
- bool: True if the file exists; False if not.
+ True if the file exists; False if not.
"""
try:
os.stat(file_path)
@@ -102,15 +113,15 @@ class Config:
A configuration section, containing configuration keys and values.
Attributes:
- section (str): The section title of this config object, such as
+ section: The section title of this config object, such as
"tls" or "logger". This is used to refer to it on the root
logger (for example, `config.tls.some_option`). Must be
defined in subclasses.
"""
- section = None
+ section: str
- def __init__(self, root_config=None):
+ def __init__(self, root_config: "RootConfig" = None):
self.root = root_config
# Get the path to the default Synapse template directory
@@ -119,7 +130,7 @@ class Config:
)
@staticmethod
- def parse_size(value):
+ def parse_size(value: Union[str, int]) -> int:
if isinstance(value, int):
return value
sizes = {"K": 1024, "M": 1024 * 1024}
@@ -162,15 +173,15 @@ class Config:
return int(value) * size
@staticmethod
- def abspath(file_path):
+ def abspath(file_path: str) -> str:
return os.path.abspath(file_path) if file_path else file_path
@classmethod
- def path_exists(cls, file_path):
+ def path_exists(cls, file_path: str) -> bool:
return path_exists(file_path)
@classmethod
- def check_file(cls, file_path, config_name):
+ def check_file(cls, file_path: Optional[str], config_name: str) -> str:
if file_path is None:
raise ConfigError("Missing config for %s." % (config_name,))
try:
@@ -183,7 +194,7 @@ class Config:
return cls.abspath(file_path)
@classmethod
- def ensure_directory(cls, dir_path):
+ def ensure_directory(cls, dir_path: str) -> str:
dir_path = cls.abspath(dir_path)
os.makedirs(dir_path, exist_ok=True)
if not os.path.isdir(dir_path):
@@ -191,7 +202,7 @@ class Config:
return dir_path
@classmethod
- def read_file(cls, file_path, config_name):
+ def read_file(cls, file_path: Any, config_name: str) -> str:
"""Deprecated: call read_file directly"""
return read_file(file_path, (config_name,))
@@ -284,6 +295,9 @@ class Config:
return [env.get_template(filename) for filename in filenames]
+TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
+
+
class RootConfig:
"""
Holder of an application's configuration.
@@ -308,7 +322,9 @@ class RootConfig:
raise Exception("Failed making %s: %r" % (config_class.section, e))
setattr(self, config_class.section, conf)
- def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+ def invoke_all(
+ self, func_name: str, *args: Any, **kwargs: Any
+ ) -> MutableMapping[str, Any]:
"""
Invoke a function on all instantiated config objects this RootConfig is
configured to use.
@@ -317,6 +333,7 @@ class RootConfig:
func_name: Name of function to invoke
*args
**kwargs
+
Returns:
ordered dictionary of config section name and the result of the
function from it.
@@ -332,7 +349,7 @@ class RootConfig:
return res
@classmethod
- def invoke_all_static(cls, func_name: str, *args, **kwargs):
+ def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: any) -> None:
"""
Invoke a static function on config objects this RootConfig is
configured to use.
@@ -341,6 +358,7 @@ class RootConfig:
func_name: Name of function to invoke
*args
**kwargs
+
Returns:
ordered dictionary of config section name and the result of the
function from it.
@@ -351,16 +369,16 @@ class RootConfig:
def generate_config(
self,
- config_dir_path,
- data_dir_path,
- server_name,
- generate_secrets=False,
- report_stats=None,
- open_private_ports=False,
- listeners=None,
- tls_certificate_path=None,
- tls_private_key_path=None,
- ):
+ config_dir_path: str,
+ data_dir_path: str,
+ server_name: str,
+ generate_secrets: bool = False,
+ report_stats: Optional[bool] = None,
+ open_private_ports: bool = False,
+ listeners: Optional[List[dict]] = None,
+ tls_certificate_path: Optional[str] = None,
+ tls_private_key_path: Optional[str] = None,
+ ) -> str:
"""
Build a default configuration file
@@ -368,27 +386,27 @@ class RootConfig:
(eg with --generate_config).
Args:
- config_dir_path (str): The path where the config files are kept. Used to
+ config_dir_path: The path where the config files are kept. Used to
create filenames for things like the log config and the signing key.
- data_dir_path (str): The path where the data files are kept. Used to create
+ data_dir_path: The path where the data files are kept. Used to create
filenames for things like the database and media store.
- server_name (str): The server name. Used to initialise the server_name
+ server_name: The server name. Used to initialise the server_name
config param, but also used in the names of some of the config files.
- generate_secrets (bool): True if we should generate new secrets for things
+ generate_secrets: True if we should generate new secrets for things
like the macaroon_secret_key. If False, these parameters will be left
unset.
- report_stats (bool|None): Initial setting for the report_stats setting.
+ report_stats: Initial setting for the report_stats setting.
If None, report_stats will be left unset.
- open_private_ports (bool): True to leave private ports (such as the non-TLS
+ open_private_ports: True to leave private ports (such as the non-TLS
HTTP listener) open to the internet.
- listeners (list(dict)|None): A list of descriptions of the listeners
- synapse should start with each of which specifies a port (str), a list of
+ listeners: A list of descriptions of the listeners synapse should
+ start with each of which specifies a port (int), a list of
resources (list(str)), tls (bool) and type (str). For example:
[{
"port": 8448,
@@ -403,16 +421,12 @@ class RootConfig:
"type": "http",
}],
+ tls_certificate_path: The path to the tls certificate.
- database (str|None): The database type to configure, either `psycog2`
- or `sqlite3`.
-
- tls_certificate_path (str|None): The path to the tls certificate.
-
- tls_private_key_path (str|None): The path to the tls private key.
+ tls_private_key_path: The path to the tls private key.
Returns:
- str: the yaml config file
+ The yaml config file
"""
return CONFIG_FILE_HEADER + "\n\n".join(
@@ -432,12 +446,15 @@ class RootConfig:
)
@classmethod
- def load_config(cls, description, argv):
+ def load_config(
+ cls: Type[TRootConfig], description: str, argv: List[str]
+ ) -> TRootConfig:
"""Parse the commandline and config files
Doesn't support config-file-generation: used by the worker apps.
- Returns: Config object.
+ Returns:
+ Config object.
"""
config_parser = argparse.ArgumentParser(description=description)
cls.add_arguments_to_parser(config_parser)
@@ -446,7 +463,7 @@ class RootConfig:
return obj
@classmethod
- def add_arguments_to_parser(cls, config_parser):
+ def add_arguments_to_parser(cls, config_parser: argparse.ArgumentParser) -> None:
"""Adds all the config flags to an ArgumentParser.
Doesn't support config-file-generation: used by the worker apps.
@@ -454,7 +471,7 @@ class RootConfig:
Used for workers where we want to add extra flags/subcommands.
Args:
- config_parser (ArgumentParser): App description
+ config_parser: App description
"""
config_parser.add_argument(
@@ -477,7 +494,9 @@ class RootConfig:
cls.invoke_all_static("add_arguments", config_parser)
@classmethod
- def load_config_with_parser(cls, parser, argv):
+ def load_config_with_parser(
+ cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
+ ) -> Tuple[TRootConfig, argparse.Namespace]:
"""Parse the commandline and config files with the given parser
Doesn't support config-file-generation: used by the worker apps.
@@ -485,13 +504,12 @@ class RootConfig:
Used for workers where we want to add extra flags/subcommands.
Args:
- parser (ArgumentParser)
- argv (list[str])
+ parser
+ argv
Returns:
- tuple[HomeServerConfig, argparse.Namespace]: Returns the parsed
- config object and the parsed argparse.Namespace object from
- `parser.parse_args(..)`
+ Returns the parsed config object and the parsed argparse.Namespace
+ object from parser.parse_args(..)`
"""
obj = cls()
@@ -520,12 +538,15 @@ class RootConfig:
return obj, config_args
@classmethod
- def load_or_generate_config(cls, description, argv):
+ def load_or_generate_config(
+ cls: Type[TRootConfig], description: str, argv: List[str]
+ ) -> Optional[TRootConfig]:
"""Parse the commandline and config files
Supports generation of config files, so is used for the main homeserver app.
- Returns: Config object, or None if --generate-config or --generate-keys was set
+ Returns:
+ Config object, or None if --generate-config or --generate-keys was set
"""
parser = argparse.ArgumentParser(description=description)
parser.add_argument(
@@ -680,16 +701,21 @@ class RootConfig:
return obj
- def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
+ def parse_config_dict(
+ self,
+ config_dict: Dict[str, Any],
+ config_dir_path: Optional[str] = None,
+ data_dir_path: Optional[str] = None,
+ ) -> None:
"""Read the information from the config dict into this Config object.
Args:
- config_dict (dict): Configuration data, as read from the yaml
+ config_dict: Configuration data, as read from the yaml
- config_dir_path (str): The path where the config files are kept. Used to
+ config_dir_path: The path where the config files are kept. Used to
create filenames for things like the log config and the signing key.
- data_dir_path (str): The path where the data files are kept. Used to create
+ data_dir_path: The path where the data files are kept. Used to create
filenames for things like the database and media store.
"""
self.invoke_all(
@@ -699,17 +725,20 @@ class RootConfig:
data_dir_path=data_dir_path,
)
- def generate_missing_files(self, config_dict, config_dir_path):
+ def generate_missing_files(
+ self, config_dict: Dict[str, Any], config_dir_path: str
+ ) -> None:
self.invoke_all("generate_files", config_dict, config_dir_path)
-def read_config_files(config_files):
+def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]:
"""Read the config files into a dict
Args:
- config_files (iterable[str]): A list of the config files to read
+ config_files: A list of the config files to read
- Returns: dict
+ Returns:
+ The configuration dictionary.
"""
specified_config = {}
for config_file in config_files:
@@ -733,17 +762,17 @@ def read_config_files(config_files):
return specified_config
-def find_config_files(search_paths):
+def find_config_files(search_paths: List[str]) -> List[str]:
"""Finds config files using a list of search paths. If a path is a file
then that file path is added to the list. If a search path is a directory
then all the "*.yaml" files in that directory are added to the list in
sorted order.
Args:
- search_paths(list(str)): A list of paths to search.
+ search_paths: A list of paths to search.
Returns:
- list(str): A list of file paths.
+ A list of file paths.
"""
config_files = []
@@ -777,7 +806,7 @@ def find_config_files(search_paths):
return config_files
-@attr.s
+@attr.s(auto_attribs=True)
class ShardedWorkerHandlingConfig:
"""Algorithm for choosing which instance is responsible for handling some
sharded work.
@@ -787,7 +816,7 @@ class ShardedWorkerHandlingConfig:
below).
"""
- instances = attr.ib(type=List[str])
+ instances: List[str]
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key."""
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index c1d9069798..1eb5f5a68c 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,4 +1,18 @@
-from typing import Any, Iterable, List, Optional
+import argparse
+from typing import (
+ Any,
+ Dict,
+ Iterable,
+ List,
+ MutableMapping,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
+
+import jinja2
from synapse.config import (
account_validity,
@@ -19,6 +33,7 @@ from synapse.config import (
logger,
metrics,
modules,
+ oembed,
oidc,
password_auth_providers,
push,
@@ -27,6 +42,7 @@ from synapse.config import (
registration,
repository,
retention,
+ room,
room_directory,
saml2,
server,
@@ -51,7 +67,9 @@ MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
MISSING_SERVER_NAME: str
-def path_exists(file_path: str): ...
+def path_exists(file_path: str) -> bool: ...
+
+TRootConfig = TypeVar("TRootConfig", bound="RootConfig")
class RootConfig:
server: server.ServerConfig
@@ -61,6 +79,7 @@ class RootConfig:
logging: logger.LoggingConfig
ratelimiting: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
+ oembed: oembed.OembedConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
registration: registration.RegistrationConfig
@@ -80,6 +99,7 @@ class RootConfig:
authproviders: password_auth_providers.PasswordAuthProviderConfig
push: push.PushConfig
spamchecker: spam_checker.SpamCheckerConfig
+ room: room.RoomConfig
groups: groups.GroupsConfig
userdirectory: user_directory.UserDirectoryConfig
consent: consent.ConsentConfig
@@ -87,72 +107,85 @@ class RootConfig:
servernotices: server_notices.ServerNoticesConfig
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
- tracer: tracer.TracerConfig
+ tracing: tracer.TracerConfig
redis: redis.RedisConfig
modules: modules.ModulesConfig
caches: cache.CacheConfig
federation: federation.FederationConfig
retention: retention.RetentionConfig
- config_classes: List = ...
+ config_classes: List[Type["Config"]] = ...
def __init__(self) -> None: ...
- def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
+ def invoke_all(
+ self, func_name: str, *args: Any, **kwargs: Any
+ ) -> MutableMapping[str, Any]: ...
@classmethod
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
- def __getattr__(self, item: str): ...
def parse_config_dict(
self,
- config_dict: Any,
- config_dir_path: Optional[Any] = ...,
- data_dir_path: Optional[Any] = ...,
+ config_dict: Dict[str, Any],
+ config_dir_path: Optional[str] = ...,
+ data_dir_path: Optional[str] = ...,
) -> None: ...
- read_config: Any = ...
def generate_config(
self,
config_dir_path: str,
data_dir_path: str,
server_name: str,
generate_secrets: bool = ...,
- report_stats: Optional[str] = ...,
+ report_stats: Optional[bool] = ...,
open_private_ports: bool = ...,
listeners: Optional[Any] = ...,
- database_conf: Optional[Any] = ...,
tls_certificate_path: Optional[str] = ...,
tls_private_key_path: Optional[str] = ...,
- ): ...
+ ) -> str: ...
@classmethod
- def load_or_generate_config(cls, description: Any, argv: Any): ...
+ def load_or_generate_config(
+ cls: Type[TRootConfig], description: str, argv: List[str]
+ ) -> Optional[TRootConfig]: ...
@classmethod
- def load_config(cls, description: Any, argv: Any): ...
+ def load_config(
+ cls: Type[TRootConfig], description: str, argv: List[str]
+ ) -> TRootConfig: ...
@classmethod
- def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
+ def add_arguments_to_parser(
+ cls, config_parser: argparse.ArgumentParser
+ ) -> None: ...
@classmethod
- def load_config_with_parser(cls, parser: Any, argv: Any): ...
+ def load_config_with_parser(
+ cls: Type[TRootConfig], parser: argparse.ArgumentParser, argv: List[str]
+ ) -> Tuple[TRootConfig, argparse.Namespace]: ...
def generate_missing_files(
self, config_dict: dict, config_dir_path: str
) -> None: ...
class Config:
root: RootConfig
+ default_template_dir: str
def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
- def __getattr__(self, item: str, from_root: bool = ...): ...
@staticmethod
- def parse_size(value: Any): ...
+ def parse_size(value: Union[str, int]) -> int: ...
@staticmethod
- def parse_duration(value: Any): ...
+ def parse_duration(value: Union[str, int]) -> int: ...
@staticmethod
- def abspath(file_path: Optional[str]): ...
+ def abspath(file_path: Optional[str]) -> str: ...
@classmethod
- def path_exists(cls, file_path: str): ...
+ def path_exists(cls, file_path: str) -> bool: ...
@classmethod
- def check_file(cls, file_path: str, config_name: str): ...
+ def check_file(cls, file_path: str, config_name: str) -> str: ...
@classmethod
- def ensure_directory(cls, dir_path: str): ...
+ def ensure_directory(cls, dir_path: str) -> str: ...
@classmethod
- def read_file(cls, file_path: str, config_name: str): ...
+ def read_file(cls, file_path: str, config_name: str) -> str: ...
+ def read_template(self, filenames: str) -> jinja2.Template: ...
+ def read_templates(
+ self,
+ filenames: List[str],
+ custom_template_directories: Optional[Iterable[str]] = None,
+ ) -> List[jinja2.Template]: ...
-def read_config_files(config_files: List[str]): ...
-def find_config_files(search_paths: List[str]): ...
+def read_config_files(config_files: Iterable[str]) -> Dict[str, Any]: ...
+def find_config_files(search_paths: List[str]) -> List[str]: ...
class ShardedWorkerHandlingConfig:
instances: List[str]
diff --git a/synapse/config/cache.py b/synapse/config/cache.py
index d119427ad8..f054455534 100644
--- a/synapse/config/cache.py
+++ b/synapse/config/cache.py
@@ -15,7 +15,7 @@
import os
import re
import threading
-from typing import Callable, Dict
+from typing import Callable, Dict, Optional
from synapse.python_dependencies import DependencyException, check_requirements
@@ -217,7 +217,7 @@ class CacheConfig(Config):
expiry_time = cache_config.get("expiry_time")
if expiry_time:
- self.expiry_time_msec = self.parse_duration(expiry_time)
+ self.expiry_time_msec: Optional[int] = self.parse_duration(expiry_time)
else:
self.expiry_time_msec = None
diff --git a/synapse/config/key.py b/synapse/config/key.py
index 015dbb8a67..035ee2416b 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -16,6 +16,7 @@
import hashlib
import logging
import os
+from typing import Any, Dict
import attr
import jsonschema
@@ -312,7 +313,7 @@ class KeyConfig(Config):
)
return keys
- def generate_files(self, config, config_dir_path):
+ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
if "signing_key" in config:
return
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 5252e61a99..63aab0babe 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -18,7 +18,7 @@ import os
import sys
import threading
from string import Template
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Dict
import yaml
from zope.interface import implementer
@@ -185,7 +185,7 @@ class LoggingConfig(Config):
help=argparse.SUPPRESS,
)
- def generate_files(self, config, config_dir_path):
+ def generate_files(self, config: Dict[str, Any], config_dir_path: str) -> None:
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
log_file = self.abspath("homeserver.log")
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 7bc0030a9e..8445e9dd05 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -421,7 +421,7 @@ class ServerConfig(Config):
# before redacting them.
redaction_retention_period = config.get("redaction_retention_period", "7d")
if redaction_retention_period is not None:
- self.redaction_retention_period = self.parse_duration(
+ self.redaction_retention_period: Optional[int] = self.parse_duration(
redaction_retention_period
)
else:
@@ -430,7 +430,7 @@ class ServerConfig(Config):
# How long to keep entries in the `users_ips` table.
user_ips_max_age = config.get("user_ips_max_age", "28d")
if user_ips_max_age is not None:
- self.user_ips_max_age = self.parse_duration(user_ips_max_age)
+ self.user_ips_max_age: Optional[int] = self.parse_duration(user_ips_max_age)
else:
self.user_ips_max_age = None
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 21e5ddd15f..4ca111618f 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -245,7 +245,7 @@ class TlsConfig(Config):
cert_path = self.tls_certificate_file
logger.info("Loading TLS certificate from %s", cert_path)
cert_pem = self.read_file(cert_path, "tls_certificate_path")
- cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem)
+ cert = crypto.load_certificate(crypto.FILETYPE_PEM, cert_pem.encode())
return cert
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index ac8e8142f1..96d7a8f2a9 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -1014,7 +1014,7 @@ class ModuleApi:
A list containing the loaded templates, with the orders matching the one of
the filenames parameter.
"""
- return self._hs.config.read_templates(
+ return self._hs.config.server.read_templates(
filenames,
(td for td in (self.custom_template_dir, custom_template_directory) if td),
)
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 8478463a2a..0e8c168667 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -1198,8 +1198,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
expiration_ts = now_ms + self._account_validity_period
if use_delta:
+ assert self._account_validity_startup_job_max_delta is not None
expiration_ts = random.randrange(
- expiration_ts - self._account_validity_startup_job_max_delta,
+ int(expiration_ts - self._account_validity_startup_job_max_delta),
expiration_ts,
)
diff --git a/tests/config/test_load.py b/tests/config/test_load.py
index d8668d56b2..69a4e9413b 100644
--- a/tests/config/test_load.py
+++ b/tests/config/test_load.py
@@ -46,15 +46,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
"was: %r" % (config.key.macaroon_secret_key,)
)
- config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config2 is not None
self.assertTrue(
- hasattr(config.key, "macaroon_secret_key"),
+ hasattr(config2.key, "macaroon_secret_key"),
"Want config to have attr macaroon_secret_key",
)
- if len(config.key.macaroon_secret_key) < 5:
+ if len(config2.key.macaroon_secret_key) < 5:
self.fail(
"Want macaroon secret key to be string of at least length 5,"
- "was: %r" % (config.key.macaroon_secret_key,)
+ "was: %r" % (config2.key.macaroon_secret_key,)
)
def test_load_succeeds_if_macaroon_secret_key_missing(self):
@@ -62,6 +63,9 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config1 = HomeServerConfig.load_config("", ["-c", self.config_file])
config2 = HomeServerConfig.load_config("", ["-c", self.config_file])
config3 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config1 is not None
+ assert config2 is not None
+ assert config3 is not None
self.assertEqual(
config1.key.macaroon_secret_key, config2.key.macaroon_secret_key
)
@@ -78,14 +82,16 @@ class ConfigLoadingFileTestCase(ConfigFileTestCase):
config = HomeServerConfig.load_config("", ["-c", self.config_file])
self.assertFalse(config.registration.enable_registration)
- config = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
- self.assertFalse(config.registration.enable_registration)
+ config2 = HomeServerConfig.load_or_generate_config("", ["-c", self.config_file])
+ assert config2 is not None
+ self.assertFalse(config2.registration.enable_registration)
# Check that either config value is clobbered by the command line.
- config = HomeServerConfig.load_or_generate_config(
+ config3 = HomeServerConfig.load_or_generate_config(
"", ["-c", self.config_file, "--enable-registration"]
)
- self.assertTrue(config.registration.enable_registration)
+ assert config3 is not None
+ self.assertTrue(config3.registration.enable_registration)
def test_stats_enabled(self):
self.generate_config_and_remove_lines_containing("enable_metrics")
|