diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 31f6530978..08619404bb 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -18,7 +18,9 @@
import argparse
import errno
import os
+from collections import OrderedDict
from textwrap import dedent
+from typing import Any, MutableMapping, Optional
from six import integer_types
@@ -51,7 +53,56 @@ Missing mandatory `server_name` config option.
"""
+def path_exists(file_path):
+ """Check if a file exists
+
+ Unlike os.path.exists, this throws an exception if there is an error
+ checking if the file exists (for example, if there is a perms error on
+ the parent dir).
+
+ Returns:
+ bool: True if the file exists; False if not.
+ """
+ try:
+ os.stat(file_path)
+ return True
+ except OSError as e:
+ if e.errno != errno.ENOENT:
+ raise e
+ return False
+
+
class Config(object):
+ """
+ A configuration section, containing configuration keys and values.
+
+ Attributes:
+ section (str): The section title of this config object, such as
+ "tls" or "logger". This is used to refer to it on the root
+ logger (for example, `config.tls.some_option`). Must be
+ defined in subclasses.
+ """
+
+ section = None
+
+ def __init__(self, root_config=None):
+ self.root = root_config
+
+ def __getattr__(self, item: str) -> Any:
+ """
+ Try and fetch a configuration option that does not exist on this class.
+
+ This is so that existing configs that rely on `self.value`, where value
+ is actually from a different config section, continue to work.
+ """
+ if item in ["generate_config_section", "read_config"]:
+ raise AttributeError(item)
+
+ if self.root is None:
+ raise AttributeError(item)
+ else:
+ return self.root._get_unclassed_config(self.section, item)
+
@staticmethod
def parse_size(value):
if isinstance(value, integer_types):
@@ -88,22 +139,7 @@ class Config(object):
@classmethod
def path_exists(cls, file_path):
- """Check if a file exists
-
- Unlike os.path.exists, this throws an exception if there is an error
- checking if the file exists (for example, if there is a perms error on
- the parent dir).
-
- Returns:
- bool: True if the file exists; False if not.
- """
- try:
- os.stat(file_path)
- return True
- except OSError as e:
- if e.errno != errno.ENOENT:
- raise e
- return False
+ return path_exists(file_path)
@classmethod
def check_file(cls, file_path, config_name):
@@ -136,42 +172,106 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
- def invoke_all(self, name, *args, **kargs):
- """Invoke all instance methods with the given name and arguments in the
- class's MRO.
+
+class RootConfig(object):
+ """
+ Holder of an application's configuration.
+
+ What configuration this object holds is defined by `config_classes`, a list
+ of Config classes that will be instantiated and given the contents of a
+ configuration file to read. They can then be accessed on this class by their
+ section name, defined in the Config or dynamically set to be the name of the
+ class, lower-cased and with "Config" removed.
+ """
+
+ config_classes = []
+
+ def __init__(self):
+ self._configs = OrderedDict()
+
+ for config_class in self.config_classes:
+ if config_class.section is None:
+ raise ValueError("%r requires a section name" % (config_class,))
+
+ try:
+ conf = config_class(self)
+ except Exception as e:
+ raise Exception("Failed making %s: %r" % (config_class.section, e))
+ self._configs[config_class.section] = conf
+
+ def __getattr__(self, item: str) -> Any:
+ """
+ Redirect lookups on this object either to config objects, or values on
+ config objects, so that `config.tls.blah` works, as well as legacy uses
+ of things like `config.server_name`. It will first look up the config
+ section name, and then values on those config classes.
+ """
+ if item in self._configs.keys():
+ return self._configs[item]
+
+ return self._get_unclassed_config(None, item)
+
+ def _get_unclassed_config(self, asking_section: Optional[str], item: str):
+ """
+ Fetch a config value from one of the instantiated config classes that
+ has not been fetched directly.
+
+ Args:
+ asking_section: If this check is coming from a Config child, which
+ one? This section will not be asked if it has the value.
+ item: The configuration value key.
+
+ Raises:
+ AttributeError if no config classes have the config key. The body
+ will contain what sections were checked.
+ """
+ for key, val in self._configs.items():
+ if key == asking_section:
+ continue
+
+ if item in dir(val):
+ return getattr(val, item)
+
+ raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
+
+ def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
+ """
+ Invoke a function on all instantiated config objects this RootConfig is
+ configured to use.
Args:
- name (str): Name of function to invoke
+ func_name: Name of function to invoke
*args
**kwargs
-
Returns:
- list: The list of the return values from each method called
+ ordered dictionary of config section name and the result of the
+ function from it.
"""
- results = []
- for cls in type(self).mro():
- if name in cls.__dict__:
- results.append(getattr(cls, name)(self, *args, **kargs))
- return results
+ res = OrderedDict()
+
+ for name, config in self._configs.items():
+ if hasattr(config, func_name):
+ res[name] = getattr(config, func_name)(*args, **kwargs)
+
+ return res
@classmethod
- def invoke_all_static(cls, name, *args, **kargs):
- """Invoke all static methods with the given name and arguments in the
- class's MRO.
+ def invoke_all_static(cls, func_name: str, *args, **kwargs):
+ """
+ Invoke a static function on config objects this RootConfig is
+ configured to use.
Args:
- name (str): Name of function to invoke
+ func_name: Name of function to invoke
*args
**kwargs
-
Returns:
- list: The list of the return values from each method called
+ ordered dictionary of config section name and the result of the
+ function from it.
"""
- results = []
- for c in cls.mro():
- if name in c.__dict__:
- results.append(getattr(c, name)(*args, **kargs))
- return results
+ for config in cls.config_classes:
+ if hasattr(config, func_name):
+ getattr(config, func_name)(*args, **kwargs)
def generate_config(
self,
@@ -187,7 +287,8 @@ class Config(object):
tls_private_key_path=None,
acme_domain=None,
):
- """Build a default configuration file
+ """
+ Build a default configuration file
This is used when the user explicitly asks us to generate a config file
(eg with --generate_config).
@@ -242,6 +343,7 @@ class Config(object):
Returns:
str: the yaml config file
"""
+
return "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
@@ -257,7 +359,7 @@ class Config(object):
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,
- )
+ ).values()
)
@classmethod
@@ -444,7 +546,7 @@ class Config(object):
)
(config_path,) = config_files
- if not cls.path_exists(config_path):
+ if not path_exists(config_path):
print("Generating config file %s" % (config_path,))
if config_args.data_directory:
@@ -469,7 +571,7 @@ class Config(object):
open_private_ports=config_args.open_private_ports,
)
- if not cls.path_exists(config_dir_path):
+ if not path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
config_file.write("# vim:ft=yaml\n\n")
@@ -518,7 +620,7 @@ class Config(object):
return obj
- def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
+ def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
"""Read the information from the config dict into this Config object.
Args:
@@ -607,3 +709,6 @@ def find_config_files(search_paths):
else:
config_files.append(config_path)
return config_files
+
+
+__all__ = ["Config", "RootConfig"]
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
new file mode 100644
index 0000000000..86bc965ee4
--- /dev/null
+++ b/synapse/config/_base.pyi
@@ -0,0 +1,135 @@
+from typing import Any, List, Optional
+
+from synapse.config import (
+ api,
+ appservice,
+ captcha,
+ cas,
+ consent_config,
+ database,
+ emailconfig,
+ groups,
+ jwt_config,
+ key,
+ logger,
+ metrics,
+ password,
+ password_auth_providers,
+ push,
+ ratelimiting,
+ registration,
+ repository,
+ room_directory,
+ saml2_config,
+ server,
+ server_notices_config,
+ spam_checker,
+ stats,
+ third_party_event_rules,
+ tls,
+ tracer,
+ user_directory,
+ voip,
+ workers,
+)
+
+class ConfigError(Exception): ...
+
+MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
+MISSING_REPORT_STATS_SPIEL: str
+MISSING_SERVER_NAME: str
+
+def path_exists(file_path: str): ...
+
+class RootConfig:
+ server: server.ServerConfig
+ tls: tls.TlsConfig
+ database: database.DatabaseConfig
+ logging: logger.LoggingConfig
+ ratelimit: ratelimiting.RatelimitConfig
+ media: repository.ContentRepositoryConfig
+ captcha: captcha.CaptchaConfig
+ voip: voip.VoipConfig
+ registration: registration.RegistrationConfig
+ metrics: metrics.MetricsConfig
+ api: api.ApiConfig
+ appservice: appservice.AppServiceConfig
+ key: key.KeyConfig
+ saml2: saml2_config.SAML2Config
+ cas: cas.CasConfig
+ jwt: jwt_config.JWTConfig
+ password: password.PasswordConfig
+ email: emailconfig.EmailConfig
+ worker: workers.WorkerConfig
+ authproviders: password_auth_providers.PasswordAuthProviderConfig
+ push: push.PushConfig
+ spamchecker: spam_checker.SpamCheckerConfig
+ groups: groups.GroupsConfig
+ userdirectory: user_directory.UserDirectoryConfig
+ consent: consent_config.ConsentConfig
+ stats: stats.StatsConfig
+ servernotices: server_notices_config.ServerNoticesConfig
+ roomdirectory: room_directory.RoomDirectoryConfig
+ thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
+ tracer: tracer.TracerConfig
+
+ config_classes: List = ...
+ def __init__(self) -> None: ...
+ def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
+ @classmethod
+ def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
+ def __getattr__(self, item: str): ...
+ def parse_config_dict(
+ self,
+ config_dict: Any,
+ config_dir_path: Optional[Any] = ...,
+ data_dir_path: Optional[Any] = ...,
+ ) -> None: ...
+ read_config: Any = ...
+ def generate_config(
+ self,
+ config_dir_path: str,
+ data_dir_path: str,
+ server_name: str,
+ generate_secrets: bool = ...,
+ report_stats: Optional[str] = ...,
+ open_private_ports: bool = ...,
+ listeners: Optional[Any] = ...,
+ database_conf: Optional[Any] = ...,
+ tls_certificate_path: Optional[str] = ...,
+ tls_private_key_path: Optional[str] = ...,
+ acme_domain: Optional[str] = ...,
+ ): ...
+ @classmethod
+ def load_or_generate_config(cls, description: Any, argv: Any): ...
+ @classmethod
+ def load_config(cls, description: Any, argv: Any): ...
+ @classmethod
+ def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
+ @classmethod
+ def load_config_with_parser(cls, parser: Any, argv: Any): ...
+ def generate_missing_files(
+ self, config_dict: dict, config_dir_path: str
+ ) -> None: ...
+
+class Config:
+ root: RootConfig
+ def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
+ def __getattr__(self, item: str, from_root: bool = ...): ...
+ @staticmethod
+ def parse_size(value: Any): ...
+ @staticmethod
+ def parse_duration(value: Any): ...
+ @staticmethod
+ def abspath(file_path: Optional[str]): ...
+ @classmethod
+ def path_exists(cls, file_path: str): ...
+ @classmethod
+ def check_file(cls, file_path: str, config_name: str): ...
+ @classmethod
+ def ensure_directory(cls, dir_path: str): ...
+ @classmethod
+ def read_file(cls, file_path: str, config_name: str): ...
+
+def read_config_files(config_files: List[str]): ...
+def find_config_files(search_paths: List[str]): ...
diff --git a/synapse/config/api.py b/synapse/config/api.py
index dddea79a8a..74cd53a8ed 100644
--- a/synapse/config/api.py
+++ b/synapse/config/api.py
@@ -18,6 +18,8 @@ from ._base import Config
class ApiConfig(Config):
+ section = "api"
+
def read_config(self, config, **kwargs):
self.room_invite_state_types = config.get(
"room_invite_state_types",
diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py
index 28d36b1bc3..9b4682222d 100644
--- a/synapse/config/appservice.py
+++ b/synapse/config/appservice.py
@@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
+ section = "appservice"
+
def read_config(self, config, **kwargs):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)
diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py
index 8dac8152cf..44bd5c6799 100644
--- a/synapse/config/captcha.py
+++ b/synapse/config/captcha.py
@@ -16,6 +16,8 @@ from ._base import Config
class CaptchaConfig(Config):
+ section = "captcha"
+
def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key")
diff --git a/synapse/config/cas.py b/synapse/config/cas.py
index ebe34d933b..b916c3aa66 100644
--- a/synapse/config/cas.py
+++ b/synapse/config/cas.py
@@ -22,6 +22,8 @@ class CasConfig(Config):
cas_server_url: URL of CAS server
"""
+ section = "cas"
+
def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None)
if cas_config:
diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py
index 48976e17b1..62c4c44d60 100644
--- a/synapse/config/consent_config.py
+++ b/synapse/config/consent_config.py
@@ -73,6 +73,9 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config):
+
+ section = "consent"
+
def __init__(self, *args):
super(ConsentConfig, self).__init__(*args)
diff --git a/synapse/config/database.py b/synapse/config/database.py
index 118aafbd4a..0e2509f0b1 100644
--- a/synapse/config/database.py
+++ b/synapse/config/database.py
@@ -21,6 +21,8 @@ from ._base import Config
class DatabaseConfig(Config):
+ section = "database"
+
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index d9b43de660..658897a77e 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -28,6 +28,8 @@ from ._base import Config, ConfigError
class EmailConfig(Config):
+ section = "email"
+
def read_config(self, config, **kwargs):
# TODO: We should separate better the email configuration from the notification
# and account validity config.
diff --git a/synapse/config/groups.py b/synapse/config/groups.py
index 2a522b5f44..d6862d9a64 100644
--- a/synapse/config/groups.py
+++ b/synapse/config/groups.py
@@ -17,6 +17,8 @@ from ._base import Config
class GroupsConfig(Config):
+ section = "groups"
+
def read_config(self, config, **kwargs):
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 72acad4f18..6e348671c7 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@@ -46,36 +47,37 @@ from .voip import VoipConfig
from .workers import WorkerConfig
-class HomeServerConfig(
- ServerConfig,
- TlsConfig,
- DatabaseConfig,
- LoggingConfig,
- RatelimitConfig,
- ContentRepositoryConfig,
- CaptchaConfig,
- VoipConfig,
- RegistrationConfig,
- MetricsConfig,
- ApiConfig,
- AppServiceConfig,
- KeyConfig,
- SAML2Config,
- CasConfig,
- JWTConfig,
- PasswordConfig,
- EmailConfig,
- WorkerConfig,
- PasswordAuthProviderConfig,
- PushConfig,
- SpamCheckerConfig,
- GroupsConfig,
- UserDirectoryConfig,
- ConsentConfig,
- StatsConfig,
- ServerNoticesConfig,
- RoomDirectoryConfig,
- ThirdPartyRulesConfig,
- TracerConfig,
-):
- pass
+class HomeServerConfig(RootConfig):
+
+ config_classes = [
+ ServerConfig,
+ TlsConfig,
+ DatabaseConfig,
+ LoggingConfig,
+ RatelimitConfig,
+ ContentRepositoryConfig,
+ CaptchaConfig,
+ VoipConfig,
+ RegistrationConfig,
+ MetricsConfig,
+ ApiConfig,
+ AppServiceConfig,
+ KeyConfig,
+ SAML2Config,
+ CasConfig,
+ JWTConfig,
+ PasswordConfig,
+ EmailConfig,
+ WorkerConfig,
+ PasswordAuthProviderConfig,
+ PushConfig,
+ SpamCheckerConfig,
+ GroupsConfig,
+ UserDirectoryConfig,
+ ConsentConfig,
+ StatsConfig,
+ ServerNoticesConfig,
+ RoomDirectoryConfig,
+ ThirdPartyRulesConfig,
+ TracerConfig,
+ ]
diff --git a/synapse/config/jwt_config.py b/synapse/config/jwt_config.py
index 36d87cef03..a568726985 100644
--- a/synapse/config/jwt_config.py
+++ b/synapse/config/jwt_config.py
@@ -23,6 +23,8 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login.
class JWTConfig(Config):
+ section = "jwt"
+
def read_config(self, config, **kwargs):
jwt_config = config.get("jwt_config", None)
if jwt_config:
diff --git a/synapse/config/key.py b/synapse/config/key.py
index f039f96e9c..ec5d430afb 100644
--- a/synapse/config/key.py
+++ b/synapse/config/key.py
@@ -92,6 +92,8 @@ class TrustedKeyServer(object):
class KeyConfig(Config):
+ section = "key"
+
def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file
if "signing_key" in config:
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index 767ecfdf09..d609ec111b 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -84,6 +84,8 @@ root:
class LoggingConfig(Config):
+ section = "logging"
+
def read_config(self, config, **kwargs):
self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False)
diff --git a/synapse/config/metrics.py b/synapse/config/metrics.py
index ec35a6b868..282a43bddb 100644
--- a/synapse/config/metrics.py
+++ b/synapse/config/metrics.py
@@ -34,6 +34,8 @@ class MetricsFlags(object):
class MetricsConfig(Config):
+ section = "metrics"
+
def read_config(self, config, **kwargs):
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)
diff --git a/synapse/config/password.py b/synapse/config/password.py
index d5b5953f2f..2a634ac751 100644
--- a/synapse/config/password.py
+++ b/synapse/config/password.py
@@ -20,6 +20,8 @@ class PasswordConfig(Config):
"""Password login configuration
"""
+ section = "password"
+
def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
if password_config is None:
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index c50e244394..9746bbc681 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -23,6 +23,8 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
+ section = "authproviders"
+
def read_config(self, config, **kwargs):
self.password_providers = [] # type: List[Any]
providers = []
diff --git a/synapse/config/push.py b/synapse/config/push.py
index 1b932722a5..0910958649 100644
--- a/synapse/config/push.py
+++ b/synapse/config/push.py
@@ -18,6 +18,8 @@ from ._base import Config
class PushConfig(Config):
+ section = "push"
+
def read_config(self, config, **kwargs):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py
index 587e2862b7..947f653e03 100644
--- a/synapse/config/ratelimiting.py
+++ b/synapse/config/ratelimiting.py
@@ -36,6 +36,8 @@ class FederationRateLimitConfig(object):
class RatelimitConfig(Config):
+ section = "ratelimiting"
+
def read_config(self, config, **kwargs):
# Load the new-style messages config if it exists. Otherwise fall back
diff --git a/synapse/config/registration.py b/synapse/config/registration.py
index bef89e2bf4..b3e3e6dda2 100644
--- a/synapse/config/registration.py
+++ b/synapse/config/registration.py
@@ -24,6 +24,8 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
+ section = "accountvalidity"
+
def __init__(self, config, synapse_config):
self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = "renew_at" in config
@@ -77,6 +79,8 @@ class AccountValidityConfig(Config):
class RegistrationConfig(Config):
+ section = "registration"
+
def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index 14740891f3..d0205e14b9 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -78,6 +78,8 @@ def parse_thumbnail_requirements(thumbnail_sizes):
class ContentRepositoryConfig(Config):
+ section = "media"
+
def read_config(self, config, **kwargs):
# Only enable the media repo if either the media repo is enabled or the
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index a92693017b..7c9f05bde4 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
+ section = "roomdirectory"
+
def read_config(self, config, **kwargs):
self.enable_room_list_search = config.get("enable_room_list_search", True)
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index ab34b41ca8..c407e13680 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -55,6 +55,8 @@ def _dict_merge(merge_dict, into_dict):
class SAML2Config(Config):
+ section = "saml2"
+
def read_config(self, config, **kwargs):
self.saml2_enabled = False
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 709bd387e5..afc4d6a4ab 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -58,6 +58,8 @@ on how to configure the new listener.
class ServerConfig(Config):
+ section = "server"
+
def read_config(self, config, **kwargs):
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)
diff --git a/synapse/config/server_notices_config.py b/synapse/config/server_notices_config.py
index 6d4285ef93..6ea2ea8869 100644
--- a/synapse/config/server_notices_config.py
+++ b/synapse/config/server_notices_config.py
@@ -59,6 +59,8 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled.
"""
+ section = "servernotices"
+
def __init__(self, *args):
super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index e40797ab50..36e0ddab5c 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -19,6 +19,8 @@ from ._base import Config
class SpamCheckerConfig(Config):
+ section = "spamchecker"
+
def read_config(self, config, **kwargs):
self.spam_checker = None
diff --git a/synapse/config/stats.py b/synapse/config/stats.py
index b18ddbd1fa..62485189ea 100644
--- a/synapse/config/stats.py
+++ b/synapse/config/stats.py
@@ -25,6 +25,8 @@ class StatsConfig(Config):
Configuration for the behaviour of synapse's stats engine
"""
+ section = "stats"
+
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index b3431441b9..10a99c792e 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -19,6 +19,8 @@ from ._base import Config
class ThirdPartyRulesConfig(Config):
+ section = "thirdpartyrules"
+
def read_config(self, config, **kwargs):
self.third_party_event_rules = None
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index fc47ba3e9a..f06341eb67 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -18,6 +18,7 @@ import os
import warnings
from datetime import datetime
from hashlib import sha256
+from typing import List
import six
@@ -33,7 +34,9 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config):
- def read_config(self, config, config_dir_path, **kwargs):
+ section = "tls"
+
+ def read_config(self, config: dict, config_dir_path: str, **kwargs):
acme_config = config.get("acme", None)
if acme_config is None:
@@ -57,7 +60,7 @@ class TlsConfig(Config):
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
- if self.has_tls_listener():
+ if self.root.server.has_tls_listener():
if not self.tls_certificate_file:
raise ConfigError(
"tls_certificate_path must be specified if TLS-enabled listeners are "
@@ -108,7 +111,7 @@ class TlsConfig(Config):
)
# Support globs (*) in whitelist values
- self.federation_certificate_verification_whitelist = []
+ self.federation_certificate_verification_whitelist = [] # type: List[str]
for entry in fed_whitelist_entries:
try:
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))
diff --git a/synapse/config/tracer.py b/synapse/config/tracer.py
index 85d99a3166..8be1346113 100644
--- a/synapse/config/tracer.py
+++ b/synapse/config/tracer.py
@@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class TracerConfig(Config):
+ section = "tracing"
+
def read_config(self, config, **kwargs):
opentracing_config = config.get("opentracing")
if opentracing_config is None:
diff --git a/synapse/config/user_directory.py b/synapse/config/user_directory.py
index f6313e17d4..c8d19c5d6b 100644
--- a/synapse/config/user_directory.py
+++ b/synapse/config/user_directory.py
@@ -21,6 +21,8 @@ class UserDirectoryConfig(Config):
Configuration for the behaviour of the /user_directory API
"""
+ section = "userdirectory"
+
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False
diff --git a/synapse/config/voip.py b/synapse/config/voip.py
index 2ca0e1cf70..a68a3068aa 100644
--- a/synapse/config/voip.py
+++ b/synapse/config/voip.py
@@ -16,6 +16,8 @@ from ._base import Config
class VoipConfig(Config):
+ section = "voip"
+
def read_config(self, config, **kwargs):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 1ec4998625..fef72ed974 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -21,6 +21,8 @@ class WorkerConfig(Config):
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
+ section = "worker"
+
def read_config(self, config, **kwargs):
self.worker_app = config.get("worker_app")
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 053cf66b28..2a5f1a007d 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -803,17 +803,25 @@ class PresenceHandler(object):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "presence_delta"):
- deltas = yield self.store.get_current_state_deltas(self._event_pos)
- if not deltas:
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self._event_pos == room_max_stream_ordering:
return
+ logger.debug(
+ "Processing presence stats %s->%s",
+ self._event_pos,
+ room_max_stream_ordering,
+ )
+ max_pos, deltas = yield self.store.get_current_state_deltas(
+ self._event_pos, room_max_stream_ordering
+ )
yield self._handle_state_delta(deltas)
- self._event_pos = deltas[-1]["stream_id"]
+ self._event_pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("presence").set(
- self._event_pos
+ max_pos
)
@defer.inlineCallbacks
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 1edc657f8a..380e2fad5e 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -203,23 +203,11 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
- yield self._user_joined_room(target, room_id)
-
- # Copy over direct message status and room tags if this is a join
- # on an upgraded room
-
- # Check if this is an upgraded room
- predecessor = yield self.store.get_room_predecessor(room_id)
-
- if predecessor:
- # It is an upgraded room. Copy over old tags
- yield self.copy_room_tags_and_direct_to_room(
- predecessor["room_id"], room_id, user_id
- )
- # Copy over push rules
- yield self.store.copy_push_rules_from_room_to_room_for_user(
- predecessor["room_id"], room_id, user_id
+ # Copy over user state if we're joining an upgraded room
+ yield self.copy_user_state_if_room_upgrade(
+ room_id, requester.user.to_string()
)
+ yield self._user_joined_room(target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
@@ -463,10 +451,16 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
- ret = yield self._remote_join(
+ remote_join_response = yield self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
- return ret
+
+ # Copy over user state if this is a join on an remote upgraded room
+ yield self.copy_user_state_if_room_upgrade(
+ room_id, requester.user.to_string()
+ )
+
+ return remote_join_response
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
@@ -504,6 +498,38 @@ class RoomMemberHandler(object):
return res
@defer.inlineCallbacks
+ def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
+ """Copy user-specific information when they join a new room if that new room is the
+ result of a room upgrade
+
+ Args:
+ new_room_id (str): The ID of the room the user is joining
+ user_id (str): The ID of the user
+
+ Returns:
+ Deferred
+ """
+ # Check if the new room is an upgraded room
+ predecessor = yield self.store.get_room_predecessor(new_room_id)
+ if not predecessor:
+ return
+
+ logger.debug(
+ "Found predecessor for %s: %s. Copying over room tags and push " "rules",
+ new_room_id,
+ predecessor,
+ )
+
+ # It is an upgraded room. Copy over old tags
+ yield self.copy_room_tags_and_direct_to_room(
+ predecessor["room_id"], new_room_id, user_id
+ )
+ # Copy over push rules
+ yield self.store.copy_push_rules_from_room_to_room_for_user(
+ predecessor["room_id"], new_room_id, user_id
+ )
+
+ @defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
"""
Change the membership status of a user in a room.
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index c62b113115..466daf9202 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -87,21 +87,23 @@ class StatsHandler(StateDeltasHandler):
# Be sure to read the max stream_ordering *before* checking if there are any outstanding
# deltas, since there is otherwise a chance that we could miss updates which arrive
# after we check the deltas.
- room_max_stream_ordering = yield self.store.get_room_max_stream_ordering()
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self.pos == room_max_stream_ordering:
break
- deltas = yield self.store.get_current_state_deltas(self.pos)
+ logger.debug(
+ "Processing room stats %s->%s", self.pos, room_max_stream_ordering
+ )
+ max_pos, deltas = yield self.store.get_current_state_deltas(
+ self.pos, room_max_stream_ordering
+ )
if deltas:
logger.debug("Handling %d state deltas", len(deltas))
room_deltas, user_deltas = yield self._handle_deltas(deltas)
-
- max_pos = deltas[-1]["stream_id"]
else:
room_deltas = {}
user_deltas = {}
- max_pos = room_max_stream_ordering
# Then count deltas for total_events and total_event_bytes.
room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index e53669e40d..624f05ab5b 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -138,21 +138,28 @@ class UserDirectoryHandler(StateDeltasHandler):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
- deltas = yield self.store.get_current_state_deltas(self.pos)
- if not deltas:
+ room_max_stream_ordering = self.store.get_room_max_stream_ordering()
+ if self.pos == room_max_stream_ordering:
return
+ logger.debug(
+ "Processing user stats %s->%s", self.pos, room_max_stream_ordering
+ )
+ max_pos, deltas = yield self.store.get_current_state_deltas(
+ self.pos, room_max_stream_ordering
+ )
+
logger.info("Handling %d state deltas", len(deltas))
yield self._handle_deltas(deltas)
- self.pos = deltas[-1]["stream_id"]
+ self.pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("user_dir").set(
- self.pos
+ max_pos
)
- yield self.store.update_user_directory_stream_pos(self.pos)
+ yield self.store.update_user_directory_stream_pos(max_pos)
@defer.inlineCallbacks
def _handle_deltas(self, deltas):
diff --git a/synapse/storage/schema/delta/56/unique_user_filter_index.py b/synapse/storage/schema/delta/56/unique_user_filter_index.py
index 60031f23ca..1de8b54961 100644
--- a/synapse/storage/schema/delta/56/unique_user_filter_index.py
+++ b/synapse/storage/schema/delta/56/unique_user_filter_index.py
@@ -5,42 +5,48 @@ from synapse.storage.engines import PostgresEngine
logger = logging.getLogger(__name__)
+"""
+This migration updates the user_filters table as follows:
+
+ - drops any (user_id, filter_id) duplicates
+ - makes the columns NON-NULLable
+ - turns the index into a UNIQUE index
+"""
+
+
def run_upgrade(cur, database_engine, *args, **kwargs):
+ pass
+
+
+def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
select_clause = """
- CREATE TEMPORARY TABLE user_filters_migration AS
SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
- FROM user_filters;
+ FROM user_filters
"""
else:
select_clause = """
- CREATE TEMPORARY TABLE user_filters_migration AS
- SELECT * FROM user_filters GROUP BY user_id, filter_id;
+ SELECT * FROM user_filters GROUP BY user_id, filter_id
"""
- sql = (
- """
- BEGIN;
- %s
- DROP INDEX user_filters_by_user_id_filter_id;
- DELETE FROM user_filters;
- ALTER TABLE user_filters
- ALTER COLUMN user_id SET NOT NULL,
- ALTER COLUMN filter_id SET NOT NULL,
- ALTER COLUMN filter_json SET NOT NULL;
- INSERT INTO user_filters(user_id, filter_id, filter_json)
- SELECT * FROM user_filters_migration;
- DROP TABLE user_filters_migration;
- CREATE UNIQUE INDEX user_filters_by_user_id_filter_id_unique
- ON user_filters(user_id, filter_id);
- END;
- """
- % select_clause
+ sql = """
+ DROP TABLE IF EXISTS user_filters_migration;
+ DROP INDEX IF EXISTS user_filters_unique;
+ CREATE TABLE user_filters_migration (
+ user_id TEXT NOT NULL,
+ filter_id BIGINT NOT NULL,
+ filter_json BYTEA NOT NULL
+ );
+ INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
+ %s;
+ CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
+ (user_id, filter_id);
+ DROP TABLE user_filters;
+ ALTER TABLE user_filters_migration RENAME TO user_filters;
+ """ % (
+ select_clause,
)
+
if isinstance(database_engine, PostgresEngine):
cur.execute(sql)
else:
cur.executescript(sql)
-
-
-def run_create(cur, database_engine, *args, **kwargs):
- pass
diff --git a/synapse/storage/state_deltas.py b/synapse/storage/state_deltas.py
index 5fdb442104..28f33ec18f 100644
--- a/synapse/storage/state_deltas.py
+++ b/synapse/storage/state_deltas.py
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
- def get_current_state_deltas(self, prev_stream_id):
+ def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@@ -36,15 +36,27 @@ class StateDeltasStore(SQLBaseStore):
Args:
prev_stream_id (int): point to get changes since (exclusive)
+ max_stream_id (int): the point that we know has been correctly persisted
+ - ie, an upper limit to return changes from.
Returns:
- Deferred[list[dict]]: results
+ Deferred[tuple[int, list[dict]]: A tuple consisting of:
+ - the stream id which these results go up to
+ - list of current_state_delta_stream rows. If it is empty, we are
+ up to date.
"""
prev_stream_id = int(prev_stream_id)
+
+ # check we're not going backwards
+ assert prev_stream_id <= max_stream_id
+
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
):
- return []
+ # if the CSDs haven't changed between prev_stream_id and now, we
+ # know for certain that they haven't changed between prev_stream_id and
+ # max_stream_id.
+ return max_stream_id, []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@@ -54,21 +66,29 @@ class StateDeltasStore(SQLBaseStore):
sql = """
SELECT stream_id, count(*)
FROM current_state_delta_stream
- WHERE stream_id > ?
+ WHERE stream_id > ? AND stream_id <= ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
"""
- txn.execute(sql, (prev_stream_id,))
+ txn.execute(sql, (prev_stream_id, max_stream_id))
total = 0
- max_stream_id = prev_stream_id
- for max_stream_id, count in txn:
+
+ for stream_id, count in txn:
total += count
if total > 100:
# We arbitarily limit to 100 entries to ensure we don't
# select toooo many.
+ logger.debug(
+ "Clipping current_state_delta_stream rows to stream_id %i",
+ stream_id,
+ )
+ clipped_stream_id = stream_id
break
+ else:
+ # if there's no problem, we may as well go right up to the max_stream_id
+ clipped_stream_id = max_stream_id
# Now actually get the deltas
sql = """
@@ -77,8 +97,8 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
- txn.execute(sql, (prev_stream_id, max_stream_id))
- return self.cursor_to_dict(txn)
+ txn.execute(sql, (prev_stream_id, clipped_stream_id))
+ return clipped_stream_id, self.cursor_to_dict(txn)
return self.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn
|