diff --git a/synapse/__init__.py b/synapse/__init__.py
index 2e354f2cc6..f2d3ac68eb 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -48,7 +48,7 @@ try:
except ImportError:
pass
-__version__ = "1.24.0rc2"
+__version__ = "1.24.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index bfcaf68b2a..1951f6e178 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -31,7 +31,9 @@ from synapse.api.errors import (
MissingClientTokenError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.appservice import ApplicationService
from synapse.events import EventBase
+from synapse.http.site import SynapseRequest
from synapse.logging import opentracing as opentracing
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import StateMap, UserID
@@ -474,7 +476,7 @@ class Auth:
now = self.hs.get_clock().time_msec()
return now < expiry
- def get_appservice_by_req(self, request):
+ def get_appservice_by_req(self, request: SynapseRequest) -> ApplicationService:
token = self.get_access_token_from_request(request)
service = self.store.get_app_service_by_token(token)
if not service:
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 895b38ae76..37ecdbe3d8 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -245,6 +245,8 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# Set up the SIGHUP machinery.
if hasattr(signal, "SIGHUP"):
+ reactor = hs.get_reactor()
+
@wrap_as_background_process("sighup")
def handle_sighup(*args, **kwargs):
# Tell systemd our state, if we're using it. This will silently fail if
@@ -260,7 +262,9 @@ def start(hs: "synapse.server.HomeServer", listeners: Iterable[ListenerConfig]):
# is so that we're in a sane state, e.g. flushing the logs may fail
# if the sighup happens in the middle of writing a log entry.
def run_sighup(*args, **kwargs):
- hs.get_clock().call_later(0, handle_sighup, *args, **kwargs)
+ # `callFromThread` should be "signal safe" as well as thread
+ # safe.
+ reactor.callFromThread(handle_sighup, *args, **kwargs)
signal.signal(signal.SIGHUP, run_sighup)
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index 1b511890aa..aa12c74358 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -266,7 +266,6 @@ class GenericWorkerPresence(BasePresenceHandler):
super().__init__(hs)
self.hs = hs
self.is_mine_id = hs.is_mine_id
- self.http_client = hs.get_simple_http_client()
self._presence_enabled = hs.config.use_presence
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 2b5465417f..bbb7407838 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -19,7 +19,7 @@ import gc
import logging
import os
import sys
-from typing import Iterable
+from typing import Iterable, Iterator
from twisted.application import service
from twisted.internet import defer, reactor
@@ -90,7 +90,7 @@ class SynapseHomeServer(HomeServer):
tls = listener_config.tls
site_tag = listener_config.http_options.tag
if site_tag is None:
- site_tag = port
+ site_tag = str(port)
# We always include a health resource.
resources = {"/health": HealthResource()}
@@ -107,7 +107,10 @@ class SynapseHomeServer(HomeServer):
logger.debug("Configuring additional resources: %r", additional_resources)
module_api = self.get_module_api()
for path, resmodule in additional_resources.items():
- handler_cls, config = load_module(resmodule)
+ handler_cls, config = load_module(
+ resmodule,
+ ("listeners", site_tag, "additional_resources", "<%s>" % (path,)),
+ )
handler = handler_cls(config, module_api)
if IResource.providedBy(handler):
resource = handler
@@ -342,7 +345,10 @@ def setup(config_options):
"Synapse Homeserver", config_options
)
except ConfigError as e:
- sys.stderr.write("\nERROR: %s\n" % (e,))
+ sys.stderr.write("\n")
+ for f in format_config_error(e):
+ sys.stderr.write(f)
+ sys.stderr.write("\n")
sys.exit(1)
if not config:
@@ -445,6 +451,38 @@ def setup(config_options):
return hs
+def format_config_error(e: ConfigError) -> Iterator[str]:
+ """
+ Formats a config error neatly
+
+ The idea is to format the immediate error, plus the "causes" of those errors,
+ hopefully in a way that makes sense to the user. For example:
+
+ Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template':
+ Failed to parse config for module 'JinjaOidcMappingProvider':
+ invalid jinja template:
+ unexpected end of template, expected 'end of print statement'.
+
+ Args:
+ e: the error to be formatted
+
+ Returns: An iterator which yields string fragments to be formatted
+ """
+ yield "Error in configuration"
+
+ if e.path:
+ yield " at '%s'" % (".".join(e.path),)
+
+ yield ":\n %s" % (e.msg,)
+
+ e = e.__cause__
+ indent = 1
+ while e:
+ indent += 1
+ yield ":\n%s%s" % (" " * indent, str(e))
+ e = e.__cause__
+
+
class SynapseService(service.Service):
"""
A twisted Service class that will start synapse. Used to run synapse
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index 85f65da4d9..2931a88207 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -23,7 +23,7 @@ import urllib.parse
from collections import OrderedDict
from hashlib import sha256
from textwrap import dedent
-from typing import Any, Callable, List, MutableMapping, Optional
+from typing import Any, Callable, Iterable, List, MutableMapping, Optional
import attr
import jinja2
@@ -32,7 +32,17 @@ import yaml
class ConfigError(Exception):
- pass
+ """Represents a problem parsing the configuration
+
+ Args:
+ msg: A textual description of the error.
+ path: Where appropriate, an indication of where in the configuration
+ the problem lies.
+ """
+
+ def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+ self.msg = msg
+ self.path = path
# We split these messages out to allow packages to override with package
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index b8faafa9bd..ed26e2fb60 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -1,4 +1,4 @@
-from typing import Any, List, Optional
+from typing import Any, Iterable, List, Optional
from synapse.config import (
api,
@@ -35,7 +35,10 @@ from synapse.config import (
workers,
)
-class ConfigError(Exception): ...
+class ConfigError(Exception):
+ def __init__(self, msg: str, path: Optional[Iterable[str]] = None):
+ self.msg = msg
+ self.path = path
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
diff --git a/synapse/config/_util.py b/synapse/config/_util.py
index c74969a977..1bbe83c317 100644
--- a/synapse/config/_util.py
+++ b/synapse/config/_util.py
@@ -38,14 +38,27 @@ def validate_config(
try:
jsonschema.validate(config, json_schema)
except jsonschema.ValidationError as e:
- # copy `config_path` before modifying it.
- path = list(config_path)
- for p in list(e.path):
- if isinstance(p, int):
- path.append("<item %i>" % p)
- else:
- path.append(str(p))
-
- raise ConfigError(
- "Unable to parse configuration: %s at %s" % (e.message, ".".join(path))
- )
+ raise json_error_to_config_error(e, config_path)
+
+
+def json_error_to_config_error(
+ e: jsonschema.ValidationError, config_path: Iterable[str]
+) -> ConfigError:
+ """Converts a json validation error to a user-readable ConfigError
+
+ Args:
+ e: the exception to be converted
+ config_path: the path within the config file. This will be used as a basis
+ for the error message.
+
+ Returns:
+ a ConfigError
+ """
+ # copy `config_path` before modifying it.
+ path = list(config_path)
+ for p in list(e.path):
+ if isinstance(p, int):
+ path.append("<item %i>" % p)
+ else:
+ path.append(str(p))
+ return ConfigError(e.message, path)
diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py
index cceffbfee2..7c8b64d84b 100644
--- a/synapse/config/emailconfig.py
+++ b/synapse/config/emailconfig.py
@@ -390,9 +390,8 @@ class EmailConfig(Config):
#validation_token_lifetime: 15m
# Directory in which Synapse will try to find the template files below.
- # If not set, default templates from within the Synapse package will be used.
- #
- # Do not uncomment this setting unless you want to customise the templates.
+ # If not set, or the files named below are not found within the template
+ # directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
diff --git a/synapse/config/federation.py b/synapse/config/federation.py
index ffd8fca54e..a03a419e23 100644
--- a/synapse/config/federation.py
+++ b/synapse/config/federation.py
@@ -12,12 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
from typing import Optional
-from netaddr import IPSet
-
-from synapse.config._base import Config, ConfigError
+from synapse.config._base import Config
from synapse.config._util import validate_config
@@ -36,23 +33,6 @@ class FederationConfig(Config):
for domain in federation_domain_whitelist:
self.federation_domain_whitelist[domain] = True
- self.federation_ip_range_blacklist = config.get(
- "federation_ip_range_blacklist", []
- )
-
- # Attempt to create an IPSet from the given ranges
- try:
- self.federation_ip_range_blacklist = IPSet(
- self.federation_ip_range_blacklist
- )
-
- # Always blacklist 0.0.0.0, ::
- self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
- except Exception as e:
- raise ConfigError(
- "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e
- )
-
federation_metrics_domains = config.get("federation_metrics_domains") or []
validate_config(
_METRICS_FOR_DOMAINS_SCHEMA,
@@ -76,26 +56,17 @@ class FederationConfig(Config):
# - nyc.example.com
# - syd.example.com
- # Prevent federation requests from being sent to the following
- # blacklist IP address CIDR ranges. If this option is not specified, or
- # specified with an empty list, no ip range blacklist will be enforced.
- #
- # As of Synapse v1.4.0 this option also affects any outbound requests to identity
- # servers provided by user input.
+ # List of IP address CIDR ranges that should be allowed for federation,
+ # identity servers, push servers, and for checking key validity for
+ # third-party invite events. This is useful for specifying exceptions to
+ # wide-ranging blacklisted target IP ranges - e.g. for communication with
+ # a push server only visible in your network.
#
- # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
- # listed here, since they correspond to unroutable addresses.)
+ # This whitelist overrides ip_range_blacklist and defaults to an empty
+ # list.
#
- federation_ip_range_blacklist:
- - '127.0.0.0/8'
- - '10.0.0.0/8'
- - '172.16.0.0/12'
- - '192.168.0.0/16'
- - '100.64.0.0/10'
- - '169.254.0.0/16'
- - '::1/128'
- - 'fe80::/64'
- - 'fc00::/7'
+ #ip_range_whitelist:
+ # - '192.168.1.1'
# Report prometheus metrics on the age of PDUs being sent to and received from
# the following domains. This can be used to give an idea of "delay" on inbound
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index d4e887a3e0..4df3f93c1c 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) ->
# filter options, but care must when using e.g. MemoryHandler to buffer
# writes.
- log_context_filter = LoggingContextFilter(request="")
+ log_context_filter = LoggingContextFilter()
log_metadata_filter = MetadataFilter({"server_name": config.server_name})
old_factory = logging.getLogRecordFactory()
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index 69d188341c..1abf8ed405 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -66,7 +66,7 @@ class OIDCConfig(Config):
(
self.oidc_user_mapping_provider_class,
self.oidc_user_mapping_provider_config,
- ) = load_module(ump_config)
+ ) = load_module(ump_config, ("oidc_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
required_methods = [
diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py
index 4fda8ae987..85d07c4f8f 100644
--- a/synapse/config/password_auth_providers.py
+++ b/synapse/config/password_auth_providers.py
@@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config):
providers.append({"module": LDAP_PROVIDER, "config": ldap_config})
providers.extend(config.get("password_providers") or [])
- for provider in providers:
+ for i, provider in enumerate(providers):
mod_name = provider["module"]
# This is for backwards compat when the ldap auth provider resided
@@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config):
mod_name = LDAP_PROVIDER
(provider_class, provider_config) = load_module(
- {"module": mod_name, "config": provider["config"]}
+ {"module": mod_name, "config": provider["config"]},
+ ("password_providers", "<item %i>" % i),
)
self.password_providers.append((provider_class, provider_config))
diff --git a/synapse/config/repository.py b/synapse/config/repository.py
index ba1e9d2361..850ac3ebd6 100644
--- a/synapse/config/repository.py
+++ b/synapse/config/repository.py
@@ -17,6 +17,9 @@ import os
from collections import namedtuple
from typing import Dict, List
+from netaddr import IPSet
+
+from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module
@@ -142,7 +145,7 @@ class ContentRepositoryConfig(Config):
# them to be started.
self.media_storage_providers = [] # type: List[tuple]
- for provider_config in storage_providers:
+ for i, provider_config in enumerate(storage_providers):
# We special case the module "file_system" so as not to need to
# expose FileStorageProviderBackend
if provider_config["module"] == "file_system":
@@ -151,7 +154,9 @@ class ContentRepositoryConfig(Config):
".FileStorageProviderBackend"
)
- provider_class, parsed_config = load_module(provider_config)
+ provider_class, parsed_config = load_module(
+ provider_config, ("media_storage_providers", "<item %i>" % i)
+ )
wrapper_config = MediaStorageProviderConfig(
provider_config.get("store_local", False),
@@ -182,9 +187,6 @@ class ContentRepositoryConfig(Config):
"to work"
)
- # netaddr is a dependency for url_preview
- from netaddr import IPSet
-
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
@@ -213,6 +215,10 @@ class ContentRepositoryConfig(Config):
# strip final NL
formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1]
+ ip_range_blacklist = "\n".join(
+ " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
+ )
+
return (
r"""
## Media Store ##
@@ -283,15 +289,7 @@ class ContentRepositoryConfig(Config):
# you uncomment the following list as a starting point.
#
#url_preview_ip_range_blacklist:
- # - '127.0.0.0/8'
- # - '10.0.0.0/8'
- # - '172.16.0.0/12'
- # - '192.168.0.0/16'
- # - '100.64.0.0/10'
- # - '169.254.0.0/16'
- # - '::1/128'
- # - 'fe80::/64'
- # - 'fc00::/7'
+%(ip_range_blacklist)s
# List of IP address CIDR ranges that the URL preview spider is allowed
# to access even if they are specified in url_preview_ip_range_blacklist.
diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py
index 92e1b67528..9a3e1c3e7d 100644
--- a/synapse/config/room_directory.py
+++ b/synapse/config/room_directory.py
@@ -180,7 +180,7 @@ class _RoomDirectoryRule:
self._alias_regex = glob_to_regex(alias)
self._room_id_regex = glob_to_regex(room_id)
except Exception as e:
- raise ConfigError("Failed to parse glob into regex: %s", e)
+ raise ConfigError("Failed to parse glob into regex") from e
def matches(self, user_id, room_id, aliases):
"""Tests if this rule matches the given user_id, room_id and aliases.
diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py
index c1b8e98ae0..7b97d4f114 100644
--- a/synapse/config/saml2_config.py
+++ b/synapse/config/saml2_config.py
@@ -125,7 +125,7 @@ class SAML2Config(Config):
(
self.saml2_user_mapping_provider_class,
self.saml2_user_mapping_provider_config,
- ) = load_module(ump_dict)
+ ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider"))
# Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module
diff --git a/synapse/config/server.py b/synapse/config/server.py
index 85aa49c02d..f3815e5add 100644
--- a/synapse/config/server.py
+++ b/synapse/config/server.py
@@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set
import attr
import yaml
+from netaddr import IPSet
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
@@ -39,6 +40,34 @@ logger = logging.Logger(__name__)
# in the list.
DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"]
+DEFAULT_IP_RANGE_BLACKLIST = [
+ # Localhost
+ "127.0.0.0/8",
+ # Private networks.
+ "10.0.0.0/8",
+ "172.16.0.0/12",
+ "192.168.0.0/16",
+ # Carrier grade NAT.
+ "100.64.0.0/10",
+ # Address registry.
+ "192.0.0.0/24",
+ # Link-local networks.
+ "169.254.0.0/16",
+ # Testing networks.
+ "198.18.0.0/15",
+ "192.0.2.0/24",
+ "198.51.100.0/24",
+ "203.0.113.0/24",
+ # Multicast.
+ "224.0.0.0/4",
+ # Localhost
+ "::1/128",
+ # Link-local addresses.
+ "fe80::/10",
+ # Unique local addresses.
+ "fc00::/7",
+]
+
DEFAULT_ROOM_VERSION = "6"
ROOM_COMPLEXITY_TOO_GREAT = (
@@ -256,6 +285,38 @@ class ServerConfig(Config):
# due to resource constraints
self.admin_contact = config.get("admin_contact", None)
+ ip_range_blacklist = config.get(
+ "ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST
+ )
+
+ # Attempt to create an IPSet from the given ranges
+ try:
+ self.ip_range_blacklist = IPSet(ip_range_blacklist)
+ except Exception as e:
+ raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e
+ # Always blacklist 0.0.0.0, ::
+ self.ip_range_blacklist.update(["0.0.0.0", "::"])
+
+ try:
+ self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ()))
+ except Exception as e:
+ raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e
+
+ # The federation_ip_range_blacklist is used for backwards-compatibility
+ # and only applies to federation and identity servers. If it is not given,
+ # default to ip_range_blacklist.
+ federation_ip_range_blacklist = config.get(
+ "federation_ip_range_blacklist", ip_range_blacklist
+ )
+ try:
+ self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist)
+ except Exception as e:
+ raise ConfigError(
+ "Invalid range(s) provided in federation_ip_range_blacklist."
+ ) from e
+ # Always blacklist 0.0.0.0, ::
+ self.federation_ip_range_blacklist.update(["0.0.0.0", "::"])
+
if self.public_baseurl is not None:
if self.public_baseurl[-1] != "/":
self.public_baseurl += "/"
@@ -561,6 +622,10 @@ class ServerConfig(Config):
def generate_config_section(
self, server_name, data_dir_path, open_private_ports, listeners, **kwargs
):
+ ip_range_blacklist = "\n".join(
+ " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST
+ )
+
_, bind_port = parse_and_validate_server_name(server_name)
if bind_port is not None:
unsecure_port = bind_port - 400
@@ -752,6 +817,21 @@ class ServerConfig(Config):
#
#enable_search: false
+ # Prevent outgoing requests from being sent to the following blacklisted IP address
+ # CIDR ranges. If this option is not specified then it defaults to private IP
+ # address ranges (see the example below).
+ #
+ # The blacklist applies to the outbound requests for federation, identity servers,
+ # push servers, and for checking key validity for third-party invite events.
+ #
+ # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
+ # listed here, since they correspond to unroutable addresses.)
+ #
+ # This option replaces federation_ip_range_blacklist in Synapse v1.25.0.
+ #
+ #ip_range_blacklist:
+%(ip_range_blacklist)s
+
# List of ports that Synapse should listen on, their purpose and their
# configuration.
#
diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py
index 3d067d29db..3d05abc158 100644
--- a/synapse/config/spam_checker.py
+++ b/synapse/config/spam_checker.py
@@ -33,13 +33,14 @@ class SpamCheckerConfig(Config):
# spam checker, and thus was simply a dictionary with module
# and config keys. Support this old behaviour by checking
# to see if the option resolves to a dictionary
- self.spam_checkers.append(load_module(spam_checkers))
+ self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",)))
elif isinstance(spam_checkers, list):
- for spam_checker in spam_checkers:
+ for i, spam_checker in enumerate(spam_checkers):
+ config_path = ("spam_checker", "<item %i>" % i)
if not isinstance(spam_checker, dict):
- raise ConfigError("spam_checker syntax is incorrect")
+ raise ConfigError("expected a mapping", config_path)
- self.spam_checkers.append(load_module(spam_checker))
+ self.spam_checkers.append(load_module(spam_checker, config_path))
else:
raise ConfigError("spam_checker syntax is incorrect")
diff --git a/synapse/config/sso.py b/synapse/config/sso.py
index 4427676167..93bbd40937 100644
--- a/synapse/config/sso.py
+++ b/synapse/config/sso.py
@@ -93,11 +93,8 @@ class SSOConfig(Config):
# - https://my.custom.client/
# Directory in which Synapse will try to find the template files below.
- # If not set, default templates from within the Synapse package will be used.
- #
- # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
- # If you *do* uncomment it, you will need to make sure that all the templates
- # below are in the directory.
+ # If not set, or the files named below are not found within the template
+ # directory, default templates from within the Synapse package will be used.
#
# Synapse will look for the following templates in this directory:
#
diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py
index 10a99c792e..c04e1c4e07 100644
--- a/synapse/config/third_party_event_rules.py
+++ b/synapse/config/third_party_event_rules.py
@@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config):
provider = config.get("third_party_event_rules", None)
if provider is not None:
- self.third_party_event_rules = load_module(provider)
+ self.third_party_event_rules = load_module(
+ provider, ("third_party_event_rules",)
+ )
def generate_config_section(self, **kwargs):
return """\
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 57ab097eba..7ca9efec52 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -85,6 +85,9 @@ class WorkerConfig(Config):
# The port on the main synapse for HTTP replication endpoint
self.worker_replication_http_port = config.get("worker_replication_http_port")
+ # The shared secret used for authentication when connecting to the main synapse.
+ self.worker_replication_secret = config.get("worker_replication_secret", None)
+
self.worker_name = config.get("worker_name", self.worker_app)
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
@@ -185,6 +188,13 @@ class WorkerConfig(Config):
# data). If not provided this defaults to the main process.
#
#run_background_tasks_on: worker1
+
+ # A shared secret used by the replication APIs to authenticate HTTP requests
+ # from workers.
+ #
+ # By default this is unused and traffic is not authenticated.
+ #
+ #worker_replication_secret: ""
"""
def read_arguments(self, args):
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index c04ad77cf9..f23eacc0d7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -578,7 +578,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
def __init__(self, hs):
super().__init__(hs)
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
async def get_keys(self, keys_to_fetch):
@@ -748,7 +748,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
def __init__(self, hs):
super().__init__(hs)
self.clock = hs.get_clock()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
async def get_keys(self, keys_to_fetch):
"""
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 936896656a..e7e3a7b9a4 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -15,10 +15,11 @@
# limitations under the License.
import inspect
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from synapse.spam_checker_api import RegistrationBehaviour
from synapse.types import Collection
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import synapse.events
@@ -39,7 +40,9 @@ class SpamChecker:
else:
self.spam_checkers.append(module(config=config))
- def check_event_for_spam(self, event: "synapse.events.EventBase") -> bool:
+ async def check_event_for_spam(
+ self, event: "synapse.events.EventBase"
+ ) -> Union[bool, str]:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
@@ -50,15 +53,16 @@ class SpamChecker:
event: the event to be checked
Returns:
- True if the event is spammy.
+ True or a string if the event is spammy. If a string is returned it
+ will be used as the error message returned to the user.
"""
for spam_checker in self.spam_checkers:
- if spam_checker.check_event_for_spam(event):
+ if await maybe_awaitable(spam_checker.check_event_for_spam(event)):
return True
return False
- def user_may_invite(
+ async def user_may_invite(
self, inviter_userid: str, invitee_userid: str, room_id: str
) -> bool:
"""Checks if a given user may send an invite
@@ -75,14 +79,18 @@ class SpamChecker:
"""
for spam_checker in self.spam_checkers:
if (
- spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
+ await maybe_awaitable(
+ spam_checker.user_may_invite(
+ inviter_userid, invitee_userid, room_id
+ )
+ )
is False
):
return False
return True
- def user_may_create_room(self, userid: str) -> bool:
+ async def user_may_create_room(self, userid: str) -> bool:
"""Checks if a given user may create a room
If this method returns false, the creation request will be rejected.
@@ -94,12 +102,15 @@ class SpamChecker:
True if the user may create a room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room(userid) is False:
+ if (
+ await maybe_awaitable(spam_checker.user_may_create_room(userid))
+ is False
+ ):
return False
return True
- def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
+ async def user_may_create_room_alias(self, userid: str, room_alias: str) -> bool:
"""Checks if a given user may create a room alias
If this method returns false, the association request will be rejected.
@@ -112,12 +123,17 @@ class SpamChecker:
True if the user may create a room alias, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_create_room_alias(userid, room_alias) is False:
+ if (
+ await maybe_awaitable(
+ spam_checker.user_may_create_room_alias(userid, room_alias)
+ )
+ is False
+ ):
return False
return True
- def user_may_publish_room(self, userid: str, room_id: str) -> bool:
+ async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
"""Checks if a given user may publish a room to the directory
If this method returns false, the publish request will be rejected.
@@ -130,12 +146,17 @@ class SpamChecker:
True if the user may publish the room, otherwise False
"""
for spam_checker in self.spam_checkers:
- if spam_checker.user_may_publish_room(userid, room_id) is False:
+ if (
+ await maybe_awaitable(
+ spam_checker.user_may_publish_room(userid, room_id)
+ )
+ is False
+ ):
return False
return True
- def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
+ async def check_username_for_spam(self, user_profile: Dict[str, str]) -> bool:
"""Checks if a user ID or display name are considered "spammy" by this server.
If the server considers a username spammy, then it will not be included in
@@ -157,12 +178,12 @@ class SpamChecker:
if checker:
# Make a copy of the user profile object to ensure the spam checker
# cannot modify it.
- if checker(user_profile.copy()):
+ if await maybe_awaitable(checker(user_profile.copy())):
return True
return False
- def check_registration_for_spam(
+ async def check_registration_for_spam(
self,
email_threepid: Optional[dict],
username: Optional[str],
@@ -185,7 +206,9 @@ class SpamChecker:
# spam checker
checker = getattr(spam_checker, "check_registration_for_spam", None)
if checker:
- behaviour = checker(email_threepid, username, request_info)
+ behaviour = await maybe_awaitable(
+ checker(email_threepid, username, request_info)
+ )
assert isinstance(behaviour, RegistrationBehaviour)
if behaviour != RegistrationBehaviour.ALLOW:
return behaviour
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 38aa47963f..383737520a 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -78,6 +78,7 @@ class FederationBase:
ctx = current_context()
+ @defer.inlineCallbacks
def callback(_, pdu: EventBase):
with PreserveLoggingContext(ctx):
if not check_event_content_hash(pdu):
@@ -105,7 +106,11 @@ class FederationBase:
)
return redacted_event
- if self.spam_checker.check_event_for_spam(pdu):
+ result = yield defer.ensureDeferred(
+ self.spam_checker.check_event_for_spam(pdu)
+ )
+
+ if result:
logger.warning(
"Event contains spam, redacting %s: %s",
pdu.event_id,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index be43c22876..00a1738e7c 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -845,7 +845,6 @@ class FederationHandlerRegistry:
def __init__(self, hs: "HomeServer"):
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
self.clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 17a10f622e..abe9168c78 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -35,7 +35,7 @@ class TransportLayerClient:
def __init__(self, hs):
self.server_name = hs.hostname
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
@log_function
def get_room_state_ids(self, destination, room_id, event_id):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b53e7a20ec..434718ddfc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args:
hs (synapse.server.HomeServer): homeserver
- resource (TransportLayerServer): resource class to register to
+ resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bb81c0e81d..d29b066a56 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
class BaseHandler:
"""
Common base class for the event handlers.
+
+ Deprecated: new code should not use this. Instead, Handler classes should define the
+ fields they actually need. The utility methods should either be factored out to
+ standalone helper functions, or to different Handler classes.
"""
def __init__(self, hs: "HomeServer"):
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..21e568f226 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import time
import unicodedata
@@ -22,6 +21,7 @@ import urllib.parse
from typing import (
TYPE_CHECKING,
Any,
+ Awaitable,
Callable,
Dict,
Iterable,
@@ -36,6 +36,8 @@ import attr
import bcrypt
import pymacaroons
+from twisted.web.http import Request
+
from synapse.api.constants import LoginType
from synapse.api.errors import (
AuthError,
@@ -56,6 +58,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
+from synapse.util.async_helpers import maybe_awaitable
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
@@ -193,39 +196,27 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = (
- hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
- )
-
- # we keep this as a list despite the O(N^2) implication so that we can
- # keep PASSWORD first and avoid confusing clients which pick the first
- # type in the list. (NB that the spec doesn't require us to do so and
- # clients which favour types that they don't understand over those that
- # they do are technically broken)
+ self._password_localdb_enabled = hs.config.password_localdb_enabled
# start out by assuming PASSWORD is enabled; we will remove it later if not.
- login_types = []
- if hs.config.password_localdb_enabled:
- login_types.append(LoginType.PASSWORD)
+ login_types = set()
+ if self._password_localdb_enabled:
+ login_types.add(LoginType.PASSWORD)
for provider in self.password_providers:
- if hasattr(provider, "get_supported_login_types"):
- for t in provider.get_supported_login_types().keys():
- if t not in login_types:
- login_types.append(t)
+ login_types.update(provider.get_supported_login_types().keys())
if not self._password_enabled:
+ login_types.discard(LoginType.PASSWORD)
+
+ # Some clients just pick the first type in the list. In this case, we want
+ # them to use PASSWORD (rather than token or whatever), so we want to make sure
+ # that comes first, where it's present.
+ self._supported_login_types = []
+ if LoginType.PASSWORD in login_types:
+ self._supported_login_types.append(LoginType.PASSWORD)
login_types.remove(LoginType.PASSWORD)
-
- self._supported_login_types = login_types
-
- # Login types and UI Auth types have a heavy overlap, but are not
- # necessarily identical. Login types have SSO (and other login types)
- # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
- ui_auth_types = login_types.copy()
- if self._sso_enabled:
- ui_auth_types.append(LoginType.SSO)
- self._supported_ui_auth_types = ui_auth_types
+ self._supported_login_types.extend(login_types)
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
@@ -339,7 +330,10 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_ui_auth_types]
+ supported_ui_auth_types = await self._get_available_ui_auth_types(
+ requester.user
+ )
+ flows = [[login_type] for login_type in supported_ui_auth_types]
try:
result, params, session_id = await self.check_ui_auth(
@@ -351,7 +345,7 @@ class AuthHandler(BaseHandler):
raise
# find the completed login type
- for login_type in self._supported_ui_auth_types:
+ for login_type in supported_ui_auth_types:
if login_type not in result:
continue
@@ -367,6 +361,41 @@ class AuthHandler(BaseHandler):
return params, session_id
+ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+ """Get a list of the authentication types this user can use
+ """
+
+ ui_auth_types = set()
+
+ # if the HS supports password auth, and the user has a non-null password, we
+ # support password auth
+ if self._password_localdb_enabled and self._password_enabled:
+ lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+ if lookupres:
+ _, password_hash = lookupres
+ if password_hash:
+ ui_auth_types.add(LoginType.PASSWORD)
+
+ # also allow auth from password providers
+ for provider in self.password_providers:
+ for t in provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
+
+ # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+ # from sso to mxid.
+ if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+ if await self.store.get_external_ids_by_user(user.to_string()):
+ ui_auth_types.add(LoginType.SSO)
+
+ # Our CAS impl does not (yet) correctly register users in user_external_ids,
+ # so always offer that if it's available.
+ if self.hs.config.cas.cas_enabled:
+ ui_auth_types.add(LoginType.SSO)
+
+ return ui_auth_types
+
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -831,7 +860,7 @@ class AuthHandler(BaseHandler):
async def validate_login(
self, login_submission: Dict[str, Any], ratelimit: bool = False,
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Authenticates the user for the /login API
Also used by the user-interactive auth flow to validate auth types which don't
@@ -974,7 +1003,7 @@ class AuthHandler(BaseHandler):
async def _validate_userid_login(
self, username: str, login_submission: Dict[str, Any],
- ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[str, Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Helper for validate_login
Handles login, once we've mapped 3pids onto userids
@@ -1029,7 +1058,7 @@ class AuthHandler(BaseHandler):
if result:
return result
- if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+ if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
@@ -1052,7 +1081,7 @@ class AuthHandler(BaseHandler):
async def check_password_provider_3pid(
self, medium: str, address: str, password: str
- ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], None]]]:
+ ) -> Tuple[Optional[str], Optional[Callable[[Dict[str, str]], Awaitable[None]]]]:
"""Check if a password provider is able to validate a thirdparty login
Args:
@@ -1303,15 +1332,14 @@ class AuthHandler(BaseHandler):
)
async def complete_sso_ui_auth(
- self, registered_user_id: str, session_id: str, request: SynapseRequest,
+ self, registered_user_id: str, session_id: str, request: Request,
):
"""Having figured out a mxid for this user, complete the HTTP request
Args:
registered_user_id: The registered user ID to complete SSO login for.
+ session_id: The ID of the user-interactive auth session.
request: The request to complete.
- client_redirect_url: The URL to which to redirect the user at the end of the
- process.
"""
# Mark the stage of the authentication as successful.
# Save the user who authenticated with SSO, this will be used to ensure
@@ -1327,7 +1355,7 @@ class AuthHandler(BaseHandler):
async def complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
@@ -1355,7 +1383,7 @@ class AuthHandler(BaseHandler):
def _complete_sso_login(
self,
registered_user_id: str,
- request: SynapseRequest,
+ request: Request,
client_redirect_url: str,
extra_attributes: Optional[JsonDict] = None,
):
@@ -1609,6 +1637,6 @@ class PasswordProvider:
# This might return an awaitable, if it does block the log out
# until it completes.
- result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
- if inspect.isawaitable(result):
- await result
+ await maybe_awaitable(
+ g(user_id=user_id, device_id=device_id, access_token=access_token,)
+ )
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index ad5683d251..abcf86352d 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -133,7 +133,9 @@ class DirectoryHandler(BaseHandler):
403, "You must be in the room to create an alias for it"
)
- if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
+ if not await self.spam_checker.user_may_create_room_alias(
+ user_id, room_alias
+ ):
raise AuthError(403, "This user is not permitted to create this alias")
if not self.config.is_alias_creation_allowed(
@@ -409,7 +411,7 @@ class DirectoryHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_publish_room(user_id, room_id):
+ if not await self.spam_checker.user_may_publish_room(user_id, room_id):
raise AuthError(
403, "This user is not permitted to publish rooms to the room list"
)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..fd8de8696d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
self._message_handler = hs.get_message_handler()
self._server_notices_mxid = hs.config.server_notices_mxid
self.config = hs.config
- self.http_client = hs.get_simple_http_client()
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self._instance_name = hs.get_instance_name()
self._replication = hs.get_replication_data_handler()
@@ -1593,7 +1593,7 @@ class FederationHandler(BaseHandler):
if self.hs.config.block_non_admin_invites:
raise SynapseError(403, "This server does not accept room invites")
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
event.sender, event.state_key, event.room_id
):
raise SynapseError(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..7301c24710 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler):
def __init__(self, hs):
super().__init__(hs)
+ # An HTTP client for contacting trusted URLs.
self.http_client = SimpleHttpClient(hs)
- # We create a blacklisting instance of SimpleHttpClient for contacting identity
- # servers specified by clients
+ # An HTTP client for contacting identity servers specified by clients.
self.blacklisting_http_client = SimpleHttpClient(
hs, ip_blacklist=hs.config.federation_ip_range_blacklist
)
- self.federation_http_client = hs.get_http_client()
+ self.federation_http_client = hs.get_federation_http_client()
self.hs = hs
async def threepid_from_creds(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 11420ea996..cbac43c536 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -744,7 +744,7 @@ class EventCreationHandler:
event.sender,
)
- spam_error = self.spam_checker.check_event_for_spam(event)
+ spam_error = await self.spam_checker.check_event_for_spam(event)
if spam_error:
if not isinstance(spam_error, str):
spam_error = "Spam is not permitted here"
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index c605f7082a..f626117f76 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -674,6 +674,21 @@ class OidcHandler(BaseHandler):
self._sso_handler.render_error(request, "invalid_token", str(e))
return
+ # first check if we're doing a UIA
+ if ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
+ except Exception as e:
+ logger.exception("Could not extract remote user id")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id, remote_user_id, ui_auth_session_id, request
+ )
+
+ # otherwise, it's a login
+
# Pull out the user-agent and IP from the request.
user_agent = request.get_user_agent("")
ip_address = self.hs.get_ip_from_request(request)
@@ -698,14 +713,9 @@ class OidcHandler(BaseHandler):
extra_attributes = await get_extra_attributes(userinfo, token)
# and finally complete the login
- if ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, ui_auth_session_id, request
- )
- else:
- await self._auth_handler.complete_sso_login(
- user_id, request, client_redirect_url, extra_attributes
- )
+ await self._auth_handler.complete_sso_login(
+ user_id, request, client_redirect_url, extra_attributes
+ )
def _generate_oidc_session_token(
self,
@@ -856,14 +866,11 @@ class OidcHandler(BaseHandler):
The mxid of the user
"""
try:
- remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ remote_user_id = self._remote_id_from_userinfo(userinfo)
except Exception as e:
raise MappingException(
"Failed to extract subject from OIDC response: %s" % (e,)
)
- # Some OIDC providers use integer IDs, but Synapse expects external IDs
- # to be strings.
- remote_user_id = str(remote_user_id)
# Older mapping providers don't accept the `failures` argument, so we
# try and detect support.
@@ -933,6 +940,19 @@ class OidcHandler(BaseHandler):
grandfather_existing_users,
)
+ def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str:
+ """Extract the unique remote id from an OIDC UserInfo block
+
+ Args:
+ userinfo: An object representing the user given by the OIDC provider
+ Returns:
+ remote user id
+ """
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(userinfo)
+ # Some OIDC providers use integer IDs, but Synapse expects external IDs
+ # to be strings.
+ return str(remote_user_id)
+
UserAttributeDict = TypedDict(
"UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 153cbae7b9..e850e45e46 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -18,7 +18,6 @@ from typing import List, Tuple
from synapse.appservice import ApplicationService
from synapse.handlers._base import BaseHandler
from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
-from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -98,10 +97,8 @@ class ReceiptsHandler(BaseHandler):
self.notifier.on_new_event("receipt_key", max_batch_id, rooms=affected_room_ids)
# Note that the min here shouldn't be relied upon to be accurate.
- await maybe_awaitable(
- self.hs.get_pusherpool().on_new_receipts(
- min_batch_id, max_batch_id, affected_room_ids
- )
+ await self.hs.get_pusherpool().on_new_receipts(
+ min_batch_id, max_batch_id, affected_room_ids
)
return True
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 0d85fd0868..94b5610acd 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -187,7 +187,7 @@ class RegistrationHandler(BaseHandler):
"""
self.check_registration_ratelimit(address)
- result = self.spam_checker.check_registration_for_spam(
+ result = await self.spam_checker.check_registration_for_spam(
threepid, localpart, user_agent_ips or [],
)
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 930047e730..7583418946 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -358,7 +358,7 @@ class RoomCreationHandler(BaseHandler):
"""
user_id = requester.user.to_string()
- if not self.spam_checker.user_may_create_room(user_id):
+ if not await self.spam_checker.user_may_create_room(user_id):
raise SynapseError(403, "You are not permitted to create rooms")
creation_content = {
@@ -440,6 +440,7 @@ class RoomCreationHandler(BaseHandler):
invite_list=[],
initial_state=initial_state,
creation_content=creation_content,
+ ratelimit=False,
)
# Transfer membership events
@@ -608,7 +609,7 @@ class RoomCreationHandler(BaseHandler):
403, "You are not permitted to create rooms", Codes.FORBIDDEN
)
- if not is_requester_admin and not self.spam_checker.user_may_create_room(
+ if not is_requester_admin and not await self.spam_checker.user_may_create_room(
user_id
):
raise SynapseError(403, "You are not permitted to create rooms")
@@ -735,6 +736,7 @@ class RoomCreationHandler(BaseHandler):
room_alias=room_alias,
power_level_content_override=power_level_content_override,
creator_join_profile=creator_join_profile,
+ ratelimit=ratelimit,
)
if "name" in config:
@@ -838,6 +840,7 @@ class RoomCreationHandler(BaseHandler):
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
+ ratelimit: bool = True,
) -> int:
"""Sends the initial events into a new room.
@@ -884,7 +887,7 @@ class RoomCreationHandler(BaseHandler):
creator.user,
room_id,
"join",
- ratelimit=False,
+ ratelimit=ratelimit,
content=creator_join_profile,
)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 4d8ffe8821..bea028b2bf 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -204,7 +204,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# Only rate-limit if the user actually joined the room, otherwise we'll end
# up blocking profile updates.
- if newly_joined:
+ if newly_joined and ratelimit:
time_now_s = self.clock.time()
(
allowed,
@@ -428,7 +428,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
)
block_invite = True
- if not self.spam_checker.user_may_invite(
+ if not await self.spam_checker.user_may_invite(
requester.user.to_string(), target.to_string(), room_id
):
logger.info("Blocking invite due to spam checker")
@@ -508,17 +508,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
- time_now_s = self.clock.time()
- (
- allowed,
- time_allowed,
- ) = self._join_rate_limiter_remote.can_requester_do_action(requester,)
-
- if not allowed:
- raise LimitExceededError(
- retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ if ratelimit:
+ time_now_s = self.clock.time()
+ (
+ allowed,
+ time_allowed,
+ ) = self._join_rate_limiter_remote.can_requester_do_action(
+ requester,
)
+ if not allowed:
+ raise LimitExceededError(
+ retry_after_ms=int(1000 * (time_allowed - time_now_s))
+ )
+
inviter = await self._get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 76d4169fe2..f2ca1ddb53 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -34,7 +34,6 @@ from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
-from synapse.util.async_helpers import Linearizer
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -81,9 +80,6 @@ class SamlHandler(BaseHandler):
# a map from saml session id to Saml2SessionData object
self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData]
- # a lock on the mappings
- self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
-
self._sso_handler = hs.get_sso_handler()
def handle_redirect_request(
@@ -183,6 +179,24 @@ class SamlHandler(BaseHandler):
saml2_auth.in_response_to, None
)
+ # first check if we're doing a UIA
+ if current_session and current_session.ui_auth_session_id:
+ try:
+ remote_user_id = self._remote_id_from_saml_response(saml2_auth, None)
+ except MappingException as e:
+ logger.exception("Failed to extract remote user id from SAML response")
+ self._sso_handler.render_error(request, "mapping_error", str(e))
+ return
+
+ return await self._sso_handler.complete_sso_ui_auth_request(
+ self._auth_provider_id,
+ remote_user_id,
+ current_session.ui_auth_session_id,
+ request,
+ )
+
+ # otherwise, we're handling a login request.
+
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
@@ -206,14 +220,7 @@ class SamlHandler(BaseHandler):
self._sso_handler.render_error(request, "mapping_error", str(e))
return
- # Complete the interactive auth session or the login.
- if current_session and current_session.ui_auth_session_id:
- await self._auth_handler.complete_sso_ui_auth(
- user_id, current_session.ui_auth_session_id, request
- )
-
- else:
- await self._auth_handler.complete_sso_login(user_id, request, relay_state)
+ await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self,
@@ -239,16 +246,10 @@ class SamlHandler(BaseHandler):
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
-
- remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ remote_user_id = self._remote_id_from_saml_response(
saml2_auth, client_redirect_url
)
- if not remote_user_id:
- raise MappingException(
- "Failed to extract remote user id from SAML response"
- )
-
async def saml_response_to_remapped_user_attributes(
failures: int,
) -> UserAttributes:
@@ -294,16 +295,44 @@ class SamlHandler(BaseHandler):
return None
- with (await self._mapping_lock.queue(self._auth_provider_id)):
- return await self._sso_handler.get_mxid_from_sso(
- self._auth_provider_id,
- remote_user_id,
- user_agent,
- ip_address,
- saml_response_to_remapped_user_attributes,
- grandfather_existing_users,
+ return await self._sso_handler.get_mxid_from_sso(
+ self._auth_provider_id,
+ remote_user_id,
+ user_agent,
+ ip_address,
+ saml_response_to_remapped_user_attributes,
+ grandfather_existing_users,
+ )
+
+ def _remote_id_from_saml_response(
+ self,
+ saml2_auth: saml2.response.AuthnResponse,
+ client_redirect_url: Optional[str],
+ ) -> str:
+ """Extract the unique remote id from a SAML2 AuthnResponse
+
+ Args:
+ saml2_auth: The parsed SAML2 response.
+ client_redirect_url: The redirect URL passed in by the client.
+ Returns:
+ remote user id
+
+ Raises:
+ MappingException if there was an error extracting the user id
+ """
+ # It's not obvious why we need to pass in the redirect URI to the mapping
+ # provider, but we do :/
+ remote_user_id = self._user_mapping_provider.get_remote_user_id(
+ saml2_auth, client_redirect_url
+ )
+
+ if not remote_user_id:
+ raise MappingException(
+ "Failed to extract remote user id from SAML response"
)
+ return remote_user_id
+
def expire_sessions(self):
expire_before = self.clock.time_msec() - self._saml2_session_lifetime
to_expire = set()
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 47ad96f97e..112a7d5b2c 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -17,10 +17,12 @@ from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
import attr
+from twisted.web.http import Request
+
from synapse.api.errors import RedirectException
-from synapse.handlers._base import BaseHandler
from synapse.http.server import respond_with_html
from synapse.types import UserID, contains_invalid_mxid_characters
+from synapse.util.async_helpers import Linearizer
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -42,14 +44,19 @@ class UserAttributes:
emails = attr.ib(type=List[str], default=attr.Factory(list))
-class SsoHandler(BaseHandler):
+class SsoHandler:
# The number of attempts to ask the mapping provider for when generating an MXID.
_MAP_USERNAME_RETRIES = 1000
def __init__(self, hs: "HomeServer"):
- super().__init__(hs)
+ self._store = hs.get_datastore()
+ self._server_name = hs.hostname
self._registration_handler = hs.get_registration_handler()
self._error_template = hs.config.sso_error_template
+ self._auth_handler = hs.get_auth_handler()
+
+ # a lock on the mappings
+ self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock())
def render_error(
self, request, error: str, error_description: Optional[str] = None
@@ -95,7 +102,7 @@ class SsoHandler(BaseHandler):
)
# Check if we already have a mapping for this user.
- previously_registered_user_id = await self.store.get_user_by_external_id(
+ previously_registered_user_id = await self._store.get_user_by_external_id(
auth_provider_id, remote_user_id,
)
@@ -169,24 +176,38 @@ class SsoHandler(BaseHandler):
to an additional page. (e.g. to prompt for more information)
"""
- # first of all, check if we already have a mapping for this user
- previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
- auth_provider_id, remote_user_id,
- )
- if previously_registered_user_id:
- return previously_registered_user_id
-
- # Check for grandfathering of users.
- if grandfather_existing_users:
- previously_registered_user_id = await grandfather_existing_users()
+ # grab a lock while we try to find a mapping for this user. This seems...
+ # optimistic, especially for implementations that end up redirecting to
+ # interstitial pages.
+ with await self._mapping_lock.queue(auth_provider_id):
+ # first of all, check if we already have a mapping for this user
+ previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
if previously_registered_user_id:
- # Future logins should also match this user ID.
- await self.store.record_user_external_id(
- auth_provider_id, remote_user_id, previously_registered_user_id
- )
return previously_registered_user_id
- # Otherwise, generate a new user.
+ # Check for grandfathering of users.
+ if grandfather_existing_users:
+ previously_registered_user_id = await grandfather_existing_users()
+ if previously_registered_user_id:
+ # Future logins should also match this user ID.
+ await self._store.record_user_external_id(
+ auth_provider_id, remote_user_id, previously_registered_user_id
+ )
+ return previously_registered_user_id
+
+ # Otherwise, generate a new user.
+ attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper)
+ user_id = await self._register_mapped_user(
+ attributes, auth_provider_id, remote_user_id, user_agent, ip_address,
+ )
+ return user_id
+
+ async def _call_attribute_mapper(
+ self, sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+ ) -> UserAttributes:
+ """Call the attribute mapper function in a loop, until we get a unique userid"""
for i in range(self._MAP_USERNAME_RETRIES):
try:
attributes = await sso_to_matrix_id_mapper(i)
@@ -214,8 +235,8 @@ class SsoHandler(BaseHandler):
)
# Check if this mxid already exists
- user_id = UserID(attributes.localpart, self.server_name).to_string()
- if not await self.store.get_users_by_id_case_insensitive(user_id):
+ user_id = UserID(attributes.localpart, self._server_name).to_string()
+ if not await self._store.get_users_by_id_case_insensitive(user_id):
# This mxid is free
break
else:
@@ -224,7 +245,16 @@ class SsoHandler(BaseHandler):
raise MappingException(
"Unable to generate a Matrix ID from the SSO response"
)
+ return attributes
+ async def _register_mapped_user(
+ self,
+ attributes: UserAttributes,
+ auth_provider_id: str,
+ remote_user_id: str,
+ user_agent: str,
+ ip_address: str,
+ ) -> str:
# Since the localpart is provided via a potentially untrusted module,
# ensure the MXID is valid before registering.
if contains_invalid_mxid_characters(attributes.localpart):
@@ -238,7 +268,47 @@ class SsoHandler(BaseHandler):
user_agent_ips=[(user_agent, ip_address)],
)
- await self.store.record_user_external_id(
+ await self._store.record_user_external_id(
auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id
+
+ async def complete_sso_ui_auth_request(
+ self,
+ auth_provider_id: str,
+ remote_user_id: str,
+ ui_auth_session_id: str,
+ request: Request,
+ ) -> None:
+ """
+ Given an SSO ID, retrieve the user ID for it and complete UIA.
+
+ Note that this requires that the user is mapped in the "user_external_ids"
+ table. This will be the case if they have ever logged in via SAML or OIDC in
+ recentish synapse versions, but may not be for older users.
+
+ Args:
+ auth_provider_id: A unique identifier for this SSO provider, e.g.
+ "oidc" or "saml".
+ remote_user_id: The unique identifier from the SSO provider.
+ ui_auth_session_id: The ID of the user-interactive auth session.
+ request: The request to complete.
+ """
+
+ user_id = await self.get_sso_user_by_remote_user_id(
+ auth_provider_id, remote_user_id,
+ )
+
+ if not user_id:
+ logger.warning(
+ "Remote user %s/%s has not previously logged in here: UIA will fail",
+ auth_provider_id,
+ remote_user_id,
+ )
+ # Let the UIA flow handle this the same as if they presented creds for a
+ # different user.
+ user_id = ""
+
+ await self._auth_handler.complete_sso_ui_auth(
+ user_id, ui_auth_session_id, request
+ )
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index afbebfc200..f263a638f8 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -81,11 +81,11 @@ class UserDirectoryHandler(StateDeltasHandler):
results = await self.store.search_user_dir(user_id, search_term, limit)
# Remove any spammy users from the results.
- results["results"] = [
- user
- for user in results["results"]
- if not self.spam_checker.check_username_for_spam(user)
- ]
+ non_spammy_users = []
+ for user in results["results"]:
+ if not await self.spam_checker.check_username_for_spam(user):
+ non_spammy_users.append(user)
+ results["results"] = non_spammy_users
return results
diff --git a/synapse/http/client.py b/synapse/http/client.py
index e5b13593f2..df7730078f 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -125,7 +125,7 @@ def _make_scheduler(reactor):
return _scheduler
-class IPBlacklistingResolver:
+class _IPBlacklistingResolver:
"""
A proxy for reactor.nameResolver which only produces non-blacklisted IP
addresses, preventing DNS rebinding attacks on URL preview.
@@ -199,6 +199,35 @@ class IPBlacklistingResolver:
return r
+@implementer(IReactorPluggableNameResolver)
+class BlacklistingReactorWrapper:
+ """
+ A Reactor wrapper which will prevent DNS resolution to blacklisted IP
+ addresses, to prevent DNS rebinding.
+ """
+
+ def __init__(
+ self,
+ reactor: IReactorPluggableNameResolver,
+ ip_whitelist: Optional[IPSet],
+ ip_blacklist: IPSet,
+ ):
+ self._reactor = reactor
+
+ # We need to use a DNS resolver which filters out blacklisted IP
+ # addresses, to prevent DNS rebinding.
+ self._nameResolver = _IPBlacklistingResolver(
+ self._reactor, ip_whitelist, ip_blacklist
+ )
+
+ def __getattr__(self, attr: str) -> Any:
+ # Passthrough to the real reactor except for the DNS resolver.
+ if attr == "nameResolver":
+ return self._nameResolver
+ else:
+ return getattr(self._reactor, attr)
+
+
class BlacklistingAgentWrapper(Agent):
"""
An Agent wrapper which will prevent access to IP addresses being accessed
@@ -292,22 +321,11 @@ class SimpleHttpClient:
self.user_agent = self.user_agent.encode("ascii")
if self._ip_blacklist:
- real_reactor = hs.get_reactor()
# If we have an IP blacklist, we need to use a DNS resolver which
# filters out blacklisted IP addresses, to prevent DNS rebinding.
- nameResolver = IPBlacklistingResolver(
- real_reactor, self._ip_whitelist, self._ip_blacklist
+ self.reactor = BlacklistingReactorWrapper(
+ hs.get_reactor(), self._ip_whitelist, self._ip_blacklist
)
-
- @implementer(IReactorPluggableNameResolver)
- class Reactor:
- def __getattr__(_self, attr):
- if attr == "nameResolver":
- return nameResolver
- else:
- return getattr(real_reactor, attr)
-
- self.reactor = Reactor()
else:
self.reactor = hs.get_reactor()
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index e77f9587d0..3b756a7dc2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -16,7 +16,7 @@ import logging
import urllib.parse
from typing import List, Optional
-from netaddr import AddrFormatError, IPAddress
+from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
from twisted.internet import defer
@@ -31,6 +31,7 @@ from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer
from synapse.crypto.context_factory import FederationPolicyForHTTPS
+from synapse.http.client import BlacklistingAgentWrapper
from synapse.http.federation.srv_resolver import Server, SrvResolver
from synapse.http.federation.well_known_resolver import WellKnownResolver
from synapse.logging.context import make_deferred_yieldable, run_in_background
@@ -70,6 +71,7 @@ class MatrixFederationAgent:
reactor: IReactorCore,
tls_client_options_factory: Optional[FederationPolicyForHTTPS],
user_agent: bytes,
+ ip_blacklist: IPSet,
_srv_resolver: Optional[SrvResolver] = None,
_well_known_resolver: Optional[WellKnownResolver] = None,
):
@@ -90,12 +92,18 @@ class MatrixFederationAgent:
self.user_agent = user_agent
if _well_known_resolver is None:
+ # Note that the name resolver has already been wrapped in a
+ # IPBlacklistingResolver by MatrixFederationHttpClient.
_well_known_resolver = WellKnownResolver(
self._reactor,
- agent=Agent(
+ agent=BlacklistingAgentWrapper(
+ Agent(
+ self._reactor,
+ pool=self._pool,
+ contextFactory=tls_client_options_factory,
+ ),
self._reactor,
- pool=self._pool,
- contextFactory=tls_client_options_factory,
+ ip_blacklist=ip_blacklist,
),
user_agent=self.user_agent,
)
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 4e27f93b7a..c962994727 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -26,11 +26,10 @@ import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
-from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
-from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
+from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse
@@ -45,7 +44,7 @@ from synapse.api.errors import (
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
BlacklistingAgentWrapper,
- IPBlacklistingResolver,
+ BlacklistingReactorWrapper,
encode_query_args,
readBodyToFile,
)
@@ -221,31 +220,22 @@ class MatrixFederationHttpClient:
self.signing_key = hs.signing_key
self.server_name = hs.hostname
- real_reactor = hs.get_reactor()
-
# We need to use a DNS resolver which filters out blacklisted IP
# addresses, to prevent DNS rebinding.
- nameResolver = IPBlacklistingResolver(
- real_reactor, None, hs.config.federation_ip_range_blacklist
+ self.reactor = BlacklistingReactorWrapper(
+ hs.get_reactor(), None, hs.config.federation_ip_range_blacklist
)
- @implementer(IReactorPluggableNameResolver)
- class Reactor:
- def __getattr__(_self, attr):
- if attr == "nameResolver":
- return nameResolver
- else:
- return getattr(real_reactor, attr)
-
- self.reactor = Reactor()
-
user_agent = hs.version_string
if hs.config.user_agent_suffix:
user_agent = "%s %s" % (user_agent, hs.config.user_agent_suffix)
user_agent = user_agent.encode("ascii")
self.agent = MatrixFederationAgent(
- self.reactor, tls_client_options_factory, user_agent
+ self.reactor,
+ tls_client_options_factory,
+ user_agent,
+ hs.config.federation_ip_range_blacklist,
)
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 6a4e429a6c..e464bfe6c7 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -275,6 +275,10 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON.
"""
+ def __init__(self, canonical_json=False, extract_context=False):
+ super().__init__(extract_context)
+ self.canonical_json = canonical_json
+
def _send_response(
self, request: Request, code: int, response_object: Any,
):
@@ -318,9 +322,7 @@ class JsonResource(DirectServeJsonResource):
)
def __init__(self, hs, canonical_json=True, extract_context=False):
- super().__init__(extract_context)
-
- self.canonical_json = canonical_json
+ super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock()
self.path_regexs = {}
self.hs = hs
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 5f0581dc3f..5a5790831b 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -128,8 +128,7 @@ class SynapseRequest(Request):
# create a LogContext for this request
request_id = self.get_request_id()
- logcontext = self.logcontext = LoggingContext(request_id)
- logcontext.request = request_id
+ self.logcontext = LoggingContext(request_id, request=request_id)
# override the Server header which is set by twisted
self.setHeader("Server", self.site.server_version_string)
diff --git a/synapse/logging/context.py b/synapse/logging/context.py
index ca0c774cc5..a507a83e93 100644
--- a/synapse/logging/context.py
+++ b/synapse/logging/context.py
@@ -203,10 +203,6 @@ class _Sentinel:
def copy_to(self, record):
pass
- def copy_to_twisted_log_entry(self, record):
- record["request"] = None
- record["scope"] = None
-
def start(self, rusage: "Optional[resource._RUsage]"):
pass
@@ -372,13 +368,6 @@ class LoggingContext:
# we also track the current scope:
record.scope = self.scope
- def copy_to_twisted_log_entry(self, record) -> None:
- """
- Copy logging fields from this context to a Twisted log record.
- """
- record["request"] = self.request
- record["scope"] = self.scope
-
def start(self, rusage: "Optional[resource._RUsage]") -> None:
"""
Record that this logcontext is currently running.
@@ -542,13 +531,10 @@ class LoggingContext:
class LoggingContextFilter(logging.Filter):
"""Logging filter that adds values from the current logging context to each
record.
- Args:
- **defaults: Default values to avoid formatters complaining about
- missing fields
"""
- def __init__(self, **defaults) -> None:
- self.defaults = defaults
+ def __init__(self, request: str = ""):
+ self._default_request = request
def filter(self, record) -> Literal[True]:
"""Add each fields from the logging contexts to the record.
@@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter):
True to include the record in the log output.
"""
context = current_context()
- for key, value in self.defaults.items():
- setattr(record, key, value)
+ record.request = self._default_request
# context should never be None, but if it somehow ends up being, then
# we end up in a death spiral of infinite loops, so let's check, for
# robustness' sake.
if context is not None:
- context.copy_to(record)
+ # Logging is interested in the request.
+ record.request = context.request
return True
diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py
index 658f6ecd72..70e0fa45d9 100644
--- a/synapse/metrics/background_process_metrics.py
+++ b/synapse/metrics/background_process_metrics.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import threading
from functools import wraps
@@ -25,6 +24,7 @@ from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import noop_context_manager, start_active_span
+from synapse.util.async_helpers import maybe_awaitable
if TYPE_CHECKING:
import resource
@@ -199,19 +199,13 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar
_background_process_start_count.labels(desc).inc()
_background_process_in_flight_count.labels(desc).inc()
- with BackgroundProcessLoggingContext(desc) as context:
- context.request = "%s-%i" % (desc, count)
+ with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context:
try:
ctx = noop_context_manager()
if bg_start_span:
ctx = start_active_span(desc, tags={"request_id": context.request})
with ctx:
- result = func(*args, **kwargs)
-
- if inspect.isawaitable(result):
- result = await result
-
- return result
+ return await maybe_awaitable(func(*args, **kwargs))
except Exception:
logger.exception(
"Background process '%s' threw an exception", desc,
@@ -249,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext):
__slots__ = ["_proc"]
- def __init__(self, name: str):
- super().__init__(name)
+ def __init__(self, name: str, request: Optional[str] = None):
+ super().__init__(name, request=request)
self._proc = _BackgroundProcess(name, self)
diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py
index 5a437f9810..3d2e874838 100644
--- a/synapse/push/__init__.py
+++ b/synapse/push/__init__.py
@@ -13,7 +13,56 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+from typing import TYPE_CHECKING, Any, Dict, Optional
+
+from synapse.types import RoomStreamToken
+
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
+
+class Pusher(metaclass=abc.ABCMeta):
+ def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
+ self.hs = hs
+ self.store = self.hs.get_datastore()
+ self.clock = self.hs.get_clock()
+
+ self.pusher_id = pusherdict["id"]
+ self.user_id = pusherdict["user_name"]
+ self.app_id = pusherdict["app_id"]
+ self.pushkey = pusherdict["pushkey"]
+
+ # This is the highest stream ordering we know it's safe to process.
+ # When new events arrive, we'll be given a window of new events: we
+ # should honour this rather than just looking for anything higher
+ # because of potential out-of-order event serialisation. This starts
+ # off as None though as we don't know any better.
+ self.max_stream_ordering = None # type: Optional[int]
+
+ @abc.abstractmethod
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_started(self, have_notifs: bool) -> None:
+ """Called when this pusher has been started.
+
+ Args:
+ should_check_for_notifs: Whether we should immediately
+ check for push to send. Set to False only if it's known there
+ is nothing to send
+ """
+ raise NotImplementedError()
+
+ @abc.abstractmethod
+ def on_stop(self) -> None:
+ raise NotImplementedError()
+
class PusherConfigException(Exception):
- def __init__(self, msg):
- super().__init__(msg)
+ """An error occurred when creating a pusher."""
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index fabc9ba126..aaed28650d 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -14,19 +14,22 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING
+from synapse.events import EventBase
+from synapse.events.snapshot import EventContext
+from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
from synapse.util.metrics import Measure
-from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class ActionGenerator:
- def __init__(self, hs):
- self.hs = hs
+ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
- self.store = hs.get_datastore()
self.bulk_evaluator = BulkPushRuleEvaluator(hs)
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
@@ -35,6 +38,8 @@ class ActionGenerator:
# event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users).
- async def handle_push_actions_for_event(self, event, context):
+ async def handle_push_actions_for_event(
+ self, event: EventBase, context: EventContext
+ ) -> None:
with Measure(self.clock, "action_for_event_by_user"):
await self.bulk_evaluator.action_for_event_by_user(event, context)
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index f5788c1de7..6211506990 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -15,16 +15,19 @@
# limitations under the License.
import copy
+from typing import Any, Dict, List
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
-def list_with_base_rules(rawrules, use_new_defaults=False):
+def list_with_base_rules(
+ rawrules: List[Dict[str, Any]], use_new_defaults: bool = False
+) -> List[Dict[str, Any]]:
"""Combine the list of rules set by the user with the default push rules
Args:
- rawrules(list): The rules the user has modified or set.
- use_new_defaults(bool): Whether to use the new experimental default rules when
+ rawrules: The rules the user has modified or set.
+ use_new_defaults: Whether to use the new experimental default rules when
appending or prepending default rules.
Returns:
@@ -94,7 +97,11 @@ def list_with_base_rules(rawrules, use_new_defaults=False):
return ruleslist
-def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
+def make_base_append_rules(
+ kind: str,
+ modified_base_rules: Dict[str, Dict[str, Any]],
+ use_new_defaults: bool = False,
+) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
@@ -116,6 +123,7 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
+ assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"])
if modified:
r["actions"] = modified["actions"]
@@ -123,7 +131,11 @@ def make_base_append_rules(kind, modified_base_rules, use_new_defaults=False):
return rules
-def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
+def make_base_prepend_rules(
+ kind: str,
+ modified_base_rules: Dict[str, Dict[str, Any]],
+ use_new_defaults: bool = False,
+) -> List[Dict[str, Any]]:
rules = []
if kind == "override":
@@ -133,6 +145,7 @@ def make_base_prepend_rules(kind, modified_base_rules, use_new_defaults=False):
rules = copy.deepcopy(rules)
for r in rules:
# Only modify the actions, keep the conditions the same.
+ assert isinstance(r["rule_id"], str)
modified = modified_base_rules.get(r["rule_id"])
if modified:
r["actions"] = modified["actions"]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 82a72dc34f..10f27e4378 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
import attr
from prometheus_client import Counter
@@ -25,16 +26,16 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
-from synapse.util.caches import register_cache
+from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from .push_rule_evaluator import PushRuleEvaluatorForEvent
-logger = logging.getLogger(__name__)
-
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
-rules_by_room = {}
+logger = logging.getLogger(__name__)
push_rules_invalidation_counter = Counter(
@@ -101,7 +102,7 @@ class BulkPushRuleEvaluator:
room at once.
"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -113,7 +114,9 @@ class BulkPushRuleEvaluator:
resizable=False,
)
- async def _get_rules_for_event(self, event, context):
+ async def _get_rules_for_event(
+ self, event: EventBase, context: EventContext
+ ) -> Dict[str, List[Dict[str, Any]]]:
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.
@@ -140,11 +143,8 @@ class BulkPushRuleEvaluator:
return rules_by_user
@lru_cache()
- def _get_rules_for_room(self, room_id):
+ def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
"""Get the current RulesForRoom object for the given room id
-
- Returns:
- RulesForRoom
"""
# It's important that RulesForRoom gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
@@ -156,20 +156,21 @@ class BulkPushRuleEvaluator:
self.room_push_rule_cache_metrics,
)
- async def _get_power_levels_and_sender_level(self, event, context):
+ async def _get_power_levels_and_sender_level(
+ self, event: EventBase, context: EventContext
+ ) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
- pl_event = await self.store.get_event(pl_event_id)
- auth_events = {POWER_KEY: pl_event}
+ auth_events = {POWER_KEY: await self.store.get_event(pl_event_id)}
else:
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
- auth_events = await self.store.get_events(auth_events_ids)
- auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
+ auth_events_dict = await self.store.get_events(auth_events_ids)
+ auth_events = {(e.type, e.state_key): e for e in auth_events_dict.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -177,7 +178,9 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
- async def action_for_event_by_user(self, event, context) -> None:
+ async def action_for_event_by_user(
+ self, event: EventBase, context: EventContext
+ ) -> None:
"""Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the
event_push_actions_staging table.
@@ -185,7 +188,7 @@ class BulkPushRuleEvaluator:
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
- actions_by_user = {}
+ actions_by_user = {} # type: Dict[str, List[Union[dict, str]]]
room_members = await self.store.get_joined_users_from_context(event, context)
@@ -198,7 +201,7 @@ class BulkPushRuleEvaluator:
event, len(room_members), sender_power_level, power_levels
)
- condition_cache = {}
+ condition_cache = {} # type: Dict[str, bool]
for uid, rules in rules_by_user.items():
if event.sender == uid:
@@ -249,7 +252,13 @@ class BulkPushRuleEvaluator:
)
-def _condition_checker(evaluator, conditions, uid, display_name, cache):
+def _condition_checker(
+ evaluator: PushRuleEvaluatorForEvent,
+ conditions: List[dict],
+ uid: str,
+ display_name: str,
+ cache: Dict[str, bool],
+) -> bool:
for cond in conditions:
_id = cond.get("_id", None)
if _id:
@@ -277,15 +286,19 @@ class RulesForRoom:
"""
def __init__(
- self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+ self,
+ hs: "HomeServer",
+ room_id: str,
+ rules_for_room_cache: LruCache,
+ room_push_rule_cache_metrics: CacheMetric,
):
"""
Args:
- hs (HomeServer)
- room_id (str)
+ hs: The HomeServer object.
+ room_id: The room ID.
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
- room_push_rule_cache_metrics (CacheMetric)
+ room_push_rule_cache_metrics: The metrics object
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
@@ -294,8 +307,10 @@ class RulesForRoom:
self.linearizer = Linearizer(name="rules_for_room")
- self.member_map = {} # event_id -> (user_id, state)
- self.rules_by_user = {} # user_id -> rules
+ # event_id -> (user_id, state)
+ self.member_map = {} # type: Dict[str, Tuple[str, str]]
+ # user_id -> rules
+ self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
# The last state group we updated the caches for. If the state_group of
# a new event comes along, we know that we can just return the cached
@@ -315,7 +330,7 @@ class RulesForRoom:
# calculate push for)
# These never need to be invalidated as we will never set up push for
# them.
- self.uninteresting_user_set = set()
+ self.uninteresting_user_set = set() # type: Set[str]
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
@@ -325,7 +340,9 @@ class RulesForRoom:
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
- async def get_rules(self, event, context):
+ async def get_rules(
+ self, event: EventBase, context: EventContext
+ ) -> Dict[str, List[Dict[str, dict]]]:
"""Given an event context return the rules for all users who are
currently in the room.
"""
@@ -356,6 +373,8 @@ class RulesForRoom:
else:
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
+ # Ensure the state IDs exist.
+ assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids))
@@ -420,18 +439,23 @@ class RulesForRoom:
return ret_rules_by_user
async def _update_rules_with_member_event_ids(
- self, ret_rules_by_user, member_event_ids, state_group, event
- ):
+ self,
+ ret_rules_by_user: Dict[str, list],
+ member_event_ids: Dict[str, str],
+ state_group: Optional[int],
+ event: EventBase,
+ ) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
- ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+ ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules.
- member_event_ids (dict): Dict of user id to event id for membership events
+ member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
+ event: The event we are currently computing push rules for.
"""
sequence = self.sequence
@@ -449,19 +473,19 @@ class RulesForRoom:
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
- user_ids = {
+ joined_user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
- logger.debug("Joined: %r", user_ids)
+ logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
- user_ids = list(filter(self.is_mine_id, user_ids))
+ user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
@@ -473,7 +497,7 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group)
- def invalidate_all(self):
+ def invalidate_all(self) -> None:
# Note: Don't hand this function directly to an invalidation callback
# as it keeps a reference to self and will stop this instance from being
# GC'd if it gets dropped from the rules_to_user cache. Instead use
@@ -485,7 +509,7 @@ class RulesForRoom:
self.rules_by_user = {}
push_rules_invalidation_counter.inc()
- def update_cache(self, sequence, members, rules_by_user, state_group):
+ def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
if sequence == self.sequence:
self.member_map.update(members)
self.rules_by_user = rules_by_user
@@ -506,7 +530,7 @@ class _Invalidation:
cache = attr.ib(type=LruCache)
room_id = attr.ib(type=str)
- def __call__(self):
+ def __call__(self) -> None:
rules = self.cache.get(self.room_id, None, update_metrics=False)
if rules:
rules.invalidate_all()
diff --git a/synapse/push/clientformat.py b/synapse/push/clientformat.py
index a59b639f15..0cadba761a 100644
--- a/synapse/push/clientformat.py
+++ b/synapse/push/clientformat.py
@@ -14,24 +14,27 @@
# limitations under the License.
import copy
+from typing import Any, Dict, List, Optional
from synapse.push.rulekinds import PRIORITY_CLASS_INVERSE_MAP, PRIORITY_CLASS_MAP
+from synapse.types import UserID
-def format_push_rules_for_user(user, ruleslist):
+def format_push_rules_for_user(user: UserID, ruleslist) -> Dict[str, Dict[str, list]]:
"""Converts a list of rawrules and a enabled map into nested dictionaries
to match the Matrix client-server format for push rules"""
# We're going to be mutating this a lot, so do a deep copy
ruleslist = copy.deepcopy(ruleslist)
- rules = {"global": {}, "device": {}}
+ rules = {
+ "global": {},
+ "device": {},
+ } # type: Dict[str, Dict[str, List[Dict[str, Any]]]]
rules["global"] = _add_empty_priority_class_arrays(rules["global"])
for r in ruleslist:
- rulearray = None
-
template_name = _priority_class_to_template_name(r["priority_class"])
# Remove internal stuff.
@@ -57,13 +60,13 @@ def format_push_rules_for_user(user, ruleslist):
return rules
-def _add_empty_priority_class_arrays(d):
+def _add_empty_priority_class_arrays(d: Dict[str, list]) -> Dict[str, list]:
for pc in PRIORITY_CLASS_MAP.keys():
d[pc] = []
return d
-def _rule_to_template(rule):
+def _rule_to_template(rule: Dict[str, Any]) -> Optional[Dict[str, Any]]:
unscoped_rule_id = None
if "rule_id" in rule:
unscoped_rule_id = _rule_id_from_namespaced(rule["rule_id"])
@@ -82,6 +85,10 @@ def _rule_to_template(rule):
return None
templaterule = {"actions": rule["actions"]}
templaterule["pattern"] = thecond["pattern"]
+ else:
+ # This should not be reached unless this function is not kept in sync
+ # with PRIORITY_CLASS_INVERSE_MAP.
+ raise ValueError("Unexpected template_name: %s" % (template_name,))
if unscoped_rule_id:
templaterule["rule_id"] = unscoped_rule_id
@@ -90,9 +97,9 @@ def _rule_to_template(rule):
return templaterule
-def _rule_id_from_namespaced(in_rule_id):
+def _rule_id_from_namespaced(in_rule_id: str) -> str:
return in_rule_id.split("/")[-1]
-def _priority_class_to_template_name(pc):
+def _priority_class_to_template_name(pc: int) -> str:
return PRIORITY_CLASS_INVERSE_MAP[pc]
diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py
index c6763971ee..64a35c1994 100644
--- a/synapse/push/emailpusher.py
+++ b/synapse/push/emailpusher.py
@@ -14,12 +14,19 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from twisted.internet.base import DelayedCall
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.push import Pusher
+from synapse.push.mailer import Mailer
from synapse.types import RoomStreamToken
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
# The amount of time we always wait before ever emailing about a notification
@@ -46,7 +53,7 @@ THROTTLE_RESET_AFTER_MS = 12 * 60 * 60 * 1000
INCLUDE_ALL_UNREAD_NOTIFS = False
-class EmailPusher:
+class EmailPusher(Pusher):
"""
A pusher that sends email notifications about events (approximately)
when they happen.
@@ -54,37 +61,31 @@ class EmailPusher:
factor out the common parts
"""
- def __init__(self, hs, pusherdict, mailer):
- self.hs = hs
+ def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any], mailer: Mailer):
+ super().__init__(hs, pusherdict)
self.mailer = mailer
self.store = self.hs.get_datastore()
- self.clock = self.hs.get_clock()
- self.pusher_id = pusherdict["id"]
- self.user_id = pusherdict["user_name"]
- self.app_id = pusherdict["app_id"]
self.email = pusherdict["pushkey"]
self.last_stream_ordering = pusherdict["last_stream_ordering"]
- self.timed_call = None
- self.throttle_params = None
-
- # See httppusher
- self.max_stream_ordering = None
+ self.timed_call = None # type: Optional[DelayedCall]
+ self.throttle_params = {} # type: Dict[str, Dict[str, int]]
+ self._inited = False
self._is_processing = False
- def on_started(self, should_check_for_notifs):
+ def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
- should_check_for_notifs (bool): Whether we should immediately
+ should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
if should_check_for_notifs and self.mailer is not None:
self._start_processing()
- def on_stop(self):
+ def on_stop(self) -> None:
if self.timed_call:
try:
self.timed_call.cancel()
@@ -92,7 +93,7 @@ class EmailPusher:
pass
self.timed_call = None
- def on_new_notifications(self, max_token: RoomStreamToken):
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@@ -106,23 +107,23 @@ class EmailPusher:
self.max_stream_ordering = max_stream_ordering
self._start_processing()
- def on_new_receipts(self, min_stream_id, max_stream_id):
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# We could wake up and cancel the timer but there tend to be quite a
# lot of read receipts so it's probably less work to just let the
# timer fire
pass
- def on_timer(self):
+ def on_timer(self) -> None:
self.timed_call = None
self._start_processing()
- def _start_processing(self):
+ def _start_processing(self) -> None:
if self._is_processing:
return
run_as_background_process("emailpush.process", self._process)
- def _pause_processing(self):
+ def _pause_processing(self) -> None:
"""Used by tests to temporarily pause processing of events.
Asserts that its not currently processing.
@@ -130,25 +131,26 @@ class EmailPusher:
assert not self._is_processing
self._is_processing = True
- def _resume_processing(self):
+ def _resume_processing(self) -> None:
"""Used by tests to resume processing of events after pausing.
"""
assert self._is_processing
self._is_processing = False
self._start_processing()
- async def _process(self):
+ async def _process(self) -> None:
# we should never get here if we are already processing
assert not self._is_processing
try:
self._is_processing = True
- if self.throttle_params is None:
+ if not self._inited:
# this is our first loop: load up the throttle params
self.throttle_params = await self.store.get_throttle_params_by_room(
self.pusher_id
)
+ self._inited = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
@@ -163,17 +165,19 @@ class EmailPusher:
finally:
self._is_processing = False
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
"""
Main logic of the push loop without the wrapper function that sets
up logging, measures and guards against multiple instances of it
being run.
"""
start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering
- fn = self.store.get_unread_push_actions_for_user_in_range_for_email
- unprocessed = await fn(self.user_id, start, self.max_stream_ordering)
+ assert self.max_stream_ordering is not None
+ unprocessed = await self.store.get_unread_push_actions_for_user_in_range_for_email(
+ self.user_id, start, self.max_stream_ordering
+ )
- soonest_due_at = None
+ soonest_due_at = None # type: Optional[int]
if not unprocessed:
await self.save_last_stream_ordering_and_success(self.max_stream_ordering)
@@ -230,7 +234,9 @@ class EmailPusher:
self.seconds_until(soonest_due_at), self.on_timer
)
- async def save_last_stream_ordering_and_success(self, last_stream_ordering):
+ async def save_last_stream_ordering_and_success(
+ self, last_stream_ordering: Optional[int]
+ ) -> None:
if last_stream_ordering is None:
# This happens if we haven't yet processed anything
return
@@ -248,28 +254,30 @@ class EmailPusher:
# lets just stop and return.
self.on_stop()
- def seconds_until(self, ts_msec):
+ def seconds_until(self, ts_msec: int) -> float:
secs = (ts_msec - self.clock.time_msec()) / 1000
return max(secs, 0)
- def get_room_throttle_ms(self, room_id):
+ def get_room_throttle_ms(self, room_id: str) -> int:
if room_id in self.throttle_params:
return self.throttle_params[room_id]["throttle_ms"]
else:
return 0
- def get_room_last_sent_ts(self, room_id):
+ def get_room_last_sent_ts(self, room_id: str) -> int:
if room_id in self.throttle_params:
return self.throttle_params[room_id]["last_sent_ts"]
else:
return 0
- def room_ready_to_notify_at(self, room_id):
+ def room_ready_to_notify_at(self, room_id: str) -> int:
"""
Determines whether throttling should prevent us from sending an email
for the given room
- Returns: The timestamp when we are next allowed to send an email notif
- for this room
+
+ Returns:
+ The timestamp when we are next allowed to send an email notif
+ for this room
"""
last_sent_ts = self.get_room_last_sent_ts(room_id)
throttle_ms = self.get_room_throttle_ms(room_id)
@@ -277,7 +285,9 @@ class EmailPusher:
may_send_at = last_sent_ts + throttle_ms
return may_send_at
- async def sent_notif_update_throttle(self, room_id, notified_push_action):
+ async def sent_notif_update_throttle(
+ self, room_id: str, notified_push_action: dict
+ ) -> None:
# We have sent a notification, so update the throttle accordingly.
# If the event that triggered the notif happened more than
# THROTTLE_RESET_AFTER_MS after the previous one that triggered a
@@ -315,7 +325,7 @@ class EmailPusher:
self.pusher_id, room_id, self.throttle_params[room_id]
)
- async def send_notification(self, push_actions, reason):
+ async def send_notification(self, push_actions: List[dict], reason: dict) -> None:
logger.info("Sending notif email for user %r", self.user_id)
await self.mailer.send_notification_mail(
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index d011e0aced..995e86e31a 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -14,19 +14,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
+import urllib.parse
+from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
from synapse.logging import opentracing
from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.push import PusherConfigException
+from synapse.push import Pusher, PusherConfigException
from synapse.types import RoomStreamToken
from . import push_rule_evaluator, push_tools
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
http_push_processed_counter = Counter(
@@ -50,24 +56,18 @@ http_badges_failed_counter = Counter(
)
-class HttpPusher:
+class HttpPusher(Pusher):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60
# This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
- def __init__(self, hs, pusherdict):
- self.hs = hs
- self.store = self.hs.get_datastore()
+ def __init__(self, hs: "HomeServer", pusherdict: Dict[str, Any]):
+ super().__init__(hs, pusherdict)
self.storage = self.hs.get_storage()
- self.clock = self.hs.get_clock()
- self.state_handler = self.hs.get_state_handler()
- self.user_id = pusherdict["user_name"]
- self.app_id = pusherdict["app_id"]
self.app_display_name = pusherdict["app_display_name"]
self.device_display_name = pusherdict["device_display_name"]
- self.pushkey = pusherdict["pushkey"]
self.pushkey_ts = pusherdict["ts"]
self.data = pusherdict["data"]
self.last_stream_ordering = pusherdict["last_stream_ordering"]
@@ -77,13 +77,6 @@ class HttpPusher:
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
- # This is the highest stream ordering we know it's safe to process.
- # When new events arrive, we'll be given a window of new events: we
- # should honour this rather than just looking for anything higher
- # because of potential out-of-order event serialisation. This starts
- # off as None though as we don't know any better.
- self.max_stream_ordering = None
-
if "data" not in pusherdict:
raise PusherConfigException("No 'data' key for HTTP pusher")
self.data = pusherdict["data"]
@@ -97,30 +90,44 @@ class HttpPusher:
if self.data is None:
raise PusherConfigException("data can not be null for HTTP pusher")
+ # Validate that there's a URL and it is of the proper form.
if "url" not in self.data:
raise PusherConfigException("'url' required in data for HTTP pusher")
- self.url = self.data["url"]
- self.url = self.url.replace(
+
+ url = self.data["url"]
+ if not isinstance(url, str):
+ raise PusherConfigException("'url' must be a string")
+ url_parts = urllib.parse.urlparse(url)
+ # Note that the specification also says the scheme must be HTTPS, but
+ # it isn't up to the homeserver to verify that.
+ if url_parts.path != "/_matrix/push/v1/notify":
+ raise PusherConfigException(
+ "'url' must have a path of '/_matrix/push/v1/notify'"
+ )
+
+ url = url.replace(
"https://matrix.org/_matrix/push/v1/notify",
"http://10.103.0.7/_matrix/push/v1/notify",
)
- self.http_client = hs.get_proxied_http_client()
+
+ self.url = url
+ self.http_client = hs.get_proxied_blacklisted_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url["url"]
- def on_started(self, should_check_for_notifs):
+ def on_started(self, should_check_for_notifs: bool) -> None:
"""Called when this pusher has been started.
Args:
- should_check_for_notifs (bool): Whether we should immediately
+ should_check_for_notifs: Whether we should immediately
check for push to send. Set to False only if it's known there
is nothing to send
"""
if should_check_for_notifs:
self._start_processing()
- def on_new_notifications(self, max_token: RoomStreamToken):
+ def on_new_notifications(self, max_token: RoomStreamToken) -> None:
# We just use the minimum stream ordering and ignore the vector clock
# component. This is safe to do as long as we *always* ignore the vector
# clock components.
@@ -131,14 +138,14 @@ class HttpPusher:
)
self._start_processing()
- def on_new_receipts(self, min_stream_id, max_stream_id):
+ def on_new_receipts(self, min_stream_id: int, max_stream_id: int) -> None:
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
- async def _update_badge(self):
+ async def _update_badge(self) -> None:
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = await push_tools.get_badge_count(
@@ -148,10 +155,10 @@ class HttpPusher:
)
await self._send_badge(badge)
- def on_timer(self):
+ def on_timer(self) -> None:
self._start_processing()
- def on_stop(self):
+ def on_stop(self) -> None:
if self.timed_call:
try:
self.timed_call.cancel()
@@ -159,13 +166,13 @@ class HttpPusher:
pass
self.timed_call = None
- def _start_processing(self):
+ def _start_processing(self) -> None:
if self._is_processing:
return
run_as_background_process("httppush.process", self._process)
- async def _process(self):
+ async def _process(self) -> None:
# we should never get here if we are already processing
assert not self._is_processing
@@ -184,7 +191,7 @@ class HttpPusher:
finally:
self._is_processing = False
- async def _unsafe_process(self):
+ async def _unsafe_process(self) -> None:
"""
Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to
@@ -192,6 +199,7 @@ class HttpPusher:
"""
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
+ assert self.max_stream_ordering is not None
unprocessed = await fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@@ -261,17 +269,12 @@ class HttpPusher:
)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
- pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
+ await self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
self.last_stream_ordering,
)
- if not pusher_still_exists:
- # The pusher has been deleted while we were processing, so
- # lets just stop and return.
- self.on_stop()
- return
self.failing_since = None
await self.store.update_pusher_failing_since(
@@ -287,7 +290,7 @@ class HttpPusher:
)
break
- async def _process_one(self, push_action):
+ async def _process_one(self, push_action: dict) -> bool:
if "notify" not in push_action["actions"]:
return True
@@ -318,7 +321,9 @@ class HttpPusher:
await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True
- async def _build_notification_dict(self, event, tweaks, badge):
+ async def _build_notification_dict(
+ self, event: EventBase, tweaks: Dict[str, bool], badge: int
+ ) -> Dict[str, Any]:
priority = "low"
if (
event.type == EventTypes.Encrypted
@@ -348,9 +353,7 @@ class HttpPusher:
}
return d
- ctx = await push_tools.get_context_for_event(
- self.storage, self.state_handler, event, self.user_id
- )
+ ctx = await push_tools.get_context_for_event(self.storage, event, self.user_id)
d = {
"notification": {
@@ -390,7 +393,9 @@ class HttpPusher:
return d
- async def dispatch_push(self, event, tweaks, badge):
+ async def dispatch_push(
+ self, event: EventBase, tweaks: Dict[str, bool], badge: int
+ ) -> Union[bool, Iterable[str]]:
notification_dict = await self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
return []
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index 38195c8eea..9ff092e8bb 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -19,7 +19,7 @@ import logging
import urllib.parse
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
-from typing import Iterable, List, TypeVar
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, TypeVar
import bleach
import jinja2
@@ -27,16 +27,20 @@ import jinja2
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import StoreError
from synapse.config.emailconfig import EmailSubjectConfig
+from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable
from synapse.push.presentable_names import (
calculate_room_name,
descriptor_from_member_events,
name_from_member_event,
)
-from synapse.types import UserID
+from synapse.types import StateMap, UserID
from synapse.util.async_helpers import concurrently_execute
from synapse.visibility import filter_events_for_client
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -93,7 +97,13 @@ ALLOWED_ATTRS = {
class Mailer:
- def __init__(self, hs, app_name, template_html, template_text):
+ def __init__(
+ self,
+ hs: "HomeServer",
+ app_name: str,
+ template_html: jinja2.Template,
+ template_text: jinja2.Template,
+ ):
self.hs = hs
self.template_html = template_html
self.template_text = template_text
@@ -108,17 +118,19 @@ class Mailer:
logger.info("Created Mailer for app_name %s" % app_name)
- async def send_password_reset_mail(self, email_address, token, client_secret, sid):
+ async def send_password_reset_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a password reset link to a user
Args:
- email_address (str): Email address we're sending the password
+ email_address: Email address we're sending the password
reset to
- token (str): Unique token generated by the server to verify
+ token: Unique token generated by the server to verify
the email was received
- client_secret (str): Unique token generated by the client to
+ client_secret: Unique token generated by the client to
group together multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -136,17 +148,19 @@ class Mailer:
template_vars,
)
- async def send_registration_mail(self, email_address, token, client_secret, sid):
+ async def send_registration_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a registration confirmation link to a user
Args:
- email_address (str): Email address we're sending the registration
+ email_address: Email address we're sending the registration
link to
- token (str): Unique token generated by the server to verify
+ token: Unique token generated by the server to verify
the email was received
- client_secret (str): Unique token generated by the client to
+ client_secret: Unique token generated by the client to
group together multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -164,18 +178,20 @@ class Mailer:
template_vars,
)
- async def send_add_threepid_mail(self, email_address, token, client_secret, sid):
+ async def send_add_threepid_mail(
+ self, email_address: str, token: str, client_secret: str, sid: str
+ ) -> None:
"""Send an email with a validation link to a user for adding a 3pid to their account
Args:
- email_address (str): Email address we're sending the validation link to
+ email_address: Email address we're sending the validation link to
- token (str): Unique token generated by the server to verify the email was received
+ token: Unique token generated by the server to verify the email was received
- client_secret (str): Unique token generated by the client to group together
+ client_secret: Unique token generated by the client to group together
multiple email sending attempts
- sid (str): The generated session ID
+ sid: The generated session ID
"""
params = {"token": token, "client_secret": client_secret, "sid": sid}
link = (
@@ -194,8 +210,13 @@ class Mailer:
)
async def send_notification_mail(
- self, app_id, user_id, email_address, push_actions, reason
- ):
+ self,
+ app_id: str,
+ user_id: str,
+ email_address: str,
+ push_actions: Iterable[Dict[str, Any]],
+ reason: Dict[str, Any],
+ ) -> None:
"""Send email regarding a user's room notifications"""
rooms_in_order = deduped_ordered_list([pa["room_id"] for pa in push_actions])
@@ -203,7 +224,7 @@ class Mailer:
[pa["event_id"] for pa in push_actions]
)
- notifs_by_room = {}
+ notifs_by_room = {} # type: Dict[str, List[Dict[str, Any]]]
for pa in push_actions:
notifs_by_room.setdefault(pa["room_id"], []).append(pa)
@@ -262,7 +283,9 @@ class Mailer:
await self.send_email(email_address, summary_text, template_vars)
- async def send_email(self, email_address, subject, extra_template_vars):
+ async def send_email(
+ self, email_address: str, subject: str, extra_template_vars: Dict[str, Any]
+ ) -> None:
"""Send an email with the given information and template text"""
try:
from_string = self.hs.config.email_notif_from % {"app": self.app_name}
@@ -315,8 +338,13 @@ class Mailer:
)
async def get_room_vars(
- self, room_id, user_id, notifs, notif_events, room_state_ids
- ):
+ self,
+ room_id: str,
+ user_id: str,
+ notifs: Iterable[Dict[str, Any]],
+ notif_events: Dict[str, EventBase],
+ room_state_ids: StateMap[str],
+ ) -> Dict[str, Any]:
# Check if one of the notifs is an invite event for the user.
is_invite = False
for n in notifs:
@@ -334,7 +362,7 @@ class Mailer:
"notifs": [],
"invite": is_invite,
"link": self.make_room_link(room_id),
- }
+ } # type: Dict[str, Any]
if not is_invite:
for n in notifs:
@@ -365,7 +393,13 @@ class Mailer:
return room_vars
- async def get_notif_vars(self, notif, user_id, notif_event, room_state_ids):
+ async def get_notif_vars(
+ self,
+ notif: Dict[str, Any],
+ user_id: str,
+ notif_event: EventBase,
+ room_state_ids: StateMap[str],
+ ) -> Dict[str, Any]:
results = await self.store.get_events_around(
notif["room_id"],
notif["event_id"],
@@ -391,7 +425,9 @@ class Mailer:
return ret
- async def get_message_vars(self, notif, event, room_state_ids):
+ async def get_message_vars(
+ self, notif: Dict[str, Any], event: EventBase, room_state_ids: StateMap[str]
+ ) -> Optional[Dict[str, Any]]:
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return None
@@ -432,7 +468,9 @@ class Mailer:
return ret
- def add_text_message_vars(self, messagevars, event):
+ def add_text_message_vars(
+ self, messagevars: Dict[str, Any], event: EventBase
+ ) -> None:
msgformat = event.content.get("format")
messagevars["format"] = msgformat
@@ -445,15 +483,18 @@ class Mailer:
elif body:
messagevars["body_text_html"] = safe_text(body)
- return messagevars
-
- def add_image_message_vars(self, messagevars, event):
+ def add_image_message_vars(
+ self, messagevars: Dict[str, Any], event: EventBase
+ ) -> None:
messagevars["image_url"] = event.content["url"]
- return messagevars
-
async def make_summary_text(
- self, notifs_by_room, room_state_ids, notif_events, user_id, reason
+ self,
+ notifs_by_room: Dict[str, List[Dict[str, Any]]],
+ room_state_ids: Dict[str, StateMap[str]],
+ notif_events: Dict[str, EventBase],
+ user_id: str,
+ reason: Dict[str, Any],
):
if len(notifs_by_room) == 1:
# Only one room has new stuff
@@ -580,7 +621,7 @@ class Mailer:
"app": self.app_name,
}
- def make_room_link(self, room_id):
+ def make_room_link(self, room_id: str) -> str:
if self.hs.config.email_riot_base_url:
base_url = "%s/#/room" % (self.hs.config.email_riot_base_url)
elif self.app_name == "Vector":
@@ -590,7 +631,7 @@ class Mailer:
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
- def make_notif_link(self, notif):
+ def make_notif_link(self, notif: Dict[str, str]) -> str:
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
@@ -606,7 +647,9 @@ class Mailer:
else:
return "https://matrix.to/#/%s/%s" % (notif["room_id"], notif["event_id"])
- def make_unsubscribe_link(self, user_id, app_id, email_address):
+ def make_unsubscribe_link(
+ self, user_id: str, app_id: str, email_address: str
+ ) -> str:
params = {
"access_token": self.macaroon_gen.generate_delete_pusher_token(user_id),
"app_id": app_id,
@@ -620,7 +663,7 @@ class Mailer:
)
-def safe_markup(raw_html):
+def safe_markup(raw_html: str) -> jinja2.Markup:
return jinja2.Markup(
bleach.linkify(
bleach.clean(
@@ -635,7 +678,7 @@ def safe_markup(raw_html):
)
-def safe_text(raw_text):
+def safe_text(raw_text: str) -> jinja2.Markup:
"""
Process text: treat it as HTML but escape any tags (ie. just escape the
HTML) then linkify it.
@@ -655,7 +698,7 @@ def deduped_ordered_list(it: Iterable[T]) -> List[T]:
return ret
-def string_ordinal_total(s):
+def string_ordinal_total(s: str) -> int:
tot = 0
for c in s:
tot += ord(c)
diff --git a/synapse/push/presentable_names.py b/synapse/push/presentable_names.py
index d8f4a453cd..7e50341d74 100644
--- a/synapse/push/presentable_names.py
+++ b/synapse/push/presentable_names.py
@@ -15,8 +15,14 @@
import logging
import re
+from typing import TYPE_CHECKING, Dict, Iterable, Optional
from synapse.api.constants import EventTypes
+from synapse.events import EventBase
+from synapse.types import StateMap
+
+if TYPE_CHECKING:
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -28,25 +34,29 @@ ALL_ALONE = "Empty Room"
async def calculate_room_name(
- store,
- room_state_ids,
- user_id,
- fallback_to_members=True,
- fallback_to_single_member=True,
-):
+ store: "DataStore",
+ room_state_ids: StateMap[str],
+ user_id: str,
+ fallback_to_members: bool = True,
+ fallback_to_single_member: bool = True,
+) -> Optional[str]:
"""
Works out a user-facing name for the given room as per Matrix
spec recommendations.
Does not yet support internationalisation.
Args:
- room_state: Dictionary of the room's state
+ store: The data store to query.
+ room_state_ids: Dictionary of the room's state IDs.
user_id: The ID of the user to whom the room name is being presented
fallback_to_members: If False, return None instead of generating a name
based on the room's members if the room has no
title or aliases.
+ fallback_to_single_member: If False, return None instead of generating a
+ name based on the user who invited this user to the room if the room
+ has no title or aliases.
Returns:
- (string or None) A human readable name for the room.
+ A human readable name for the room, if possible.
"""
# does it have a name?
if (EventTypes.Name, "") in room_state_ids:
@@ -97,7 +107,7 @@ async def calculate_room_name(
name_from_member_event(inviter_member_event),
)
else:
- return
+ return None
else:
return "Room Invite"
@@ -150,19 +160,19 @@ async def calculate_room_name(
else:
return ALL_ALONE
elif len(other_members) == 1 and not fallback_to_single_member:
- return
- else:
- return descriptor_from_member_events(other_members)
+ return None
+
+ return descriptor_from_member_events(other_members)
-def descriptor_from_member_events(member_events):
+def descriptor_from_member_events(member_events: Iterable[EventBase]) -> str:
"""Get a description of the room based on the member events.
Args:
- member_events (Iterable[FrozenEvent])
+ member_events: The events of a room.
Returns:
- str
+ The room description
"""
member_events = list(member_events)
@@ -183,7 +193,7 @@ def descriptor_from_member_events(member_events):
)
-def name_from_member_event(member_event):
+def name_from_member_event(member_event: EventBase) -> str:
if (
member_event.content
and "displayname" in member_event.content
@@ -193,12 +203,12 @@ def name_from_member_event(member_event):
return member_event.state_key
-def _state_as_two_level_dict(state):
- ret = {}
+def _state_as_two_level_dict(state: StateMap[str]) -> Dict[str, Dict[str, str]]:
+ ret = {} # type: Dict[str, Dict[str, str]]
for k, v in state.items():
ret.setdefault(k[0], {})[k[1]] = v
return ret
-def _looks_like_an_alias(string):
+def _looks_like_an_alias(string: str) -> bool:
return ALIAS_RE.match(string) is not None
diff --git a/synapse/push/push_rule_evaluator.py b/synapse/push/push_rule_evaluator.py
index 2ce9e444ab..ba1877adcd 100644
--- a/synapse/push/push_rule_evaluator.py
+++ b/synapse/push/push_rule_evaluator.py
@@ -30,22 +30,30 @@ IS_GLOB = re.compile(r"[\?\*\[\]]")
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
-def _room_member_count(ev, condition, room_member_count):
+def _room_member_count(
+ ev: EventBase, condition: Dict[str, Any], room_member_count: int
+) -> bool:
return _test_ineq_condition(condition, room_member_count)
-def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
+def _sender_notification_permission(
+ ev: EventBase,
+ condition: Dict[str, Any],
+ sender_power_level: int,
+ power_levels: Dict[str, Union[int, Dict[str, int]]],
+) -> bool:
notif_level_key = condition.get("key")
if notif_level_key is None:
return False
notif_levels = power_levels.get("notifications", {})
+ assert isinstance(notif_levels, dict)
room_notif_level = notif_levels.get(notif_level_key, 50)
return sender_power_level >= room_notif_level
-def _test_ineq_condition(condition, number):
+def _test_ineq_condition(condition: Dict[str, Any], number: int) -> bool:
if "is" not in condition:
return False
m = INEQUALITY_EXPR.match(condition["is"])
@@ -110,7 +118,7 @@ class PushRuleEvaluatorForEvent:
event: EventBase,
room_member_count: int,
sender_power_level: int,
- power_levels: dict,
+ power_levels: Dict[str, Union[int, Dict[str, int]]],
):
self._event = event
self._room_member_count = room_member_count
@@ -120,7 +128,9 @@ class PushRuleEvaluatorForEvent:
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
self._value_cache = _flatten_dict(event)
- def matches(self, condition: dict, user_id: str, display_name: str) -> bool:
+ def matches(
+ self, condition: Dict[str, Any], user_id: str, display_name: str
+ ) -> bool:
if condition["kind"] == "event_match":
return self._event_match(condition, user_id)
elif condition["kind"] == "contains_display_name":
@@ -261,7 +271,13 @@ def _re_word_boundary(r: str) -> str:
return r"(^|\W)%s(\W|$)" % (r,)
-def _flatten_dict(d, prefix=[], result=None):
+def _flatten_dict(
+ d: Union[EventBase, dict],
+ prefix: Optional[List[str]] = None,
+ result: Optional[Dict[str, str]] = None,
+) -> Dict[str, str]:
+ if prefix is None:
+ prefix = []
if result is None:
result = {}
for key, value in d.items():
diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py
index 6e7c880dc0..df34103224 100644
--- a/synapse/push/push_tools.py
+++ b/synapse/push/push_tools.py
@@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Dict
+
+from synapse.events import EventBase
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
@@ -46,7 +49,9 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
return badge
-async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
+async def get_context_for_event(
+ storage: Storage, ev: EventBase, user_id: str
+) -> Dict[str, str]:
ctx = {}
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py
index 2a52e226e3..8f1072b094 100644
--- a/synapse/push/pusher.py
+++ b/synapse/push/pusher.py
@@ -14,25 +14,31 @@
# limitations under the License.
import logging
+from typing import TYPE_CHECKING, Any, Callable, Dict, Optional
+from synapse.push import Pusher
from synapse.push.emailpusher import EmailPusher
+from synapse.push.httppusher import HttpPusher
from synapse.push.mailer import Mailer
-from .httppusher import HttpPusher
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class PusherFactory:
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.config = hs.config
- self.pusher_types = {"http": HttpPusher}
+ self.pusher_types = {
+ "http": HttpPusher
+ } # type: Dict[str, Callable[[HomeServer, dict], Pusher]]
logger.info("email enable notifs: %r", hs.config.email_enable_notifs)
if hs.config.email_enable_notifs:
- self.mailers = {} # app_name -> Mailer
+ self.mailers = {} # type: Dict[str, Mailer]
self._notif_template_html = hs.config.email_notif_template_html
self._notif_template_text = hs.config.email_notif_template_text
@@ -41,7 +47,7 @@ class PusherFactory:
logger.info("defined email pusher type")
- def create_pusher(self, pusherdict):
+ def create_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
kind = pusherdict["kind"]
f = self.pusher_types.get(kind, None)
if not f:
@@ -49,7 +55,9 @@ class PusherFactory:
logger.debug("creating %s pusher for %r", kind, pusherdict)
return f(self.hs, pusherdict)
- def _create_email_pusher(self, _hs, pusherdict):
+ def _create_email_pusher(
+ self, _hs: "HomeServer", pusherdict: Dict[str, Any]
+ ) -> EmailPusher:
app_name = self._app_name_from_pusherdict(pusherdict)
mailer = self.mailers.get(app_name)
if not mailer:
@@ -62,7 +70,7 @@ class PusherFactory:
self.mailers[app_name] = mailer
return EmailPusher(self.hs, pusherdict, mailer)
- def _app_name_from_pusherdict(self, pusherdict):
+ def _app_name_from_pusherdict(self, pusherdict: Dict[str, Any]) -> str:
data = pusherdict["data"]
if isinstance(data, dict):
diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py
index f325964983..9fcc0b8a64 100644
--- a/synapse/push/pusherpool.py
+++ b/synapse/push/pusherpool.py
@@ -15,7 +15,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Dict, Union
+from typing import TYPE_CHECKING, Any, Dict, Optional
from prometheus_client import Gauge
@@ -23,9 +23,7 @@ from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
-from synapse.push import PusherConfigException
-from synapse.push.emailpusher import EmailPusher
-from synapse.push.httppusher import HttpPusher
+from synapse.push import Pusher, PusherConfigException
from synapse.push.pusher import PusherFactory
from synapse.types import RoomStreamToken
from synapse.util.async_helpers import concurrently_execute
@@ -77,7 +75,7 @@ class PusherPool:
self._last_room_stream_id_seen = self.store.get_room_max_stream_ordering()
# map from user id to app_id:pushkey to pusher
- self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
+ self.pushers = {} # type: Dict[str, Dict[str, Pusher]]
def start(self):
"""Starts the pushers off in a background process.
@@ -99,11 +97,11 @@ class PusherPool:
lang,
data,
profile_tag="",
- ):
+ ) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
Returns:
- EmailPusher|HttpPusher
+ The newly created pusher.
"""
time_now_msec = self.clock.time_msec()
@@ -267,17 +265,19 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_receipts")
- async def start_pusher_by_id(self, app_id, pushkey, user_id):
+ async def start_pusher_by_id(
+ self, app_id: str, pushkey: str, user_id: str
+ ) -> Optional[Pusher]:
"""Look up the details for the given pusher, and start it
Returns:
- EmailPusher|HttpPusher|None: The pusher started, if any
+ The pusher started, if any
"""
if not self._should_start_pushers:
- return
+ return None
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
- return
+ return None
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
@@ -303,19 +303,19 @@ class PusherPool:
logger.info("Started pushers")
- async def _start_pusher(self, pusherdict):
+ async def _start_pusher(self, pusherdict: Dict[str, Any]) -> Optional[Pusher]:
"""Start the given pusher
Args:
- pusherdict (dict): dict with the values pulled from the db table
+ pusherdict: dict with the values pulled from the db table
Returns:
- EmailPusher|HttpPusher
+ The newly created pusher or None.
"""
if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"]
):
- return
+ return None
try:
p = self.pusher_factory.create_pusher(pusherdict)
@@ -328,15 +328,15 @@ class PusherPool:
pusherdict.get("pushkey"),
e,
)
- return
+ return None
except Exception:
logger.exception(
"Couldn't start pusher id %i: caught Exception", pusherdict["id"],
)
- return
+ return None
if not p:
- return
+ return None
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index c899ca14d3..c97e0df1f5 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -96,7 +96,11 @@ CONDITIONAL_REQUIREMENTS = {
# python 3.5.2, as per https://github.com/itamarst/eliot/issues/418
'eliot<1.8.0;python_version<"3.5.3"',
],
- "saml2": ["pysaml2>=4.5.0"],
+ "saml2": [
+ # pysaml2 6.4.0 is incompatible with Python 3.5 (see https://github.com/IdentityPython/pysaml2/issues/749)
+ "pysaml2>=4.5.0,<6.4.0;python_version<'3.6'",
+ "pysaml2>=4.5.0;python_version>='3.6'",
+ ],
"oidc": ["authlib>=0.14.0"],
"systemd": ["systemd-python>=231"],
"url_preview": ["lxml>=3.5.0"],
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 2b3972cb14..1492ac922c 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
assert self.METHOD in ("PUT", "POST", "GET")
+ self._replication_secret = None
+ if hs.config.worker.worker_replication_secret:
+ self._replication_secret = hs.config.worker.worker_replication_secret
+
+ def _check_auth(self, request) -> None:
+ # Get the authorization header.
+ auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
+
+ if len(auth_headers) > 1:
+ raise RuntimeError("Too many Authorization headers.")
+ parts = auth_headers[0].split(b" ")
+ if parts[0] == b"Bearer" and len(parts) == 2:
+ received_secret = parts[1].decode("ascii")
+ if self._replication_secret == received_secret:
+ # Success!
+ return
+
+ raise RuntimeError("Invalid Authorization header.")
+
@abc.abstractmethod
async def _serialize_payload(**kwargs):
"""Static method that is called when creating a request.
@@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME)
+ replication_secret = None
+ if hs.config.worker.worker_replication_secret:
+ replication_secret = hs.config.worker.worker_replication_secret.encode(
+ "ascii"
+ )
+
@trace(opname="outgoing_replication_request")
@outgoing_gauge.track_inprogress()
async def send_request(instance_name="master", **kwargs):
@@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# the master, and so whether we should clean up or not.
while True:
headers = {} # type: Dict[bytes, List[bytes]]
+ # Add an authorization header, if configured.
+ if replication_secret:
+ headers[b"Authorization"] = [b"Bearer " + replication_secret]
inject_active_span_byte_dict(headers, None, check_destination=False)
try:
result = await request_func(uri, data, headers=headers)
@@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""
url_args = list(self.PATH_ARGS)
- handler = self._handle_request
method = self.METHOD
if self.CACHE:
- handler = self._cached_handler # type: ignore
url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)
pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args))
http_server.register_paths(
- method, [pattern], handler, self.__class__.__name__,
+ method, [pattern], self._check_auth_and_handle, self.__class__.__name__,
)
- def _cached_handler(self, request, txn_id, **kwargs):
+ def _check_auth_and_handle(self, request, **kwargs):
"""Called on new incoming requests when caching is enabled. Checks
if there is a cached response for the request and returns that,
otherwise calls `_handle_request` and caches its response.
@@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta):
# We just use the txn_id here, but we probably also want to use the
# other PATH_ARGS as well.
- assert self.CACHE
+ # Check the authorization headers before handling the request.
+ if self._replication_secret:
+ self._check_auth(request)
+
+ if self.CACHE:
+ txn_id = kwargs.pop("txn_id")
+
+ return self.response_cache.wrap(
+ txn_id, self._handle_request, request, **kwargs
+ )
- return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs)
+ return self._handle_request(request, **kwargs)
diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py
index a509e599c2..804da994ea 100644
--- a/synapse/replication/tcp/protocol.py
+++ b/synapse/replication/tcp/protocol.py
@@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# a logcontext which we use for processing incoming commands. We declare it as a
# background process so that the CPU stats get reported to prometheus.
ctx_name = "replication-conn-%s" % self.conn_id
- self._logging_context = BackgroundProcessLoggingContext(ctx_name)
- self._logging_context.request = ctx_name
+ self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name)
def connectionMade(self):
logger.info("[%s] Connection established", self.id())
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 25f89e4685..b902af8028 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
from http import HTTPStatus
-from typing import List, Optional
+from typing import TYPE_CHECKING, List, Optional, Tuple
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -25,13 +25,17 @@ from synapse.http.servlet import (
parse_json_object_from_request,
parse_string,
)
+from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
-from synapse.types import RoomAlias, RoomID, UserID, create_requester
+from synapse.types import JsonDict, RoomAlias, RoomID, UserID, create_requester
+
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -45,12 +49,14 @@ class ShutdownRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/shutdown_room/(?P<room_id>[^/]+)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -86,13 +92,15 @@ class DeleteRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/delete$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_shutdown_handler = hs.get_room_shutdown_handler()
self.pagination_handler = hs.get_pagination_handler()
- async def on_POST(self, request, room_id):
+ async def on_POST(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -146,12 +154,12 @@ class ListRoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
- async def on_GET(self, request):
+ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -236,19 +244,24 @@ class RoomRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)$")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room_with_stats(room_id)
if not ret:
raise NotFoundError("Room not found")
- return 200, ret
+ members = await self.store.get_users_in_room(room_id)
+ ret["joined_local_devices"] = await self.store.count_devices_by_users(members)
+
+ return (200, ret)
class RoomMembersRestServlet(RestServlet):
@@ -258,12 +271,14 @@ class RoomMembersRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
- async def on_GET(self, request, room_id):
+ async def on_GET(
+ self, request: SynapseRequest, room_id: str
+ ) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
@@ -280,14 +295,16 @@ class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
- async def on_POST(self, request, room_identifier):
+ async def on_POST(
+ self, request: SynapseRequest, room_identifier: str
+ ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -314,7 +331,6 @@ class JoinRoomAliasServlet(RestServlet):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
- room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py
index b0ff5e1ead..88cba369f5 100644
--- a/synapse/rest/admin/users.py
+++ b/synapse/rest/admin/users.py
@@ -320,9 +320,9 @@ class UserRestServletV2(RestServlet):
data={},
)
- if "avatar_url" in body and type(body["avatar_url"]) == str:
+ if "avatar_url" in body and isinstance(body["avatar_url"], str):
await self.profile_handler.set_avatar_url(
- user_id, requester, body["avatar_url"], True
+ target_user, requester, body["avatar_url"], True
)
ret = await self.admin_handler.get_user(target_user)
@@ -420,6 +420,9 @@ class UserRegisterServlet(RestServlet):
if user_type is not None and user_type not in UserTypes.ALL_USER_TYPES:
raise SynapseError(400, "Invalid user type")
+ if "mac" not in body:
+ raise SynapseError(400, "mac must be specified", errcode=Codes.BAD_JSON)
+
got_mac = body["mac"]
want_mac_builder = hmac.new(
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index d7ae148214..5f4c6703db 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Awaitable, Callable, Dict, Optional
+from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
@@ -30,6 +30,9 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
+if TYPE_CHECKING:
+ from synapse.server import HomeServer
+
logger = logging.getLogger(__name__)
@@ -42,7 +45,7 @@ class LoginRestServlet(RestServlet):
JWT_TYPE_DEPRECATED = "m.login.jwt"
APPSERVICE_TYPE = "uk.half-shot.msc2778.login.application_service"
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
@@ -105,22 +108,27 @@ class LoginRestServlet(RestServlet):
return 200, {"flows": flows}
async def on_POST(self, request: SynapseRequest):
- self._address_ratelimiter.ratelimit(request.getClientIP())
-
login_submission = parse_json_object_from_request(request)
try:
if login_submission["type"] == LoginRestServlet.APPSERVICE_TYPE:
appservice = self.auth.get_appservice_by_req(request)
+
+ if appservice.is_rate_limited():
+ self._address_ratelimiter.ratelimit(request.getClientIP())
+
result = await self._do_appservice_login(login_submission, appservice)
elif self.jwt_enabled and (
login_submission["type"] == LoginRestServlet.JWT_TYPE
or login_submission["type"] == LoginRestServlet.JWT_TYPE_DEPRECATED
):
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_jwt_login(login_submission)
elif login_submission["type"] == LoginRestServlet.TOKEN_TYPE:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_token_login(login_submission)
else:
+ self._address_ratelimiter.ratelimit(request.getClientIP())
result = await self._do_other_login(login_submission)
except KeyError:
raise SynapseError(400, "Missing JSON keys.")
@@ -159,7 +167,9 @@ class LoginRestServlet(RestServlet):
if not appservice.is_interested_in_user(qualified_user_id):
raise LoginError(403, "Invalid access_token", errcode=Codes.FORBIDDEN)
- return await self._complete_login(qualified_user_id, login_submission)
+ return await self._complete_login(
+ qualified_user_id, login_submission, ratelimit=appservice.is_rate_limited()
+ )
async def _do_other_login(self, login_submission: JsonDict) -> Dict[str, str]:
"""Handle non-token/saml/jwt logins
@@ -194,6 +204,7 @@ class LoginRestServlet(RestServlet):
login_submission: JsonDict,
callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None,
create_non_existent_users: bool = False,
+ ratelimit: bool = True,
) -> Dict[str, str]:
"""Called when we've successfully authed the user and now need to
actually login them in (e.g. create devices). This gets called on
@@ -208,6 +219,7 @@ class LoginRestServlet(RestServlet):
callback: Callback function to run after login.
create_non_existent_users: Whether to create the user if they don't
exist. Defaults to False.
+ ratelimit: Whether to ratelimit the login request.
Returns:
result: Dictionary of account information after successful login.
@@ -216,7 +228,8 @@ class LoginRestServlet(RestServlet):
# Before we actually log them in we check if they've already logged in
# too often. This happens here rather than before as we don't
# necessarily know the user before now.
- self._account_ratelimiter.ratelimit(user_id.lower())
+ if ratelimit:
+ self._account_ratelimiter.ratelimit(user_id.lower())
if create_non_existent_users:
canonical_uid = await self.auth_handler.check_user_exists(user_id)
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index a89ae6ddf9..9041e7ed76 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -451,7 +451,7 @@ class RegisterRestServlet(RestServlet):
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
- raise SynapseError(403, "Registration has been disabled")
+ raise SynapseError(403, "Registration has been disabled", Codes.FORBIDDEN)
# For regular registration, convert the provided username to lowercase
# before attempting to register it. This should mean that people who try
diff --git a/synapse/rest/media/v1/_base.py b/synapse/rest/media/v1/_base.py
index 67aa993f19..47c2b44bff 100644
--- a/synapse/rest/media/v1/_base.py
+++ b/synapse/rest/media/v1/_base.py
@@ -155,6 +155,11 @@ def add_file_headers(request, media_type, file_size, upload_name):
request.setHeader(b"Cache-Control", b"public,max-age=86400,s-maxage=86400")
request.setHeader(b"Content-Length", b"%d" % (file_size,))
+ # Tell web crawlers to not index, archive, or follow links in media. This
+ # should help to prevent things in the media repo from showing up in web
+ # search results.
+ request.setHeader(b"X-Robots-Tag", "noindex, nofollow, noarchive, noimageindex")
+
# separators as defined in RFC2616. SP and HT are handled separately.
# see _can_encode_filename_as_token.
diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py
index 9cac74ebd8..83beb02b05 100644
--- a/synapse/rest/media/v1/media_repository.py
+++ b/synapse/rest/media/v1/media_repository.py
@@ -66,7 +66,7 @@ class MediaRepository:
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
- self.client = hs.get_http_client()
+ self.client = hs.get_federation_http_client()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index dce6c4d168..1082389d9b 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -676,7 +676,11 @@ class PreviewUrlResource(DirectServeJsonResource):
logger.debug("No media removed from url cache")
-def decode_and_calc_og(body, media_uri, request_encoding=None):
+def decode_and_calc_og(body, media_uri, request_encoding=None) -> Dict[str, str]:
+ # If there's no body, nothing useful is going to be found.
+ if not body:
+ return {}
+
from lxml import etree
try:
diff --git a/synapse/rest/media/v1/storage_provider.py b/synapse/rest/media/v1/storage_provider.py
index 18c9ed48d6..67f67efde7 100644
--- a/synapse/rest/media/v1/storage_provider.py
+++ b/synapse/rest/media/v1/storage_provider.py
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
import os
import shutil
@@ -21,6 +20,7 @@ from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
+from synapse.util.async_helpers import maybe_awaitable
from ._base import FileInfo, Responder
from .media_storage import FileResponder
@@ -91,16 +91,14 @@ class StorageProviderWrapper(StorageProvider):
if self.store_synchronous:
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(self.backend.store_file(path, file_info))
else:
# TODO: Handle errors.
async def store():
try:
- result = self.backend.store_file(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(
+ self.backend.store_file(path, file_info)
+ )
except Exception:
logger.exception("Error storing file")
@@ -110,9 +108,7 @@ class StorageProviderWrapper(StorageProvider):
async def fetch(self, path, file_info):
# store_file is supposed to return an Awaitable, but guard
# against improper implementations.
- result = self.backend.fetch(path, file_info)
- if inspect.isawaitable(result):
- return await result
+ return await maybe_awaitable(self.backend.fetch(path, file_info))
class FileStorageProviderBackend(StorageProvider):
diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py
index d76f7389e1..42febc9afc 100644
--- a/synapse/rest/media/v1/upload_resource.py
+++ b/synapse/rest/media/v1/upload_resource.py
@@ -44,7 +44,7 @@ class UploadResource(DirectServeJsonResource):
requester = await self.auth.get_user_by_req(request)
# TODO: The checks here are a bit late. The content will have
# already been uploaded to a tmp file at this point
- content_length = request.getHeader(b"Content-Length").decode("ascii")
+ content_length = request.getHeader("Content-Length")
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
diff --git a/synapse/server.py b/synapse/server.py
index b017e3489f..a198b0eb46 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -350,17 +350,47 @@ class HomeServer(metaclass=abc.ABCMeta):
@cache_in_self
def get_simple_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client with no special configuration.
+ """
return SimpleHttpClient(self)
@cache_in_self
def get_proxied_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client that uses configured HTTP(S) proxies.
+ """
+ return SimpleHttpClient(
+ self,
+ http_proxy=os.getenvb(b"http_proxy"),
+ https_proxy=os.getenvb(b"HTTPS_PROXY"),
+ )
+
+ @cache_in_self
+ def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
+ """
+ An HTTP client that uses configured HTTP(S) proxies and blacklists IPs
+ based on the IP range blacklist/whitelist.
+ """
return SimpleHttpClient(
self,
+ ip_whitelist=self.config.ip_range_whitelist,
+ ip_blacklist=self.config.ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
@cache_in_self
+ def get_federation_http_client(self) -> MatrixFederationHttpClient:
+ """
+ An HTTP client for federation.
+ """
+ tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
+ self.config
+ )
+ return MatrixFederationHttpClient(self, tls_client_options_factory)
+
+ @cache_in_self
def get_room_creation_handler(self) -> RoomCreationHandler:
return RoomCreationHandler(self)
@@ -515,13 +545,6 @@ class HomeServer(metaclass=abc.ABCMeta):
return PusherPool(self)
@cache_in_self
- def get_http_client(self) -> MatrixFederationHttpClient:
- tls_client_options_factory = context_factory.FederationPolicyForHTTPS(
- self.config
- )
- return MatrixFederationHttpClient(self, tls_client_options_factory)
-
- @cache_in_self
def get_media_repository_resource(self) -> MediaRepositoryResource:
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
@@ -595,7 +618,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return StatsHandler(self)
@cache_in_self
- def get_spam_checker(self):
+ def get_spam_checker(self) -> SpamChecker:
return SpamChecker(self)
@cache_in_self
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 1fa3b280b4..84f59c7d85 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -783,7 +783,7 @@ class StateResolutionStore:
)
def get_auth_chain_difference(
- self, state_sets: List[Set[str]]
+ self, room_id: str, state_sets: List[Set[str]]
) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -796,4 +796,4 @@ class StateResolutionStore:
An awaitable that resolves to a set of event IDs.
"""
- return self.store.get_auth_chain_difference(state_sets)
+ return self.store.get_auth_chain_difference(room_id, state_sets)
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..f85124bf81 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
+ auth_diff = await _get_auth_chain_difference(
+ room_id, state_sets, event_map, state_res_store
+ )
full_conflicted_set = set(
itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference(
+ room_id: str,
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
@@ -252,9 +255,90 @@ async def _get_auth_chain_difference(
Set of event IDs
"""
+ # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+ # all events passed to it (and their auth chains) have been persisted
+ # previously. This is not the case for any events in the `event_map`, and so
+ # we need to manually handle those events.
+ #
+ # We do this by:
+ # 1. calculating the auth chain difference for the state sets based on the
+ # events in `event_map` alone
+ # 2. replacing any events in the state_sets that are also in `event_map`
+ # with their auth events (recursively), and then calling
+ # `store.get_auth_chain_difference` as normal
+ # 3. adding the results of 1 and 2 together.
+
+ # Map from event ID in `event_map` to their auth event IDs, and their auth
+ # event IDs if they appear in the `event_map`. This is the intersection of
+ # the event's auth chain with the events in the `event_map` *plus* their
+ # auth event IDs.
+ events_to_auth_chain = {} # type: Dict[str, Set[str]]
+ for event in event_map.values():
+ chain = {event.event_id}
+ events_to_auth_chain[event.event_id] = chain
+
+ to_search = [event]
+ while to_search:
+ for auth_id in to_search.pop().auth_event_ids():
+ chain.add(auth_id)
+ auth_event = event_map.get(auth_id)
+ if auth_event:
+ to_search.append(auth_event)
+
+ # We now a) calculate the auth chain difference for the unpersisted events
+ # and b) work out the state sets to pass to the store.
+ #
+ # Note: If the `event_map` is empty (which is the common case), we can do a
+ # much simpler calculation.
+ if event_map:
+ # The list of state sets to pass to the store, where each state set is a set
+ # of the event ids making up the state. This is similar to `state_sets`,
+ # except that (a) we only have event ids, not the complete
+ # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+ # unpersisted events and replaced them with the persisted events in
+ # their auth chain.
+ state_sets_ids = [] # type: List[Set[str]]
+
+ # For each state set, the unpersisted event IDs reachable (by their auth
+ # chain) from the events in that set.
+ unpersisted_set_ids = [] # type: List[Set[str]]
+
+ for state_set in state_sets:
+ set_ids = set() # type: Set[str]
+ state_sets_ids.append(set_ids)
+
+ unpersisted_ids = set() # type: Set[str]
+ unpersisted_set_ids.append(unpersisted_ids)
+
+ for event_id in state_set.values():
+ event_chain = events_to_auth_chain.get(event_id)
+ if event_chain is not None:
+ # We have an event in `event_map`. We add all the auth
+ # events that it references (that aren't also in `event_map`).
+ set_ids.update(e for e in event_chain if e not in event_map)
+
+ # We also add the full chain of unpersisted event IDs
+ # referenced by this state set, so that we can work out the
+ # auth chain difference of the unpersisted events.
+ unpersisted_ids.update(e for e in event_chain if e in event_map)
+ else:
+ set_ids.add(event_id)
+
+ # The auth chain difference of the unpersisted events of the state sets
+ # is calculated by taking the difference between the union and
+ # intersections.
+ union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+ intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+ difference_from_event_map = union - intersection # type: Collection[str]
+ else:
+ difference_from_event_map = ()
+ state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
difference = await state_res_store.get_auth_chain_difference(
- [set(state_set.values()) for state_set in state_sets]
+ room_id, state_sets_ids
)
+ difference.update(difference_from_event_map)
return difference
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index dfb4f87b8f..9097677648 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -57,6 +57,38 @@ class DeviceWorkerStore(SQLBaseStore):
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
+ async def count_devices_by_users(self, user_ids: Optional[List[str]] = None) -> int:
+ """Retrieve number of all devices of given users.
+ Only returns number of devices that are not marked as hidden.
+
+ Args:
+ user_ids: The IDs of the users which owns devices
+ Returns:
+ Number of devices of this users.
+ """
+
+ def count_devices_by_users_txn(txn, user_ids):
+ sql = """
+ SELECT count(*)
+ FROM devices
+ WHERE
+ hidden = '0' AND
+ """
+
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "user_id", user_ids
+ )
+
+ txn.execute(sql + clause, args)
+ return txn.fetchone()[0]
+
+ if not user_ids:
+ return 0
+
+ return await self.db_pool.runInteraction(
+ "count_devices_by_users", count_devices_by_users_txn, user_ids
+ )
+
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
"""Retrieve a device. Only returns devices that are not marked as
hidden.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 2e07c37340..ebffd89251 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -137,7 +137,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
- async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
+ async def get_auth_chain_difference(
+ self, room_id: str, state_sets: List[Set[str]]
+ ) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fedb8a6c26..ff96c34c2e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_external_id",
)
+ async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
+ """Look up external ids for the given user
+
+ Args:
+ mxid: the MXID to be looked up
+
+ Returns:
+ Tuples of (auth_provider, external_id)
+ """
+ res = await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ )
+ return [(r["auth_provider"], r["external_id"]) for r in res]
+
async def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
+ self.db_pool.updates.register_background_index_update(
+ "user_external_ids_user_id_idx",
+ index_name="user_external_ids_user_id_idx",
+ table="user_external_ids",
+ columns=["user_id"],
+ unique=False,
+ )
+
async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
new file mode 100644
index 0000000000..8f5e65aa71
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5825, 'user_external_ids_user_id_idx', '{}');
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 382f0cf3f0..9a873c8e8e 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -15,10 +15,12 @@
# limitations under the License.
import collections
+import inspect
import logging
from contextlib import contextmanager
from typing import (
Any,
+ Awaitable,
Callable,
Dict,
Hashable,
@@ -542,11 +544,11 @@ class DoneAwaitable:
raise StopIteration(self.value)
-def maybe_awaitable(value):
+def maybe_awaitable(value: Union[Awaitable[R], R]) -> Awaitable[R]:
"""Convert a value to an awaitable if not already an awaitable.
"""
-
- if hasattr(value, "__await__"):
+ if inspect.isawaitable(value):
+ assert isinstance(value, Awaitable)
return value
return DoneAwaitable(value)
diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py
index f73e95393c..a6ee9edaec 100644
--- a/synapse/util/distributor.py
+++ b/synapse/util/distributor.py
@@ -12,13 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import inspect
import logging
from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.util.async_helpers import maybe_awaitable
logger = logging.getLogger(__name__)
@@ -105,10 +105,7 @@ class Signal:
async def do(observer):
try:
- result = observer(*args, **kwargs)
- if inspect.isawaitable(result):
- result = await result
- return result
+ return await maybe_awaitable(observer(*args, **kwargs))
except Exception as e:
logger.warning(
"%s signal observer %s failed: %r", self.name, observer, e,
diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py
index 94b59afb38..1ee61851e4 100644
--- a/synapse/util/module_loader.py
+++ b/synapse/util/module_loader.py
@@ -15,28 +15,56 @@
import importlib
import importlib.util
+import itertools
+from typing import Any, Iterable, Tuple, Type
+
+import jsonschema
from synapse.config._base import ConfigError
+from synapse.config._util import json_error_to_config_error
-def load_module(provider):
+def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]:
""" Loads a synapse module with its config
- Take a dict with keys 'module' (the module name) and 'config'
- (the config dict).
+
+ Args:
+ provider: a dict with keys 'module' (the module name) and 'config'
+ (the config dict).
+ config_path: the path within the config file. This will be used as a basis
+ for any error message.
Returns
Tuple of (provider class, parsed config object)
"""
+
+ modulename = provider.get("module")
+ if not isinstance(modulename, str):
+ raise ConfigError(
+ "expected a string", path=itertools.chain(config_path, ("module",))
+ )
+
# We need to import the module, and then pick the class out of
# that, so we split based on the last dot.
- module, clz = provider["module"].rsplit(".", 1)
+ module, clz = modulename.rsplit(".", 1)
module = importlib.import_module(module)
provider_class = getattr(module, clz)
+ module_config = provider.get("config")
try:
- provider_config = provider_class.parse_config(provider.get("config"))
+ provider_config = provider_class.parse_config(module_config)
+ except jsonschema.ValidationError as e:
+ raise json_error_to_config_error(e, itertools.chain(config_path, ("config",)))
+ except ConfigError as e:
+ raise _wrap_config_error(
+ "Failed to parse config for module %r" % (modulename,),
+ prefix=itertools.chain(config_path, ("config",)),
+ e=e,
+ )
except Exception as e:
- raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e))
+ raise ConfigError(
+ "Failed to parse config for module %r" % (modulename,),
+ path=itertools.chain(config_path, ("config",)),
+ ) from e
return provider_class, provider_config
@@ -56,3 +84,27 @@ def load_python_module(location: str):
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) # type: ignore
return mod
+
+
+def _wrap_config_error(
+ msg: str, prefix: Iterable[str], e: ConfigError
+) -> "ConfigError":
+ """Wrap a relative ConfigError with a new path
+
+ This is useful when we have a ConfigError with a relative path due to a problem
+ parsing part of the config, and we now need to set it in context.
+ """
+ path = prefix
+ if e.path:
+ path = itertools.chain(prefix, e.path)
+
+ e1 = ConfigError(msg, path)
+
+ # ideally we would set the 'cause' of the new exception to the original exception;
+ # however now that we have merged the path into our own, the stringification of
+ # e will be incorrect, so instead we create a new exception with just the "msg"
+ # part.
+
+ e1.__cause__ = Exception(e.msg)
+ e1.__cause__.__cause__ = e.__cause__
+ return e1
|