From 30fba6210834a4ecd91badf0c8f3eb278b72e746 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 2 Dec 2020 11:09:24 -0500 Subject: Apply an IP range blacklist to push and key revocation requests. (#8821) Replaces the `federation_ip_range_blacklist` configuration setting with an `ip_range_blacklist` setting with wider scope. It now applies to: * Federation * Identity servers * Push notifications * Checking key validitity for third-party invite events The old `federation_ip_range_blacklist` setting is still honored if present, but with reduced scope (it only applies to federation and identity servers). --- synapse/config/federation.py | 40 +++++++++++++++++++++++++--------------- 1 file changed, 25 insertions(+), 15 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/federation.py b/synapse/config/federation.py index ffd8fca54e..27ccf61c3c 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -36,22 +36,30 @@ class FederationConfig(Config): for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True - self.federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", [] - ) + ip_range_blacklist = config.get("ip_range_blacklist", []) # Attempt to create an IPSet from the given ranges try: - self.federation_ip_range_blacklist = IPSet( - self.federation_ip_range_blacklist - ) - - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + self.ip_range_blacklist = IPSet(ip_range_blacklist) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e) + # Always blacklist 0.0.0.0, :: + self.ip_range_blacklist.update(["0.0.0.0", "::"]) + + # The federation_ip_range_blacklist is used for backwards-compatibility + # and only applies to federation and identity servers. If it is not given, + # default to ip_range_blacklist. + federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", ip_range_blacklist + ) + try: + self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) except Exception as e: raise ConfigError( "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e ) + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) federation_metrics_domains = config.get("federation_metrics_domains") or [] validate_config( @@ -76,17 +84,19 @@ class FederationConfig(Config): # - nyc.example.com # - syd.example.com - # Prevent federation requests from being sent to the following - # blacklist IP address CIDR ranges. If this option is not specified, or - # specified with an empty list, no ip range blacklist will be enforced. + # Prevent outgoing requests from being sent to the following blacklisted IP address + # CIDR ranges. If this option is not specified, or specified with an empty list, + # no IP range blacklist will be enforced. # - # As of Synapse v1.4.0 this option also affects any outbound requests to identity - # servers provided by user input. + # The blacklist applies to the outbound requests for federation, identity servers, + # push servers, and for checking key validitity for third-party invite events. # # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly # listed here, since they correspond to unroutable addresses.) # - federation_ip_range_blacklist: + # This option replaces federation_ip_range_blacklist in Synapse v1.24.0. + # + ip_range_blacklist: - '127.0.0.0/8' - '10.0.0.0/8' - '172.16.0.0/12' -- cgit 1.5.1 From 96358cb42410a4be6268eaa3ffec229c550208ea Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 4 Dec 2020 10:56:28 -0500 Subject: Add authentication to replication endpoints. (#8853) Authentication is done by checking a shared secret provided in the Synapse configuration file. --- changelog.d/8853.feature | 1 + docs/sample_config.yaml | 7 ++ docs/workers.md | 6 +- synapse/config/workers.py | 10 +++ synapse/replication/http/_base.py | 47 ++++++++-- tests/replication/test_auth.py | 119 ++++++++++++++++++++++++++ tests/replication/test_client_reader_shard.py | 9 +- 7 files changed, 184 insertions(+), 15 deletions(-) create mode 100644 changelog.d/8853.feature create mode 100644 tests/replication/test_auth.py (limited to 'synapse/config') diff --git a/changelog.d/8853.feature b/changelog.d/8853.feature new file mode 100644 index 0000000000..63c59f4ff2 --- /dev/null +++ b/changelog.d/8853.feature @@ -0,0 +1 @@ +Add optional HTTP authentication to replication endpoints. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 6dbccf5932..8712c580c0 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2589,6 +2589,13 @@ opentracing: # #run_background_tasks_on: worker1 +# A shared secret used by the replication APIs to authenticate HTTP requests +# from workers. +# +# By default this is unused and traffic is not authenticated. +# +#worker_replication_secret: "" + # Configuration for Redis when using workers. This *must* be enabled when # using workers (unless using old style direct TCP configuration). diff --git a/docs/workers.md b/docs/workers.md index c53d1bd2ff..efe97af31a 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -89,7 +89,8 @@ shared configuration file. Normally, only a couple of changes are needed to make an existing configuration file suitable for use with workers. First, you need to enable an "HTTP replication listener" for the main process; and secondly, you need to enable redis-based -replication. For example: +replication. Optionally, a shared secret can be used to authenticate HTTP +traffic between workers. For example: ```yaml @@ -103,6 +104,9 @@ listeners: resources: - names: [replication] +# Add a random shared secret to authenticate traffic. +worker_replication_secret: "" + redis: enabled: true ``` diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 57ab097eba..7ca9efec52 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -85,6 +85,9 @@ class WorkerConfig(Config): # The port on the main synapse for HTTP replication endpoint self.worker_replication_http_port = config.get("worker_replication_http_port") + # The shared secret used for authentication when connecting to the main synapse. + self.worker_replication_secret = config.get("worker_replication_secret", None) + self.worker_name = config.get("worker_name", self.worker_app) self.worker_main_http_uri = config.get("worker_main_http_uri", None) @@ -185,6 +188,13 @@ class WorkerConfig(Config): # data). If not provided this defaults to the main process. # #run_background_tasks_on: worker1 + + # A shared secret used by the replication APIs to authenticate HTTP requests + # from workers. + # + # By default this is unused and traffic is not authenticated. + # + #worker_replication_secret: "" """ def read_arguments(self, args): diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py index 2b3972cb14..1492ac922c 100644 --- a/synapse/replication/http/_base.py +++ b/synapse/replication/http/_base.py @@ -106,6 +106,25 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): assert self.METHOD in ("PUT", "POST", "GET") + self._replication_secret = None + if hs.config.worker.worker_replication_secret: + self._replication_secret = hs.config.worker.worker_replication_secret + + def _check_auth(self, request) -> None: + # Get the authorization header. + auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") + + if len(auth_headers) > 1: + raise RuntimeError("Too many Authorization headers.") + parts = auth_headers[0].split(b" ") + if parts[0] == b"Bearer" and len(parts) == 2: + received_secret = parts[1].decode("ascii") + if self._replication_secret == received_secret: + # Success! + return + + raise RuntimeError("Invalid Authorization header.") + @abc.abstractmethod async def _serialize_payload(**kwargs): """Static method that is called when creating a request. @@ -150,6 +169,12 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): outgoing_gauge = _pending_outgoing_requests.labels(cls.NAME) + replication_secret = None + if hs.config.worker.worker_replication_secret: + replication_secret = hs.config.worker.worker_replication_secret.encode( + "ascii" + ) + @trace(opname="outgoing_replication_request") @outgoing_gauge.track_inprogress() async def send_request(instance_name="master", **kwargs): @@ -202,6 +227,9 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # the master, and so whether we should clean up or not. while True: headers = {} # type: Dict[bytes, List[bytes]] + # Add an authorization header, if configured. + if replication_secret: + headers[b"Authorization"] = [b"Bearer " + replication_secret] inject_active_span_byte_dict(headers, None, check_destination=False) try: result = await request_func(uri, data, headers=headers) @@ -236,21 +264,19 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): """ url_args = list(self.PATH_ARGS) - handler = self._handle_request method = self.METHOD if self.CACHE: - handler = self._cached_handler # type: ignore url_args.append("txn_id") args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) pattern = re.compile("^/_synapse/replication/%s/%s$" % (self.NAME, args)) http_server.register_paths( - method, [pattern], handler, self.__class__.__name__, + method, [pattern], self._check_auth_and_handle, self.__class__.__name__, ) - def _cached_handler(self, request, txn_id, **kwargs): + def _check_auth_and_handle(self, request, **kwargs): """Called on new incoming requests when caching is enabled. Checks if there is a cached response for the request and returns that, otherwise calls `_handle_request` and caches its response. @@ -258,6 +284,15 @@ class ReplicationEndpoint(metaclass=abc.ABCMeta): # We just use the txn_id here, but we probably also want to use the # other PATH_ARGS as well. - assert self.CACHE + # Check the authorization headers before handling the request. + if self._replication_secret: + self._check_auth(request) + + if self.CACHE: + txn_id = kwargs.pop("txn_id") + + return self.response_cache.wrap( + txn_id, self._handle_request, request, **kwargs + ) - return self.response_cache.wrap(txn_id, self._handle_request, request, **kwargs) + return self._handle_request(request, **kwargs) diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py new file mode 100644 index 0000000000..fe9e4d5f9a --- /dev/null +++ b/tests/replication/test_auth.py @@ -0,0 +1,119 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import Tuple + +from synapse.http.site import SynapseRequest +from synapse.rest.client.v2_alpha import register + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, make_request +from tests.unittest import override_config + +logger = logging.getLogger(__name__) + + +class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): + """Test the authentication of HTTP calls between workers.""" + + servlets = [register.register_servlets] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # This isn't a real configuration option but is used to provide the main + # homeserver and worker homeserver different options. + main_replication_secret = config.pop("main_replication_secret", None) + if main_replication_secret: + config["worker_replication_secret"] = main_replication_secret + return self.setup_test_homeserver(config=config) + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.client_reader" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + + return config + + def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]: + """Run the actual test: + + 1. Create a worker homeserver. + 2. Start registration by providing a user/password. + 3. Complete registration by providing dummy auth (this hits the main synapse). + 4. Return the final request. + + """ + worker_hs = self.make_worker_hs("synapse.app.client_reader") + site = self._hs_to_site[worker_hs] + + request_1, channel_1 = make_request( + self.reactor, + site, + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) # type: SynapseRequest, FakeChannel + self.assertEqual(request_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + return make_request( + self.reactor, + site, + "POST", + "register", + {"auth": {"session": session, "type": "m.login.dummy"}}, + ) + + def test_no_auth(self): + """With no authentication the request should finish. + """ + request, channel = self._test_register() + self.assertEqual(request.code, 200) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + @override_config({"main_replication_secret": "my-secret"}) + def test_missing_auth(self): + """If the main process expects a secret that is not provided, an error results. + """ + request, channel = self._test_register() + self.assertEqual(request.code, 500) + + @override_config( + { + "main_replication_secret": "my-secret", + "worker_replication_secret": "wrong-secret", + } + ) + def test_unauthorized(self): + """If the main process receives the wrong secret, an error results. + """ + request, channel = self._test_register() + self.assertEqual(request.code, 500) + + @override_config({"worker_replication_secret": "my-secret"}) + def test_authorized(self): + """The request should finish when the worker provides the authentication header. + """ + request, channel = self._test_register() + self.assertEqual(request.code, 200) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 96801db473..fdaad3d8ad 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -14,27 +14,20 @@ # limitations under the License. import logging -from synapse.api.constants import LoginType from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker from tests.server import FakeChannel, make_request logger = logging.getLogger(__name__) class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): - """Base class for tests of the replication streams""" + """Test using one or more client readers for registration.""" servlets = [register.register_servlets] - def prepare(self, reactor, clock, hs): - self.recaptcha_checker = DummyRecaptchaChecker(hs) - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.client_reader" -- cgit 1.5.1 From 025fa06fc743bda7c4769b19991c40a1fb5d12ba Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 8 Dec 2020 14:03:08 +0000 Subject: Clarify config template comments (#8891) --- changelog.d/8891.doc | 1 + docs/sample_config.yaml | 12 ++++-------- synapse/config/emailconfig.py | 5 ++--- synapse/config/sso.py | 7 ++----- 4 files changed, 9 insertions(+), 16 deletions(-) create mode 100644 changelog.d/8891.doc (limited to 'synapse/config') diff --git a/changelog.d/8891.doc b/changelog.d/8891.doc new file mode 100644 index 0000000000..c3947fe7c2 --- /dev/null +++ b/changelog.d/8891.doc @@ -0,0 +1 @@ +Clarify comments around template directories in `sample_config.yaml`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 8712c580c0..68c8f4f0e2 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1879,11 +1879,8 @@ sso: # - https://my.custom.client/ # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # If not set, or the files named below are not found within the template + # directory, default templates from within the Synapse package will be used. # # Synapse will look for the following templates in this directory: # @@ -2113,9 +2110,8 @@ email: #validation_token_lifetime: 15m # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # Do not uncomment this setting unless you want to customise the templates. + # If not set, or the files named below are not found within the template + # directory, default templates from within the Synapse package will be used. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index cceffbfee2..7c8b64d84b 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -390,9 +390,8 @@ class EmailConfig(Config): #validation_token_lifetime: 15m # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # Do not uncomment this setting unless you want to customise the templates. + # If not set, or the files named below are not found within the template + # directory, default templates from within the Synapse package will be used. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 4427676167..93bbd40937 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -93,11 +93,8 @@ class SSOConfig(Config): # - https://my.custom.client/ # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # If not set, or the files named below are not found within the template + # directory, default templates from within the Synapse package will be used. # # Synapse will look for the following templates in this directory: # -- cgit 1.5.1 From ab7a24cc6bbffa5ba67b42731c45b1d4d33f3ae3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 8 Dec 2020 14:04:35 +0000 Subject: Better formatting for config errors from modules (#8874) The idea is that the parse_config method of extension modules can raise either a ConfigError or a JsonValidationError, and it will be magically turned into a legible error message. There's a few components to it: * Separating the "path" and the "message" parts of a ConfigError, so that we can fiddle with the path bit to turn it into an absolute path. * Generally improving the way ConfigErrors get printed. * Passing in the config path to load_module so that it can wrap any exceptions that get caught appropriately. --- changelog.d/8874.feature | 1 + synapse/app/homeserver.py | 46 ++++++++++++++++++++-- synapse/config/_base.py | 14 ++++++- synapse/config/_base.pyi | 7 +++- synapse/config/_util.py | 35 +++++++++++------ synapse/config/oidc_config.py | 2 +- synapse/config/password_auth_providers.py | 5 ++- synapse/config/repository.py | 6 ++- synapse/config/room_directory.py | 2 +- synapse/config/saml2_config.py | 2 +- synapse/config/spam_checker.py | 9 +++-- synapse/config/third_party_event_rules.py | 4 +- synapse/util/module_loader.py | 64 ++++++++++++++++++++++++++++--- 13 files changed, 160 insertions(+), 37 deletions(-) create mode 100644 changelog.d/8874.feature (limited to 'synapse/config') diff --git a/changelog.d/8874.feature b/changelog.d/8874.feature new file mode 100644 index 0000000000..720665ecac --- /dev/null +++ b/changelog.d/8874.feature @@ -0,0 +1 @@ +Improve the error messages printed as a result of configuration problems for extension modules. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b5465417f..bbb7407838 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -19,7 +19,7 @@ import gc import logging import os import sys -from typing import Iterable +from typing import Iterable, Iterator from twisted.application import service from twisted.internet import defer, reactor @@ -90,7 +90,7 @@ class SynapseHomeServer(HomeServer): tls = listener_config.tls site_tag = listener_config.http_options.tag if site_tag is None: - site_tag = port + site_tag = str(port) # We always include a health resource. resources = {"/health": HealthResource()} @@ -107,7 +107,10 @@ class SynapseHomeServer(HomeServer): logger.debug("Configuring additional resources: %r", additional_resources) module_api = self.get_module_api() for path, resmodule in additional_resources.items(): - handler_cls, config = load_module(resmodule) + handler_cls, config = load_module( + resmodule, + ("listeners", site_tag, "additional_resources", "<%s>" % (path,)), + ) handler = handler_cls(config, module_api) if IResource.providedBy(handler): resource = handler @@ -342,7 +345,10 @@ def setup(config_options): "Synapse Homeserver", config_options ) except ConfigError as e: - sys.stderr.write("\nERROR: %s\n" % (e,)) + sys.stderr.write("\n") + for f in format_config_error(e): + sys.stderr.write(f) + sys.stderr.write("\n") sys.exit(1) if not config: @@ -445,6 +451,38 @@ def setup(config_options): return hs +def format_config_error(e: ConfigError) -> Iterator[str]: + """ + Formats a config error neatly + + The idea is to format the immediate error, plus the "causes" of those errors, + hopefully in a way that makes sense to the user. For example: + + Error in configuration at 'oidc_config.user_mapping_provider.config.display_name_template': + Failed to parse config for module 'JinjaOidcMappingProvider': + invalid jinja template: + unexpected end of template, expected 'end of print statement'. + + Args: + e: the error to be formatted + + Returns: An iterator which yields string fragments to be formatted + """ + yield "Error in configuration" + + if e.path: + yield " at '%s'" % (".".join(e.path),) + + yield ":\n %s" % (e.msg,) + + e = e.__cause__ + indent = 1 + while e: + indent += 1 + yield ":\n%s%s" % (" " * indent, str(e)) + e = e.__cause__ + + class SynapseService(service.Service): """ A twisted Service class that will start synapse. Used to run synapse diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 85f65da4d9..2931a88207 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -23,7 +23,7 @@ import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Callable, List, MutableMapping, Optional +from typing import Any, Callable, Iterable, List, MutableMapping, Optional import attr import jinja2 @@ -32,7 +32,17 @@ import yaml class ConfigError(Exception): - pass + """Represents a problem parsing the configuration + + Args: + msg: A textual description of the error. + path: Where appropriate, an indication of where in the configuration + the problem lies. + """ + + def __init__(self, msg: str, path: Optional[Iterable[str]] = None): + self.msg = msg + self.path = path # We split these messages out to allow packages to override with package diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index b8faafa9bd..ed26e2fb60 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -1,4 +1,4 @@ -from typing import Any, List, Optional +from typing import Any, Iterable, List, Optional from synapse.config import ( api, @@ -35,7 +35,10 @@ from synapse.config import ( workers, ) -class ConfigError(Exception): ... +class ConfigError(Exception): + def __init__(self, msg: str, path: Optional[Iterable[str]] = None): + self.msg = msg + self.path = path MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str MISSING_REPORT_STATS_SPIEL: str diff --git a/synapse/config/_util.py b/synapse/config/_util.py index c74969a977..1bbe83c317 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -38,14 +38,27 @@ def validate_config( try: jsonschema.validate(config, json_schema) except jsonschema.ValidationError as e: - # copy `config_path` before modifying it. - path = list(config_path) - for p in list(e.path): - if isinstance(p, int): - path.append("" % p) - else: - path.append(str(p)) - - raise ConfigError( - "Unable to parse configuration: %s at %s" % (e.message, ".".join(path)) - ) + raise json_error_to_config_error(e, config_path) + + +def json_error_to_config_error( + e: jsonschema.ValidationError, config_path: Iterable[str] +) -> ConfigError: + """Converts a json validation error to a user-readable ConfigError + + Args: + e: the exception to be converted + config_path: the path within the config file. This will be used as a basis + for the error message. + + Returns: + a ConfigError + """ + # copy `config_path` before modifying it. + path = list(config_path) + for p in list(e.path): + if isinstance(p, int): + path.append("" % p) + else: + path.append(str(p)) + return ConfigError(e.message, path) diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 69d188341c..1abf8ed405 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -66,7 +66,7 @@ class OIDCConfig(Config): ( self.oidc_user_mapping_provider_class, self.oidc_user_mapping_provider_config, - ) = load_module(ump_config) + ) = load_module(ump_config, ("oidc_config", "user_mapping_provider")) # Ensure loaded user mapping module has defined all necessary methods required_methods = [ diff --git a/synapse/config/password_auth_providers.py b/synapse/config/password_auth_providers.py index 4fda8ae987..85d07c4f8f 100644 --- a/synapse/config/password_auth_providers.py +++ b/synapse/config/password_auth_providers.py @@ -36,7 +36,7 @@ class PasswordAuthProviderConfig(Config): providers.append({"module": LDAP_PROVIDER, "config": ldap_config}) providers.extend(config.get("password_providers") or []) - for provider in providers: + for i, provider in enumerate(providers): mod_name = provider["module"] # This is for backwards compat when the ldap auth provider resided @@ -45,7 +45,8 @@ class PasswordAuthProviderConfig(Config): mod_name = LDAP_PROVIDER (provider_class, provider_config) = load_module( - {"module": mod_name, "config": provider["config"]} + {"module": mod_name, "config": provider["config"]}, + ("password_providers", "" % i), ) self.password_providers.append((provider_class, provider_config)) diff --git a/synapse/config/repository.py b/synapse/config/repository.py index ba1e9d2361..17ce9145ef 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -142,7 +142,7 @@ class ContentRepositoryConfig(Config): # them to be started. self.media_storage_providers = [] # type: List[tuple] - for provider_config in storage_providers: + for i, provider_config in enumerate(storage_providers): # We special case the module "file_system" so as not to need to # expose FileStorageProviderBackend if provider_config["module"] == "file_system": @@ -151,7 +151,9 @@ class ContentRepositoryConfig(Config): ".FileStorageProviderBackend" ) - provider_class, parsed_config = load_module(provider_config) + provider_class, parsed_config = load_module( + provider_config, ("media_storage_providers", "" % i) + ) wrapper_config = MediaStorageProviderConfig( provider_config.get("store_local", False), diff --git a/synapse/config/room_directory.py b/synapse/config/room_directory.py index 92e1b67528..9a3e1c3e7d 100644 --- a/synapse/config/room_directory.py +++ b/synapse/config/room_directory.py @@ -180,7 +180,7 @@ class _RoomDirectoryRule: self._alias_regex = glob_to_regex(alias) self._room_id_regex = glob_to_regex(room_id) except Exception as e: - raise ConfigError("Failed to parse glob into regex: %s", e) + raise ConfigError("Failed to parse glob into regex") from e def matches(self, user_id, room_id, aliases): """Tests if this rule matches the given user_id, room_id and aliases. diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index c1b8e98ae0..7b97d4f114 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -125,7 +125,7 @@ class SAML2Config(Config): ( self.saml2_user_mapping_provider_class, self.saml2_user_mapping_provider_config, - ) = load_module(ump_dict) + ) = load_module(ump_dict, ("saml2_config", "user_mapping_provider")) # Ensure loaded user mapping module has defined all necessary methods # Note parse_config() is already checked during the call to load_module diff --git a/synapse/config/spam_checker.py b/synapse/config/spam_checker.py index 3d067d29db..3d05abc158 100644 --- a/synapse/config/spam_checker.py +++ b/synapse/config/spam_checker.py @@ -33,13 +33,14 @@ class SpamCheckerConfig(Config): # spam checker, and thus was simply a dictionary with module # and config keys. Support this old behaviour by checking # to see if the option resolves to a dictionary - self.spam_checkers.append(load_module(spam_checkers)) + self.spam_checkers.append(load_module(spam_checkers, ("spam_checker",))) elif isinstance(spam_checkers, list): - for spam_checker in spam_checkers: + for i, spam_checker in enumerate(spam_checkers): + config_path = ("spam_checker", "" % i) if not isinstance(spam_checker, dict): - raise ConfigError("spam_checker syntax is incorrect") + raise ConfigError("expected a mapping", config_path) - self.spam_checkers.append(load_module(spam_checker)) + self.spam_checkers.append(load_module(spam_checker, config_path)) else: raise ConfigError("spam_checker syntax is incorrect") diff --git a/synapse/config/third_party_event_rules.py b/synapse/config/third_party_event_rules.py index 10a99c792e..c04e1c4e07 100644 --- a/synapse/config/third_party_event_rules.py +++ b/synapse/config/third_party_event_rules.py @@ -26,7 +26,9 @@ class ThirdPartyRulesConfig(Config): provider = config.get("third_party_event_rules", None) if provider is not None: - self.third_party_event_rules = load_module(provider) + self.third_party_event_rules = load_module( + provider, ("third_party_event_rules",) + ) def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/util/module_loader.py b/synapse/util/module_loader.py index 94b59afb38..1ee61851e4 100644 --- a/synapse/util/module_loader.py +++ b/synapse/util/module_loader.py @@ -15,28 +15,56 @@ import importlib import importlib.util +import itertools +from typing import Any, Iterable, Tuple, Type + +import jsonschema from synapse.config._base import ConfigError +from synapse.config._util import json_error_to_config_error -def load_module(provider): +def load_module(provider: dict, config_path: Iterable[str]) -> Tuple[Type, Any]: """ Loads a synapse module with its config - Take a dict with keys 'module' (the module name) and 'config' - (the config dict). + + Args: + provider: a dict with keys 'module' (the module name) and 'config' + (the config dict). + config_path: the path within the config file. This will be used as a basis + for any error message. Returns Tuple of (provider class, parsed config object) """ + + modulename = provider.get("module") + if not isinstance(modulename, str): + raise ConfigError( + "expected a string", path=itertools.chain(config_path, ("module",)) + ) + # We need to import the module, and then pick the class out of # that, so we split based on the last dot. - module, clz = provider["module"].rsplit(".", 1) + module, clz = modulename.rsplit(".", 1) module = importlib.import_module(module) provider_class = getattr(module, clz) + module_config = provider.get("config") try: - provider_config = provider_class.parse_config(provider.get("config")) + provider_config = provider_class.parse_config(module_config) + except jsonschema.ValidationError as e: + raise json_error_to_config_error(e, itertools.chain(config_path, ("config",))) + except ConfigError as e: + raise _wrap_config_error( + "Failed to parse config for module %r" % (modulename,), + prefix=itertools.chain(config_path, ("config",)), + e=e, + ) except Exception as e: - raise ConfigError("Failed to parse config for %r: %s" % (provider["module"], e)) + raise ConfigError( + "Failed to parse config for module %r" % (modulename,), + path=itertools.chain(config_path, ("config",)), + ) from e return provider_class, provider_config @@ -56,3 +84,27 @@ def load_python_module(location: str): mod = importlib.util.module_from_spec(spec) spec.loader.exec_module(mod) # type: ignore return mod + + +def _wrap_config_error( + msg: str, prefix: Iterable[str], e: ConfigError +) -> "ConfigError": + """Wrap a relative ConfigError with a new path + + This is useful when we have a ConfigError with a relative path due to a problem + parsing part of the config, and we now need to set it in context. + """ + path = prefix + if e.path: + path = itertools.chain(prefix, e.path) + + e1 = ConfigError(msg, path) + + # ideally we would set the 'cause' of the new exception to the original exception; + # however now that we have merged the path into our own, the stringification of + # e will be incorrect, so instead we create a new exception with just the "msg" + # part. + + e1.__cause__ = Exception(e.msg) + e1.__cause__.__cause__ = e.__cause__ + return e1 -- cgit 1.5.1 From 344ab0b53abc0291d79882f8bdc1a853f7495ed4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 9 Dec 2020 13:56:06 -0500 Subject: Default to blacklisting reserved IP ranges and add a whitelist. (#8870) This defaults `ip_range_blacklist` to reserved IP ranges and also adds an `ip_range_whitelist` setting to override it. --- INSTALL.md | 7 ++- UPGRADE.rst | 21 ++++++++ changelog.d/8821.bugfix | 2 +- changelog.d/8870.bugfix | 1 + docs/sample_config.yaml | 66 ++++++++++++++++-------- synapse/config/federation.py | 59 ++++------------------ synapse/config/repository.py | 20 +++----- synapse/config/server.py | 80 ++++++++++++++++++++++++++++++ synapse/server.py | 3 +- tests/replication/test_multi_media_repo.py | 2 +- 10 files changed, 172 insertions(+), 89 deletions(-) create mode 100644 changelog.d/8870.bugfix (limited to 'synapse/config') diff --git a/INSTALL.md b/INSTALL.md index eaeb690092..eb5f506de9 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -557,10 +557,9 @@ This is critical from a security perspective to stop arbitrary Matrix users spidering 'internal' URLs on your network. At the very least we recommend that your loopback and RFC1918 IP addresses are blacklisted. -This also requires the optional `lxml` and `netaddr` python dependencies to be -installed. This in turn requires the `libxml2` library to be available - on -Debian/Ubuntu this means `apt-get install libxml2-dev`, or equivalent for -your OS. +This also requires the optional `lxml` python dependency to be installed. This +in turn requires the `libxml2` library to be available - on Debian/Ubuntu this +means `apt-get install libxml2-dev`, or equivalent for your OS. # Troubleshooting Installation diff --git a/UPGRADE.rst b/UPGRADE.rst index 6825b567e9..54a40bd42f 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -75,6 +75,27 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.25.0 +==================== + +Blacklisting IP ranges +---------------------- + +Synapse v1.25.0 includes new settings, ``ip_range_blacklist`` and +``ip_range_whitelist``, for controlling outgoing requests from Synapse for federation, +identity servers, push, and for checking key validity for third-party invite events. +The previous setting, ``federation_ip_range_blacklist``, is deprecated. The new +``ip_range_blacklist`` defaults to private IP ranges if it is not defined. + +If you have never customised ``federation_ip_range_blacklist`` it is recommended +that you remove that setting. + +If you have customised ``federation_ip_range_blacklist`` you should update the +setting name to ``ip_range_blacklist``. + +If you have a custom push server that is reached via private IP space you may +need to customise ``ip_range_blacklist`` or ``ip_range_whitelist``. + Upgrading to v1.24.0 ==================== diff --git a/changelog.d/8821.bugfix b/changelog.d/8821.bugfix index 8ddfbf31ce..39f53174ad 100644 --- a/changelog.d/8821.bugfix +++ b/changelog.d/8821.bugfix @@ -1 +1 @@ -Apply the `federation_ip_range_blacklist` to push and key revocation requests. +Apply an IP range blacklist to push and key revocation requests. diff --git a/changelog.d/8870.bugfix b/changelog.d/8870.bugfix new file mode 100644 index 0000000000..39f53174ad --- /dev/null +++ b/changelog.d/8870.bugfix @@ -0,0 +1 @@ +Apply an IP range blacklist to push and key revocation requests. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 68c8f4f0e2..f196781c1c 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -144,6 +144,35 @@ pid_file: DATADIR/homeserver.pid # #enable_search: false +# Prevent outgoing requests from being sent to the following blacklisted IP address +# CIDR ranges. If this option is not specified then it defaults to private IP +# address ranges (see the example below). +# +# The blacklist applies to the outbound requests for federation, identity servers, +# push servers, and for checking key validity for third-party invite events. +# +# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly +# listed here, since they correspond to unroutable addresses.) +# +# This option replaces federation_ip_range_blacklist in Synapse v1.25.0. +# +#ip_range_blacklist: +# - '127.0.0.0/8' +# - '10.0.0.0/8' +# - '172.16.0.0/12' +# - '192.168.0.0/16' +# - '100.64.0.0/10' +# - '192.0.0.0/24' +# - '169.254.0.0/16' +# - '198.18.0.0/15' +# - '192.0.2.0/24' +# - '198.51.100.0/24' +# - '203.0.113.0/24' +# - '224.0.0.0/4' +# - '::1/128' +# - 'fe80::/10' +# - 'fc00::/7' + # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -642,28 +671,17 @@ acme: # - nyc.example.com # - syd.example.com -# Prevent outgoing requests from being sent to the following blacklisted IP address -# CIDR ranges. If this option is not specified, or specified with an empty list, -# no IP range blacklist will be enforced. +# List of IP address CIDR ranges that should be allowed for federation, +# identity servers, push servers, and for checking key validity for +# third-party invite events. This is useful for specifying exceptions to +# wide-ranging blacklisted target IP ranges - e.g. for communication with +# a push server only visible in your network. # -# The blacklist applies to the outbound requests for federation, identity servers, -# push servers, and for checking key validitity for third-party invite events. -# -# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly -# listed here, since they correspond to unroutable addresses.) -# -# This option replaces federation_ip_range_blacklist in Synapse v1.24.0. +# This whitelist overrides ip_range_blacklist and defaults to an empty +# list. # -ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' +#ip_range_whitelist: +# - '192.168.1.1' # Report prometheus metrics on the age of PDUs being sent to and received from # the following domains. This can be used to give an idea of "delay" on inbound @@ -955,9 +973,15 @@ media_store_path: "DATADIR/media_store" # - '172.16.0.0/12' # - '192.168.0.0/16' # - '100.64.0.0/10' +# - '192.0.0.0/24' # - '169.254.0.0/16' +# - '198.18.0.0/15' +# - '192.0.2.0/24' +# - '198.51.100.0/24' +# - '203.0.113.0/24' +# - '224.0.0.0/4' # - '::1/128' -# - 'fe80::/64' +# - 'fe80::/10' # - 'fc00::/7' # List of IP address CIDR ranges that the URL preview spider is allowed diff --git a/synapse/config/federation.py b/synapse/config/federation.py index 27ccf61c3c..a03a419e23 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -12,12 +12,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from typing import Optional -from netaddr import IPSet - -from synapse.config._base import Config, ConfigError +from synapse.config._base import Config from synapse.config._util import validate_config @@ -36,31 +33,6 @@ class FederationConfig(Config): for domain in federation_domain_whitelist: self.federation_domain_whitelist[domain] = True - ip_range_blacklist = config.get("ip_range_blacklist", []) - - # Attempt to create an IPSet from the given ranges - try: - self.ip_range_blacklist = IPSet(ip_range_blacklist) - except Exception as e: - raise ConfigError("Invalid range(s) provided in ip_range_blacklist: %s" % e) - # Always blacklist 0.0.0.0, :: - self.ip_range_blacklist.update(["0.0.0.0", "::"]) - - # The federation_ip_range_blacklist is used for backwards-compatibility - # and only applies to federation and identity servers. If it is not given, - # default to ip_range_blacklist. - federation_ip_range_blacklist = config.get( - "federation_ip_range_blacklist", ip_range_blacklist - ) - try: - self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) - except Exception as e: - raise ConfigError( - "Invalid range(s) provided in federation_ip_range_blacklist: %s" % e - ) - # Always blacklist 0.0.0.0, :: - self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - federation_metrics_domains = config.get("federation_metrics_domains") or [] validate_config( _METRICS_FOR_DOMAINS_SCHEMA, @@ -84,28 +56,17 @@ class FederationConfig(Config): # - nyc.example.com # - syd.example.com - # Prevent outgoing requests from being sent to the following blacklisted IP address - # CIDR ranges. If this option is not specified, or specified with an empty list, - # no IP range blacklist will be enforced. - # - # The blacklist applies to the outbound requests for federation, identity servers, - # push servers, and for checking key validitity for third-party invite events. - # - # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly - # listed here, since they correspond to unroutable addresses.) + # List of IP address CIDR ranges that should be allowed for federation, + # identity servers, push servers, and for checking key validity for + # third-party invite events. This is useful for specifying exceptions to + # wide-ranging blacklisted target IP ranges - e.g. for communication with + # a push server only visible in your network. # - # This option replaces federation_ip_range_blacklist in Synapse v1.24.0. + # This whitelist overrides ip_range_blacklist and defaults to an empty + # list. # - ip_range_blacklist: - - '127.0.0.0/8' - - '10.0.0.0/8' - - '172.16.0.0/12' - - '192.168.0.0/16' - - '100.64.0.0/10' - - '169.254.0.0/16' - - '::1/128' - - 'fe80::/64' - - 'fc00::/7' + #ip_range_whitelist: + # - '192.168.1.1' # Report prometheus metrics on the age of PDUs being sent to and received from # the following domains. This can be used to give an idea of "delay" on inbound diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 17ce9145ef..850ac3ebd6 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -17,6 +17,9 @@ import os from collections import namedtuple from typing import Dict, List +from netaddr import IPSet + +from synapse.config.server import DEFAULT_IP_RANGE_BLACKLIST from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module @@ -184,9 +187,6 @@ class ContentRepositoryConfig(Config): "to work" ) - # netaddr is a dependency for url_preview - from netaddr import IPSet - self.url_preview_ip_range_blacklist = IPSet( config["url_preview_ip_range_blacklist"] ) @@ -215,6 +215,10 @@ class ContentRepositoryConfig(Config): # strip final NL formatted_thumbnail_sizes = formatted_thumbnail_sizes[:-1] + ip_range_blacklist = "\n".join( + " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST + ) + return ( r""" ## Media Store ## @@ -285,15 +289,7 @@ class ContentRepositoryConfig(Config): # you uncomment the following list as a starting point. # #url_preview_ip_range_blacklist: - # - '127.0.0.0/8' - # - '10.0.0.0/8' - # - '172.16.0.0/12' - # - '192.168.0.0/16' - # - '100.64.0.0/10' - # - '169.254.0.0/16' - # - '::1/128' - # - 'fe80::/64' - # - 'fc00::/7' +%(ip_range_blacklist)s # List of IP address CIDR ranges that the URL preview spider is allowed # to access even if they are specified in url_preview_ip_range_blacklist. diff --git a/synapse/config/server.py b/synapse/config/server.py index 85aa49c02d..f3815e5add 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -23,6 +23,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set import attr import yaml +from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.http.endpoint import parse_and_validate_server_name @@ -39,6 +40,34 @@ logger = logging.Logger(__name__) # in the list. DEFAULT_BIND_ADDRESSES = ["::", "0.0.0.0"] +DEFAULT_IP_RANGE_BLACKLIST = [ + # Localhost + "127.0.0.0/8", + # Private networks. + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + # Carrier grade NAT. + "100.64.0.0/10", + # Address registry. + "192.0.0.0/24", + # Link-local networks. + "169.254.0.0/16", + # Testing networks. + "198.18.0.0/15", + "192.0.2.0/24", + "198.51.100.0/24", + "203.0.113.0/24", + # Multicast. + "224.0.0.0/4", + # Localhost + "::1/128", + # Link-local addresses. + "fe80::/10", + # Unique local addresses. + "fc00::/7", +] + DEFAULT_ROOM_VERSION = "6" ROOM_COMPLEXITY_TOO_GREAT = ( @@ -256,6 +285,38 @@ class ServerConfig(Config): # due to resource constraints self.admin_contact = config.get("admin_contact", None) + ip_range_blacklist = config.get( + "ip_range_blacklist", DEFAULT_IP_RANGE_BLACKLIST + ) + + # Attempt to create an IPSet from the given ranges + try: + self.ip_range_blacklist = IPSet(ip_range_blacklist) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_blacklist.") from e + # Always blacklist 0.0.0.0, :: + self.ip_range_blacklist.update(["0.0.0.0", "::"]) + + try: + self.ip_range_whitelist = IPSet(config.get("ip_range_whitelist", ())) + except Exception as e: + raise ConfigError("Invalid range(s) provided in ip_range_whitelist.") from e + + # The federation_ip_range_blacklist is used for backwards-compatibility + # and only applies to federation and identity servers. If it is not given, + # default to ip_range_blacklist. + federation_ip_range_blacklist = config.get( + "federation_ip_range_blacklist", ip_range_blacklist + ) + try: + self.federation_ip_range_blacklist = IPSet(federation_ip_range_blacklist) + except Exception as e: + raise ConfigError( + "Invalid range(s) provided in federation_ip_range_blacklist." + ) from e + # Always blacklist 0.0.0.0, :: + self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) + if self.public_baseurl is not None: if self.public_baseurl[-1] != "/": self.public_baseurl += "/" @@ -561,6 +622,10 @@ class ServerConfig(Config): def generate_config_section( self, server_name, data_dir_path, open_private_ports, listeners, **kwargs ): + ip_range_blacklist = "\n".join( + " # - '%s'" % ip for ip in DEFAULT_IP_RANGE_BLACKLIST + ) + _, bind_port = parse_and_validate_server_name(server_name) if bind_port is not None: unsecure_port = bind_port - 400 @@ -752,6 +817,21 @@ class ServerConfig(Config): # #enable_search: false + # Prevent outgoing requests from being sent to the following blacklisted IP address + # CIDR ranges. If this option is not specified then it defaults to private IP + # address ranges (see the example below). + # + # The blacklist applies to the outbound requests for federation, identity servers, + # push servers, and for checking key validity for third-party invite events. + # + # (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly + # listed here, since they correspond to unroutable addresses.) + # + # This option replaces federation_ip_range_blacklist in Synapse v1.25.0. + # + #ip_range_blacklist: +%(ip_range_blacklist)s + # List of ports that Synapse should listen on, their purpose and their # configuration. # diff --git a/synapse/server.py b/synapse/server.py index 9af759626e..043810ad31 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -370,10 +370,11 @@ class HomeServer(metaclass=abc.ABCMeta): def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: """ An HTTP client that uses configured HTTP(S) proxies and blacklists IPs - based on the IP range blacklist. + based on the IP range blacklist/whitelist. """ return SimpleHttpClient( self, + ip_whitelist=self.config.ip_range_whitelist, ip_blacklist=self.config.ip_range_blacklist, http_proxy=os.getenvb(b"http_proxy"), https_proxy=os.getenvb(b"HTTPS_PROXY"), diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 48b574ccbe..83afd9fd2f 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.user_id = self.register_user("user", "pass") self.access_token = self.login("user", "pass") - self.reactor.lookups["example.com"] = "127.0.0.2" + self.reactor.lookups["example.com"] = "1.2.3.4" def default_config(self): conf = super().default_config() -- cgit 1.5.1 From 1619802228033455ff6e5863c52556996b38e8c6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 14 Dec 2020 14:19:47 -0500 Subject: Various clean-ups to the logging context code (#8935) --- changelog.d/8916.misc | 2 +- changelog.d/8935.misc | 1 + synapse/config/logger.py | 2 +- synapse/http/site.py | 3 +-- synapse/logging/context.py | 24 +++++------------------- synapse/metrics/background_process_metrics.py | 7 +++---- synapse/replication/tcp/protocol.py | 3 +-- tests/handlers/test_federation.py | 6 +++--- tests/logging/test_terse_json.py | 7 ++----- tests/test_federation.py | 2 +- tests/test_utils/logging_setup.py | 2 +- 11 files changed, 20 insertions(+), 39 deletions(-) create mode 100644 changelog.d/8935.misc (limited to 'synapse/config') diff --git a/changelog.d/8916.misc b/changelog.d/8916.misc index c71ef480e6..bf94135fd5 100644 --- a/changelog.d/8916.misc +++ b/changelog.d/8916.misc @@ -1 +1 @@ -Improve structured logging tests. +Various clean-ups to the structured logging and logging context code. diff --git a/changelog.d/8935.misc b/changelog.d/8935.misc new file mode 100644 index 0000000000..bf94135fd5 --- /dev/null +++ b/changelog.d/8935.misc @@ -0,0 +1 @@ +Various clean-ups to the structured logging and logging context code. diff --git a/synapse/config/logger.py b/synapse/config/logger.py index d4e887a3e0..4df3f93c1c 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -206,7 +206,7 @@ def _setup_stdlib_logging(config, log_config_path, logBeginner: LogBeginner) -> # filter options, but care must when using e.g. MemoryHandler to buffer # writes. - log_context_filter = LoggingContextFilter(request="") + log_context_filter = LoggingContextFilter() log_metadata_filter = MetadataFilter({"server_name": config.server_name}) old_factory = logging.getLogRecordFactory() diff --git a/synapse/http/site.py b/synapse/http/site.py index 5f0581dc3f..5a5790831b 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -128,8 +128,7 @@ class SynapseRequest(Request): # create a LogContext for this request request_id = self.get_request_id() - logcontext = self.logcontext = LoggingContext(request_id) - logcontext.request = request_id + self.logcontext = LoggingContext(request_id, request=request_id) # override the Server header which is set by twisted self.setHeader("Server", self.site.server_version_string) diff --git a/synapse/logging/context.py b/synapse/logging/context.py index ca0c774cc5..a507a83e93 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -203,10 +203,6 @@ class _Sentinel: def copy_to(self, record): pass - def copy_to_twisted_log_entry(self, record): - record["request"] = None - record["scope"] = None - def start(self, rusage: "Optional[resource._RUsage]"): pass @@ -372,13 +368,6 @@ class LoggingContext: # we also track the current scope: record.scope = self.scope - def copy_to_twisted_log_entry(self, record) -> None: - """ - Copy logging fields from this context to a Twisted log record. - """ - record["request"] = self.request - record["scope"] = self.scope - def start(self, rusage: "Optional[resource._RUsage]") -> None: """ Record that this logcontext is currently running. @@ -542,13 +531,10 @@ class LoggingContext: class LoggingContextFilter(logging.Filter): """Logging filter that adds values from the current logging context to each record. - Args: - **defaults: Default values to avoid formatters complaining about - missing fields """ - def __init__(self, **defaults) -> None: - self.defaults = defaults + def __init__(self, request: str = ""): + self._default_request = request def filter(self, record) -> Literal[True]: """Add each fields from the logging contexts to the record. @@ -556,14 +542,14 @@ class LoggingContextFilter(logging.Filter): True to include the record in the log output. """ context = current_context() - for key, value in self.defaults.items(): - setattr(record, key, value) + record.request = self._default_request # context should never be None, but if it somehow ends up being, then # we end up in a death spiral of infinite loops, so let's check, for # robustness' sake. if context is not None: - context.copy_to(record) + # Logging is interested in the request. + record.request = context.request return True diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 76b7decf26..70e0fa45d9 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -199,8 +199,7 @@ def run_as_background_process(desc: str, func, *args, bg_start_span=True, **kwar _background_process_start_count.labels(desc).inc() _background_process_in_flight_count.labels(desc).inc() - with BackgroundProcessLoggingContext(desc) as context: - context.request = "%s-%i" % (desc, count) + with BackgroundProcessLoggingContext(desc, "%s-%i" % (desc, count)) as context: try: ctx = noop_context_manager() if bg_start_span: @@ -244,8 +243,8 @@ class BackgroundProcessLoggingContext(LoggingContext): __slots__ = ["_proc"] - def __init__(self, name: str): - super().__init__(name) + def __init__(self, name: str, request: Optional[str] = None): + super().__init__(name, request=request) self._proc = _BackgroundProcess(name, self) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index a509e599c2..804da994ea 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -172,8 +172,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): # a logcontext which we use for processing incoming commands. We declare it as a # background process so that the CPU stats get reported to prometheus. ctx_name = "replication-conn-%s" % self.conn_id - self._logging_context = BackgroundProcessLoggingContext(ctx_name) - self._logging_context.request = ctx_name + self._logging_context = BackgroundProcessLoggingContext(ctx_name, ctx_name) def connectionMade(self): logger.info("[%s] Connection established", self.id()) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index d0452e1490..0b24b89a2e 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase): room_version, ) - with LoggingContext(request="send_rejected"): + with LoggingContext("send_rejected"): d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev) self.get_success(d) @@ -198,7 +198,7 @@ class FederationTestCase(unittest.HomeserverTestCase): # the auth code requires that a signature exists, but doesn't check that # signature... go figure. join_event.signatures[other_server] = {"x": "y"} - with LoggingContext(request="send_join"): + with LoggingContext("send_join"): d = run_in_background( self.handler.on_send_join_request, other_server, join_event ) diff --git a/tests/logging/test_terse_json.py b/tests/logging/test_terse_json.py index f6e7e5fdaa..48a74e2eee 100644 --- a/tests/logging/test_terse_json.py +++ b/tests/logging/test_terse_json.py @@ -117,11 +117,10 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): """ handler = logging.StreamHandler(self.output) handler.setFormatter(JsonFormatter()) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) logger = self.get_logger(handler) - with LoggingContext() as context_one: - context_one.request = "test" + with LoggingContext(request="test"): logger.info("Hello there, %s!", "wally") log = self.get_log_line() @@ -132,9 +131,7 @@ class TerseJsonTestCase(LoggerCleanupMixin, TestCase): "level", "namespace", "request", - "scope", ] self.assertCountEqual(log.keys(), expected_log_keys) self.assertEqual(log["log"], "Hello there, wally!") self.assertEqual(log["request"], "test") - self.assertIsNone(log["scope"]) diff --git a/tests/test_federation.py b/tests/test_federation.py index fa45f8b3b7..fc9aab32d0 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -134,7 +134,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): } ) - with LoggingContext(request="lying_event"): + with LoggingContext(): failure = self.get_failure( self.handler.on_receive_pdu( "test.serv", lying_event, sent_to_us_directly=True diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py index fdfb840b62..52ae5c5713 100644 --- a/tests/test_utils/logging_setup.py +++ b/tests/test_utils/logging_setup.py @@ -48,7 +48,7 @@ def setup_logging(): handler = ToTwistedHandler() formatter = logging.Formatter(log_format) handler.setFormatter(formatter) - handler.addFilter(LoggingContextFilter(request="")) + handler.addFilter(LoggingContextFilter()) root_logger.addHandler(handler) log_level = os.environ.get("SYNAPSE_TEST_LOG_LEVEL", "ERROR") -- cgit 1.5.1 From 44b7d4c6d6d5e8d78bd0154b407defea4a35aebd Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 16 Dec 2020 14:40:47 -0500 Subject: Fix the sample config location for the ip_range_whitelist setting. (#8954) Move it from the federation section to the server section to match ip_range_blacklist. --- changelog.d/8954.feature | 1 + docs/sample_config.yaml | 24 ++++++++++++------------ synapse/config/federation.py | 12 ------------ synapse/config/server.py | 12 ++++++++++++ 4 files changed, 25 insertions(+), 24 deletions(-) create mode 100644 changelog.d/8954.feature (limited to 'synapse/config') diff --git a/changelog.d/8954.feature b/changelog.d/8954.feature new file mode 100644 index 0000000000..39f53174ad --- /dev/null +++ b/changelog.d/8954.feature @@ -0,0 +1 @@ +Apply an IP range blacklist to push and key revocation requests. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index f196781c1c..75a01094d5 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -173,6 +173,18 @@ pid_file: DATADIR/homeserver.pid # - 'fe80::/10' # - 'fc00::/7' +# List of IP address CIDR ranges that should be allowed for federation, +# identity servers, push servers, and for checking key validity for +# third-party invite events. This is useful for specifying exceptions to +# wide-ranging blacklisted target IP ranges - e.g. for communication with +# a push server only visible in your network. +# +# This whitelist overrides ip_range_blacklist and defaults to an empty +# list. +# +#ip_range_whitelist: +# - '192.168.1.1' + # List of ports that Synapse should listen on, their purpose and their # configuration. # @@ -671,18 +683,6 @@ acme: # - nyc.example.com # - syd.example.com -# List of IP address CIDR ranges that should be allowed for federation, -# identity servers, push servers, and for checking key validity for -# third-party invite events. This is useful for specifying exceptions to -# wide-ranging blacklisted target IP ranges - e.g. for communication with -# a push server only visible in your network. -# -# This whitelist overrides ip_range_blacklist and defaults to an empty -# list. -# -#ip_range_whitelist: -# - '192.168.1.1' - # Report prometheus metrics on the age of PDUs being sent to and received from # the following domains. This can be used to give an idea of "delay" on inbound # and outbound federation, though be aware that any delay can be due to problems diff --git a/synapse/config/federation.py b/synapse/config/federation.py index a03a419e23..9f3c57e6a1 100644 --- a/synapse/config/federation.py +++ b/synapse/config/federation.py @@ -56,18 +56,6 @@ class FederationConfig(Config): # - nyc.example.com # - syd.example.com - # List of IP address CIDR ranges that should be allowed for federation, - # identity servers, push servers, and for checking key validity for - # third-party invite events. This is useful for specifying exceptions to - # wide-ranging blacklisted target IP ranges - e.g. for communication with - # a push server only visible in your network. - # - # This whitelist overrides ip_range_blacklist and defaults to an empty - # list. - # - #ip_range_whitelist: - # - '192.168.1.1' - # Report prometheus metrics on the age of PDUs being sent to and received from # the following domains. This can be used to give an idea of "delay" on inbound # and outbound federation, though be aware that any delay can be due to problems diff --git a/synapse/config/server.py b/synapse/config/server.py index f3815e5add..7242a4aa8e 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -832,6 +832,18 @@ class ServerConfig(Config): #ip_range_blacklist: %(ip_range_blacklist)s + # List of IP address CIDR ranges that should be allowed for federation, + # identity servers, push servers, and for checking key validity for + # third-party invite events. This is useful for specifying exceptions to + # wide-ranging blacklisted target IP ranges - e.g. for communication with + # a push server only visible in your network. + # + # This whitelist overrides ip_range_blacklist and defaults to an empty + # list. + # + #ip_range_whitelist: + # - '192.168.1.1' + # List of ports that Synapse should listen on, their purpose and their # configuration. # -- cgit 1.5.1 From 5d4c330ed979b0d60efe5f80fd76de8f162263a1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Dec 2020 07:33:57 -0500 Subject: Allow re-using a UI auth validation for a period of time (#8970) --- changelog.d/8970.feature | 1 + docs/sample_config.yaml | 15 +++ synapse/config/_base.pyi | 4 +- synapse/config/auth.py | 110 +++++++++++++++++++++ synapse/config/homeserver.py | 4 +- synapse/config/password.py | 90 ----------------- synapse/handlers/auth.py | 32 ++++-- synapse/rest/client/v2_alpha/account.py | 10 +- synapse/storage/databases/main/registration.py | 38 +++++++ .../delta/58/26access_token_last_validated.sql | 18 ++++ tests/rest/client/v2_alpha/test_auth.py | 94 ++++++++++++------ 11 files changed, 280 insertions(+), 136 deletions(-) create mode 100644 changelog.d/8970.feature create mode 100644 synapse/config/auth.py delete mode 100644 synapse/config/password.py create mode 100644 synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql (limited to 'synapse/config') diff --git a/changelog.d/8970.feature b/changelog.d/8970.feature new file mode 100644 index 0000000000..6d5b3303a6 --- /dev/null +++ b/changelog.d/8970.feature @@ -0,0 +1 @@ +Allow re-using an user-interactive authentication session for a period of time. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 75a01094d5..549c581a97 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2068,6 +2068,21 @@ password_config: # #require_uppercase: true +ui_auth: + # The number of milliseconds to allow a user-interactive authentication + # session to be active. + # + # This defaults to 0, meaning the user is queried for their credentials + # before every action, but this can be overridden to alow a single + # validation to be re-used. This weakens the protections afforded by + # the user-interactive authentication process, by allowing for multiple + # (and potentially different) operations to use the same validation session. + # + # Uncomment below to allow for credential validation to last for 15 + # seconds. + # + #session_timeout: 15000 + # Configuration for sending emails from Synapse. # diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index ed26e2fb60..29aa064e57 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -3,6 +3,7 @@ from typing import Any, Iterable, List, Optional from synapse.config import ( api, appservice, + auth, captcha, cas, consent_config, @@ -14,7 +15,6 @@ from synapse.config import ( logger, metrics, oidc_config, - password, password_auth_providers, push, ratelimiting, @@ -65,7 +65,7 @@ class RootConfig: sso: sso.SSOConfig oidc: oidc_config.OIDCConfig jwt: jwt_config.JWTConfig - password: password.PasswordConfig + auth: auth.AuthConfig email: emailconfig.EmailConfig worker: workers.WorkerConfig authproviders: password_auth_providers.PasswordAuthProviderConfig diff --git a/synapse/config/auth.py b/synapse/config/auth.py new file mode 100644 index 0000000000..2b3e2ce87b --- /dev/null +++ b/synapse/config/auth.py @@ -0,0 +1,110 @@ +# -*- coding: utf-8 -*- +# Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import Config + + +class AuthConfig(Config): + """Password and login configuration + """ + + section = "auth" + + def read_config(self, config, **kwargs): + password_config = config.get("password_config", {}) + if password_config is None: + password_config = {} + + self.password_enabled = password_config.get("enabled", True) + self.password_localdb_enabled = password_config.get("localdb_enabled", True) + self.password_pepper = password_config.get("pepper", "") + + # Password policy + self.password_policy = password_config.get("policy") or {} + self.password_policy_enabled = self.password_policy.get("enabled", False) + + # User-interactive authentication + ui_auth = config.get("ui_auth") or {} + self.ui_auth_session_timeout = ui_auth.get("session_timeout", 0) + + def generate_config_section(self, config_dir_path, server_name, **kwargs): + return """\ + password_config: + # Uncomment to disable password login + # + #enabled: false + + # Uncomment to disable authentication against the local password + # database. This is ignored if `enabled` is false, and is only useful + # if you have other password_providers. + # + #localdb_enabled: false + + # Uncomment and change to a secret random string for extra security. + # DO NOT CHANGE THIS AFTER INITIAL SETUP! + # + #pepper: "EVEN_MORE_SECRET" + + # Define and enforce a password policy. Each parameter is optional. + # This is an implementation of MSC2000. + # + policy: + # Whether to enforce the password policy. + # Defaults to 'false'. + # + #enabled: true + + # Minimum accepted length for a password. + # Defaults to 0. + # + #minimum_length: 15 + + # Whether a password must contain at least one digit. + # Defaults to 'false'. + # + #require_digit: true + + # Whether a password must contain at least one symbol. + # A symbol is any character that's not a number or a letter. + # Defaults to 'false'. + # + #require_symbol: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_lowercase: true + + # Whether a password must contain at least one lowercase letter. + # Defaults to 'false'. + # + #require_uppercase: true + + ui_auth: + # The number of milliseconds to allow a user-interactive authentication + # session to be active. + # + # This defaults to 0, meaning the user is queried for their credentials + # before every action, but this can be overridden to alow a single + # validation to be re-used. This weakens the protections afforded by + # the user-interactive authentication process, by allowing for multiple + # (and potentially different) operations to use the same validation session. + # + # Uncomment below to allow for credential validation to last for 15 + # seconds. + # + #session_timeout: 15000 + """ diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index be65554524..4bd2b3587b 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -17,6 +17,7 @@ from ._base import RootConfig from .api import ApiConfig from .appservice import AppServiceConfig +from .auth import AuthConfig from .cache import CacheConfig from .captcha import CaptchaConfig from .cas import CasConfig @@ -30,7 +31,6 @@ from .key import KeyConfig from .logger import LoggingConfig from .metrics import MetricsConfig from .oidc_config import OIDCConfig -from .password import PasswordConfig from .password_auth_providers import PasswordAuthProviderConfig from .push import PushConfig from .ratelimiting import RatelimitConfig @@ -76,7 +76,7 @@ class HomeServerConfig(RootConfig): CasConfig, SSOConfig, JWTConfig, - PasswordConfig, + AuthConfig, EmailConfig, PasswordAuthProviderConfig, PushConfig, diff --git a/synapse/config/password.py b/synapse/config/password.py deleted file mode 100644 index 9c0ea8c30a..0000000000 --- a/synapse/config/password.py +++ /dev/null @@ -1,90 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2015, 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ._base import Config - - -class PasswordConfig(Config): - """Password login configuration - """ - - section = "password" - - def read_config(self, config, **kwargs): - password_config = config.get("password_config", {}) - if password_config is None: - password_config = {} - - self.password_enabled = password_config.get("enabled", True) - self.password_localdb_enabled = password_config.get("localdb_enabled", True) - self.password_pepper = password_config.get("pepper", "") - - # Password policy - self.password_policy = password_config.get("policy") or {} - self.password_policy_enabled = self.password_policy.get("enabled", False) - - def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """\ - password_config: - # Uncomment to disable password login - # - #enabled: false - - # Uncomment to disable authentication against the local password - # database. This is ignored if `enabled` is false, and is only useful - # if you have other password_providers. - # - #localdb_enabled: false - - # Uncomment and change to a secret random string for extra security. - # DO NOT CHANGE THIS AFTER INITIAL SETUP! - # - #pepper: "EVEN_MORE_SECRET" - - # Define and enforce a password policy. Each parameter is optional. - # This is an implementation of MSC2000. - # - policy: - # Whether to enforce the password policy. - # Defaults to 'false'. - # - #enabled: true - - # Minimum accepted length for a password. - # Defaults to 0. - # - #minimum_length: 15 - - # Whether a password must contain at least one digit. - # Defaults to 'false'. - # - #require_digit: true - - # Whether a password must contain at least one symbol. - # A symbol is any character that's not a number or a letter. - # Defaults to 'false'. - # - #require_symbol: true - - # Whether a password must contain at least one lowercase letter. - # Defaults to 'false'. - # - #require_lowercase: true - - # Whether a password must contain at least one lowercase letter. - # Defaults to 'false'. - # - #require_uppercase: true - """ diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 57ff461f92..f4434673dc 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -226,6 +226,9 @@ class AuthHandler(BaseHandler): burst_count=self.hs.config.rc_login_failed_attempts.burst_count, ) + # The number of seconds to keep a UI auth session active. + self._ui_auth_session_timeout = hs.config.ui_auth_session_timeout + # Ratelimitier for failed /login attempts self._failed_login_attempts_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -283,7 +286,7 @@ class AuthHandler(BaseHandler): request_body: Dict[str, Any], clientip: str, description: str, - ) -> Tuple[dict, str]: + ) -> Tuple[dict, Optional[str]]: """ Checks that the user is who they claim to be, via a UI auth. @@ -310,7 +313,8 @@ class AuthHandler(BaseHandler): have been given only in a previous call). 'session_id' is the ID of this session, either passed in by the - client or assigned by this call + client or assigned by this call. This is None if UI auth was + skipped (by re-using a previous validation). Raises: InteractiveAuthIncompleteError if the client has not yet completed @@ -324,6 +328,16 @@ class AuthHandler(BaseHandler): """ + if self._ui_auth_session_timeout: + last_validated = await self.store.get_access_token_last_validated( + requester.access_token_id + ) + if self.clock.time_msec() - last_validated < self._ui_auth_session_timeout: + # Return the input parameters, minus the auth key, which matches + # the logic in check_ui_auth. + request_body.pop("auth", None) + return request_body, None + user_id = requester.user.to_string() # Check if we should be ratelimited due to too many previous failed attempts @@ -359,6 +373,9 @@ class AuthHandler(BaseHandler): if user_id != requester.user.to_string(): raise AuthError(403, "Invalid auth") + # Note that the access token has been validated. + await self.store.update_access_token_last_validated(requester.access_token_id) + return params, session_id async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]: @@ -452,13 +469,10 @@ class AuthHandler(BaseHandler): all the stages in any of the permitted flows. """ - authdict = None sid = None # type: Optional[str] - if clientdict and "auth" in clientdict: - authdict = clientdict["auth"] - del clientdict["auth"] - if "session" in authdict: - sid = authdict["session"] + authdict = clientdict.pop("auth", {}) + if "session" in authdict: + sid = authdict["session"] # Convert the URI and method to strings. uri = request.uri.decode("utf-8") @@ -563,6 +577,8 @@ class AuthHandler(BaseHandler): creds = await self.store.get_completed_ui_auth_stages(session.session_id) for f in flows: + # If all the required credentials have been supplied, the user has + # successfully completed the UI auth process! if len(set(f) - set(creds)) == 0: # it's very useful to know what args are stored, but this can # include the password in the case of registering, so only log diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index eebee44a44..d837bde1d6 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -254,14 +254,18 @@ class PasswordRestServlet(RestServlet): logger.error("Auth succeeded but no known type! %r", result.keys()) raise SynapseError(500, "", Codes.UNKNOWN) - # If we have a password in this request, prefer it. Otherwise, there - # must be a password hash from an earlier request. + # If we have a password in this request, prefer it. Otherwise, use the + # password hash from an earlier request. if new_password: password_hash = await self.auth_handler.hash(new_password) - else: + elif session_id is not None: password_hash = await self.auth_handler.get_session_data( session_id, "password_hash", None ) + else: + # UI validation was skipped, but the request did not include a new + # password. + password_hash = None if not password_hash: raise SynapseError(400, "Missing params: password", Codes.MISSING_PARAM) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index ff96c34c2e..8d05288ed4 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -943,6 +943,42 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore): desc="del_user_pending_deactivation", ) + async def get_access_token_last_validated(self, token_id: int) -> int: + """Retrieves the time (in milliseconds) of the last validation of an access token. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if the access token was not found. + + Returns: + The last validation time. + """ + result = await self.db_pool.simple_select_one_onecol( + "access_tokens", {"id": token_id}, "last_validated" + ) + + # If this token has not been validated (since starting to track this), + # return 0 instead of None. + return result or 0 + + async def update_access_token_last_validated(self, token_id: int) -> None: + """Updates the last time an access token was validated. + + Args: + token_id: The ID of the access token to update. + Raises: + StoreError if there was a problem updating this. + """ + now = self._clock.time_msec() + + await self.db_pool.simple_update_one( + "access_tokens", + {"id": token_id}, + {"last_validated": now}, + desc="update_access_token_last_validated", + ) + class RegistrationBackgroundUpdateStore(RegistrationWorkerStore): def __init__(self, database: DatabasePool, db_conn: Connection, hs: "HomeServer"): @@ -1150,6 +1186,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): The token ID """ next_id = self._access_tokens_id_gen.get_next() + now = self._clock.time_msec() await self.db_pool.simple_insert( "access_tokens", @@ -1160,6 +1197,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore): "device_id": device_id, "valid_until_ms": valid_until_ms, "puppets_user_id": puppets_user_id, + "last_validated": now, }, desc="add_access_token_to_user", ) diff --git a/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql new file mode 100644 index 0000000000..1a101cd5eb --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/26access_token_last_validated.sql @@ -0,0 +1,18 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- The last time this access token was "validated" (i.e. logged in or succeeded +-- at user-interactive authentication). +ALTER TABLE access_tokens ADD COLUMN last_validated BIGINT; diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py index 51323b3da3..ac66a4e0b7 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Union +from typing import Union from twisted.internet.defer import succeed @@ -177,13 +177,8 @@ class UIAuthTests(unittest.HomeserverTestCase): def prepare(self, reactor, clock, hs): self.user_pass = "pass" self.user = self.register_user("test", self.user_pass) - self.user_tok = self.login("test", self.user_pass) - - def get_device_ids(self, access_token: str) -> List[str]: - # Get the list of devices so one can be deleted. - channel = self.make_request("GET", "devices", access_token=access_token,) - self.assertEqual(channel.code, 200) - return [d["device_id"] for d in channel.json_body["devices"]] + self.device_id = "dev1" + self.user_tok = self.login("test", self.user_pass, self.device_id) def delete_device( self, @@ -219,11 +214,9 @@ class UIAuthTests(unittest.HomeserverTestCase): """ Test user interactive authentication outside of registration. """ - device_id = self.get_device_ids(self.user_tok)[0] - # Attempt to delete this device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -233,7 +226,7 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow. self.delete_device( self.user_tok, - device_id, + self.device_id, 200, { "auth": { @@ -252,14 +245,13 @@ class UIAuthTests(unittest.HomeserverTestCase): UIA - check that still works. """ - device_id = self.get_device_ids(self.user_tok)[0] - channel = self.delete_device(self.user_tok, device_id, 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) session = channel.json_body["session"] # Make another request providing the UI auth flow. self.delete_device( self.user_tok, - device_id, + self.device_id, 200, { "auth": { @@ -282,14 +274,11 @@ class UIAuthTests(unittest.HomeserverTestCase): session ID should be rejected. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids(self.user_tok) - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_devices(401, {"devices": [device_ids[0]]}) + channel = self.delete_devices(401, {"devices": [self.device_id]}) # Grab the session session = channel.json_body["session"] @@ -301,7 +290,7 @@ class UIAuthTests(unittest.HomeserverTestCase): self.delete_devices( 200, { - "devices": [device_ids[1]], + "devices": ["dev2"], "auth": { "type": "m.login.password", "identifier": {"type": "m.id.user", "user": self.user}, @@ -316,14 +305,11 @@ class UIAuthTests(unittest.HomeserverTestCase): The initial requested URI cannot be modified during the user interactive authentication session. """ # Create a second login. - self.login("test", self.user_pass) - - device_ids = self.get_device_ids(self.user_tok) - self.assertEqual(len(device_ids), 2) + self.login("test", self.user_pass, "dev2") # Attempt to delete the first device. # Returns a 401 as per the spec - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) # Grab the session session = channel.json_body["session"] @@ -332,9 +318,11 @@ class UIAuthTests(unittest.HomeserverTestCase): # Make another request providing the UI auth flow, but try to delete the # second device. This results in an error. + # + # This makes use of the fact that the device ID is embedded into the URL. self.delete_device( self.user_tok, - device_ids[1], + "dev2", 403, { "auth": { @@ -346,6 +334,52 @@ class UIAuthTests(unittest.HomeserverTestCase): }, ) + @unittest.override_config({"ui_auth": {"session_timeout": 5 * 1000}}) + def test_can_reuse_session(self): + """ + The session can be reused if configured. + + Compare to test_cannot_change_uri. + """ + # Create a second and third login. + self.login("test", self.user_pass, "dev2") + self.login("test", self.user_pass, "dev3") + + # Attempt to delete a device. This works since the user just logged in. + self.delete_device(self.user_tok, "dev2", 200) + + # Move the clock forward past the validation timeout. + self.reactor.advance(6) + + # Deleting another devices throws the user into UI auth. + channel = self.delete_device(self.user_tok, "dev3", 401) + + # Grab the session + session = channel.json_body["session"] + # Ensure that flows are what is expected. + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + # Make another request providing the UI auth flow. + self.delete_device( + self.user_tok, + "dev3", + 200, + { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.user_pass, + "session": session, + }, + }, + ) + + # Make another request, but try to delete the first device. This works + # due to re-using the previous session. + # + # Note that *no auth* information is provided, not even a session iD! + self.delete_device(self.user_tok, self.device_id, 200) + def test_does_not_offer_password_for_sso_user(self): login_resp = self.helper.login_via_oidc("username") user_tok = login_resp["access_token"] @@ -361,8 +395,7 @@ class UIAuthTests(unittest.HomeserverTestCase): def test_does_not_offer_sso_for_password_user(self): # now call the device deletion API: we should get the option to auth with SSO # and not password. - device_ids = self.get_device_ids(self.user_tok) - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] self.assertEqual(flows, [{"stages": ["m.login.password"]}]) @@ -373,8 +406,7 @@ class UIAuthTests(unittest.HomeserverTestCase): login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart) self.assertEqual(login_resp["user_id"], self.user) - device_ids = self.get_device_ids(self.user_tok) - channel = self.delete_device(self.user_tok, device_ids[0], 401) + channel = self.delete_device(self.user_tok, self.device_id, 401) flows = channel.json_body["flows"] # we have no particular expectations of ordering here -- cgit 1.5.1 From 28877fade90a5cfb3457c9e6c70924dbbe8af715 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 18 Dec 2020 14:19:46 +0000 Subject: Implement a username picker for synapse (#8942) The final part (for now) of my work to implement a username picker in synapse itself. The idea is that we allow `UsernameMappingProvider`s to return `localpart=None`, in which case, rather than redirecting the browser back to the client, we redirect to a username-picker resource, which allows the user to enter a username. We *then* complete the SSO flow (including doing the client permission checks). The static resources for the username picker itself (in https://github.com/matrix-org/synapse/tree/rav/username_picker/synapse/res/username_picker) are essentially lifted wholesale from https://github.com/matrix-org/matrix-synapse-saml-mozilla/tree/master/matrix_synapse_saml_mozilla/res. As the comment says, we might want to think about making them customisable, but that can be a follow-up. Fixes #8876. --- changelog.d/8942.feature | 1 + docs/sample_config.yaml | 5 +- docs/sso_mapping_providers.md | 28 +-- synapse/app/homeserver.py | 2 + synapse/config/oidc_config.py | 5 +- synapse/handlers/oidc_handler.py | 59 +++---- synapse/handlers/sso.py | 254 ++++++++++++++++++++++++++- synapse/res/username_picker/index.html | 19 ++ synapse/res/username_picker/script.js | 95 ++++++++++ synapse/res/username_picker/style.css | 27 +++ synapse/rest/synapse/client/pick_username.py | 88 ++++++++++ synapse/types.py | 8 +- tests/handlers/test_oidc.py | 143 ++++++++++++++- tests/unittest.py | 8 +- 14 files changed, 683 insertions(+), 59 deletions(-) create mode 100644 changelog.d/8942.feature create mode 100644 synapse/res/username_picker/index.html create mode 100644 synapse/res/username_picker/script.js create mode 100644 synapse/res/username_picker/style.css create mode 100644 synapse/rest/synapse/client/pick_username.py (limited to 'synapse/config') diff --git a/changelog.d/8942.feature b/changelog.d/8942.feature new file mode 100644 index 0000000000..d450ef4998 --- /dev/null +++ b/changelog.d/8942.feature @@ -0,0 +1 @@ +Add support for allowing users to pick their own user ID during a single-sign-on login. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 549c581a97..077cb619c7 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1825,9 +1825,10 @@ oidc_config: # * user: The claims returned by the UserInfo Endpoint and/or in the ID # Token # - # This must be configured if using the default mapping provider. + # If this is not set, the user will be prompted to choose their + # own username. # - localpart_template: "{{ user.preferred_username }}" + #localpart_template: "{{ user.preferred_username }}" # Jinja2 template for the display name to set on first login. # diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 7714b1d844..e1d6ede7ba 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -15,12 +15,18 @@ where SAML mapping providers come into play. SSO mapping providers are currently supported for OpenID and SAML SSO configurations. Please see the details below for how to implement your own. -It is the responsibility of the mapping provider to normalise the SSO attributes -and map them to a valid Matrix ID. The -[specification for Matrix IDs](https://matrix.org/docs/spec/appendices#user-identifiers) -has some information about what is considered valid. Alternately an easy way to -ensure it is valid is to use a Synapse utility function: -`synapse.types.map_username_to_mxid_localpart`. +It is up to the mapping provider whether the user should be assigned a predefined +Matrix ID based on the SSO attributes, or if the user should be allowed to +choose their own username. + +In the first case - where users are automatically allocated a Matrix ID - it is +the responsibility of the mapping provider to normalise the SSO attributes and +map them to a valid Matrix ID. The [specification for Matrix +IDs](https://matrix.org/docs/spec/appendices#user-identifiers) has some +information about what is considered valid. + +If the mapping provider does not assign a Matrix ID, then Synapse will +automatically serve an HTML page allowing the user to pick their own username. External mapping providers are provided to Synapse in the form of an external Python module. You can retrieve this module from [PyPI](https://pypi.org) or elsewhere, @@ -80,8 +86,9 @@ A custom mapping provider must specify the following methods: with failures=1. The method should then return a different `localpart` value, such as `john.doe1`. - Returns a dictionary with two keys: - - localpart: A required string, used to generate the Matrix ID. - - displayname: An optional string, the display name for the user. + - `localpart`: A string, used to generate the Matrix ID. If this is + `None`, the user is prompted to pick their own username. + - `displayname`: An optional string, the display name for the user. * `get_extra_attributes(self, userinfo, token)` - This method must be async. - Arguments: @@ -165,12 +172,13 @@ A custom mapping provider must specify the following methods: redirected to. - This method must return a dictionary, which will then be used by Synapse to build a new user. The following keys are allowed: - * `mxid_localpart` - Required. The mxid localpart of the new user. + * `mxid_localpart` - The mxid localpart of the new user. If this is + `None`, the user is prompted to pick their own username. * `displayname` - The displayname of the new user. If not provided, will default to the value of `mxid_localpart`. * `emails` - A list of emails for the new user. If not provided, will default to an empty list. - + Alternatively it can raise a `synapse.api.errors.RedirectException` to redirect the user to another page. This is useful to prompt the user for additional information, e.g. if you want them to provide their own username. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index bbb7407838..8d9b53be53 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer): "/_matrix/client/versions": client_resource, "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), + "/_synapse/client/pick_username": pick_username_resource(self), } ) diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 1abf8ed405..4e3055282d 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -203,9 +203,10 @@ class OIDCConfig(Config): # * user: The claims returned by the UserInfo Endpoint and/or in the ID # Token # - # This must be configured if using the default mapping provider. + # If this is not set, the user will be prompted to choose their + # own username. # - localpart_template: "{{{{ user.preferred_username }}}}" + #localpart_template: "{{{{ user.preferred_username }}}}" # Jinja2 template for the display name to set on first login. # diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index cbd11a1382..709f8dfc13 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -947,7 +947,7 @@ class OidcHandler(BaseHandler): UserAttributeDict = TypedDict( - "UserAttributeDict", {"localpart": str, "display_name": Optional[str]} + "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} ) C = TypeVar("C") @@ -1028,10 +1028,10 @@ env = Environment(finalize=jinja_finalize) @attr.s class JinjaOidcMappingConfig: - subject_claim = attr.ib() # type: str - localpart_template = attr.ib() # type: Template - display_name_template = attr.ib() # type: Optional[Template] - extra_attributes = attr.ib() # type: Dict[str, Template] + subject_claim = attr.ib(type=str) + localpart_template = attr.ib(type=Optional[Template]) + display_name_template = attr.ib(type=Optional[Template]) + extra_attributes = attr.ib(type=Dict[str, Template]) class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1047,18 +1047,14 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") - if "localpart_template" not in config: - raise ConfigError( - "missing key: oidc_config.user_mapping_provider.config.localpart_template" - ) - - try: - localpart_template = env.from_string(config["localpart_template"]) - except Exception as e: - raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.localpart_template: %r" - % (e,) - ) + localpart_template = None # type: Optional[Template] + if "localpart_template" in config: + try: + localpart_template = env.from_string(config["localpart_template"]) + except Exception as e: + raise ConfigError( + "invalid jinja template", path=["localpart_template"] + ) from e display_name_template = None # type: Optional[Template] if "display_name_template" in config: @@ -1066,26 +1062,22 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): display_name_template = env.from_string(config["display_name_template"]) except Exception as e: raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.display_name_template: %r" - % (e,) - ) + "invalid jinja template", path=["display_name_template"] + ) from e extra_attributes = {} # type Dict[str, Template] if "extra_attributes" in config: extra_attributes_config = config.get("extra_attributes") or {} if not isinstance(extra_attributes_config, dict): - raise ConfigError( - "oidc_config.user_mapping_provider.config.extra_attributes must be a dict" - ) + raise ConfigError("must be a dict", path=["extra_attributes"]) for key, value in extra_attributes_config.items(): try: extra_attributes[key] = env.from_string(value) except Exception as e: raise ConfigError( - "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" - % (key, e) - ) + "invalid jinja template", path=["extra_attributes", key] + ) from e return JinjaOidcMappingConfig( subject_claim=subject_claim, @@ -1100,14 +1092,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): async def map_user_attributes( self, userinfo: UserInfo, token: Token, failures: int ) -> UserAttributeDict: - localpart = self._config.localpart_template.render(user=userinfo).strip() + localpart = None + + if self._config.localpart_template: + localpart = self._config.localpart_template.render(user=userinfo).strip() - # Ensure only valid characters are included in the MXID. - localpart = map_username_to_mxid_localpart(localpart) + # Ensure only valid characters are included in the MXID. + localpart = map_username_to_mxid_localpart(localpart) - # Append suffix integer if last call to this function failed to produce - # a usable mxid. - localpart += str(failures) if failures else "" + # Append suffix integer if last call to this function failed to produce + # a usable mxid. + localpart += str(failures) if failures else "" display_name = None # type: Optional[str] if self._config.display_name_template is not None: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index f054b66a53..548b02211b 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -13,17 +13,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional import attr +from typing_extensions import NoReturn from twisted.web.http import Request -from synapse.api.errors import RedirectException +from synapse.api.errors import RedirectException, SynapseError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer +from synapse.util.stringutils import random_string if TYPE_CHECKING: from synapse.server import HomeServer @@ -40,16 +42,52 @@ class MappingException(Exception): @attr.s class UserAttributes: - localpart = attr.ib(type=str) + # the localpart of the mxid that the mapper has assigned to the user. + # if `None`, the mapper has not picked a userid, and the user should be prompted to + # enter one. + localpart = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str], default=None) emails = attr.ib(type=List[str], default=attr.Factory(list)) +@attr.s(slots=True) +class UsernameMappingSession: + """Data we track about SSO sessions""" + + # A unique identifier for this SSO provider, e.g. "oidc" or "saml". + auth_provider_id = attr.ib(type=str) + + # user ID on the IdP server + remote_user_id = attr.ib(type=str) + + # attributes returned by the ID mapper + display_name = attr.ib(type=Optional[str]) + emails = attr.ib(type=List[str]) + + # An optional dictionary of extra attributes to be provided to the client in the + # login response. + extra_login_attributes = attr.ib(type=Optional[JsonDict]) + + # where to redirect the client back to + client_redirect_url = attr.ib(type=str) + + # expiry time for the session, in milliseconds + expiry_time_ms = attr.ib(type=int) + + +# the HTTP cookie used to track the mapping session id +USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" + + class SsoHandler: # The number of attempts to ask the mapping provider for when generating an MXID. _MAP_USERNAME_RETRIES = 1000 + # the time a UsernameMappingSession remains valid for + _MAPPING_SESSION_VALIDITY_PERIOD_MS = 15 * 60 * 1000 + def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() self._store = hs.get_datastore() self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() @@ -59,6 +97,9 @@ class SsoHandler: # a lock on the mappings self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) + # a map from session id to session data + self._username_mapping_sessions = {} # type: Dict[str, UsernameMappingSession] + def render_error( self, request, error: str, error_description: Optional[str] = None ) -> None: @@ -206,6 +247,18 @@ class SsoHandler: # Otherwise, generate a new user. if not user_id: attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) + + if attributes.localpart is None: + # the mapper doesn't return a username. bail out with a redirect to + # the username picker. + await self._redirect_to_username_picker( + auth_provider_id, + remote_user_id, + attributes, + client_redirect_url, + extra_login_attributes, + ) + user_id = await self._register_mapped_user( attributes, auth_provider_id, @@ -243,10 +296,8 @@ class SsoHandler: ) if not attributes.localpart: - raise MappingException( - "Error parsing SSO response: SSO mapping provider plugin " - "did not return a localpart value" - ) + # the mapper has not picked a localpart + return attributes # Check if this mxid already exists user_id = UserID(attributes.localpart, self._server_name).to_string() @@ -261,6 +312,59 @@ class SsoHandler: ) return attributes + async def _redirect_to_username_picker( + self, + auth_provider_id: str, + remote_user_id: str, + attributes: UserAttributes, + client_redirect_url: str, + extra_login_attributes: Optional[JsonDict], + ) -> NoReturn: + """Creates a UsernameMappingSession and redirects the browser + + Called if the user mapping provider doesn't return a localpart for a new user. + Raises a RedirectException which redirects the browser to the username picker. + + Args: + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + attributes: the user attributes returned by the user mapping provider. + + client_redirect_url: The redirect URL passed in by the client, which we + will eventually redirect back to. + + extra_login_attributes: An optional dictionary of extra + attributes to be provided to the client in the login response. + + Raises: + RedirectException + """ + session_id = random_string(16) + now = self._clock.time_msec() + session = UsernameMappingSession( + auth_provider_id=auth_provider_id, + remote_user_id=remote_user_id, + display_name=attributes.display_name, + emails=attributes.emails, + client_redirect_url=client_redirect_url, + expiry_time_ms=now + self._MAPPING_SESSION_VALIDITY_PERIOD_MS, + extra_login_attributes=extra_login_attributes, + ) + + self._username_mapping_sessions[session_id] = session + logger.info("Recorded registration session id %s", session_id) + + # Set the cookie and redirect to the username picker + e = RedirectException(b"/_synapse/client/pick_username") + e.cookies.append( + b"%s=%s; path=/" + % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) + ) + raise e + async def _register_mapped_user( self, attributes: UserAttributes, @@ -269,9 +373,38 @@ class SsoHandler: user_agent: str, ip_address: str, ) -> str: + """Register a new SSO user. + + This is called once we have successfully mapped the remote user id onto a local + user id, one way or another. + + Args: + attributes: user attributes returned by the user mapping provider, + including a non-empty localpart. + + auth_provider_id: A unique identifier for this SSO provider, e.g. + "oidc" or "saml". + + remote_user_id: The unique identifier from the SSO provider. + + user_agent: The user-agent in the HTTP request (used for potential + shadow-banning.) + + ip_address: The IP address of the requester (used for potential + shadow-banning.) + + Raises: + a MappingException if the localpart is invalid. + + a SynapseError with code 400 and errcode Codes.USER_IN_USE if the localpart + is already taken. + """ + # Since the localpart is provided via a potentially untrusted module, # ensure the MXID is valid before registering. - if contains_invalid_mxid_characters(attributes.localpart): + if not attributes.localpart or contains_invalid_mxid_characters( + attributes.localpart + ): raise MappingException("localpart is invalid: %s" % (attributes.localpart,)) logger.debug("Mapped SSO user to local part %s", attributes.localpart) @@ -326,3 +459,108 @@ class SsoHandler: await self._auth_handler.complete_sso_ui_auth( user_id, ui_auth_session_id, request ) + + async def check_username_availability( + self, localpart: str, session_id: str, + ) -> bool: + """Handle an "is username available" callback check + + Args: + localpart: desired localpart + session_id: the session id for the username picker + Returns: + True if the username is available + Raises: + SynapseError if the localpart is invalid or the session is unknown + """ + + # make sure that there is a valid mapping session, to stop people dictionary- + # scanning for accounts + + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + + logger.info( + "[session %s] Checking for availability of username %s", + session_id, + localpart, + ) + + if contains_invalid_mxid_characters(localpart): + raise SynapseError(400, "localpart is invalid: %s" % (localpart,)) + user_id = UserID(localpart, self._server_name).to_string() + user_infos = await self._store.get_users_by_id_case_insensitive(user_id) + + logger.info("[session %s] users: %s", session_id, user_infos) + return not user_infos + + async def handle_submit_username_request( + self, request: SynapseRequest, localpart: str, session_id: str + ) -> None: + """Handle a request to the username-picker 'submit' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + localpart: localpart requested by the user + session_id: ID of the username mapping session, extracted from a cookie + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if not session: + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + + logger.info("[session %s] Registering localpart %s", session_id, localpart) + + attributes = UserAttributes( + localpart=localpart, + display_name=session.display_name, + emails=session.emails, + ) + + # the following will raise a 400 error if the username has been taken in the + # meantime. + user_id = await self._register_mapped_user( + attributes, + session.auth_provider_id, + session.remote_user_id, + request.get_user_agent(""), + request.getClientIP(), + ) + + logger.info("[session %s] Registered userid %s", session_id, user_id) + + # delete the mapping session and the cookie + del self._username_mapping_sessions[session_id] + + # delete the cookie + request.addCookie( + USERNAME_MAPPING_SESSION_COOKIE_NAME, + b"", + expires=b"Thu, 01 Jan 1970 00:00:00 GMT", + path=b"/", + ) + + await self._auth_handler.complete_sso_login( + user_id, + request, + session.client_redirect_url, + session.extra_login_attributes, + ) + + def _expire_old_sessions(self): + to_expire = [] + now = int(self._clock.time_msec()) + + for session_id, session in self._username_mapping_sessions.items(): + if session.expiry_time_ms <= now: + to_expire.append(session_id) + + for session_id in to_expire: + logger.info("Expiring mapping session %s", session_id) + del self._username_mapping_sessions[session_id] diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html new file mode 100644 index 0000000000..37ea8bb6d8 --- /dev/null +++ b/synapse/res/username_picker/index.html @@ -0,0 +1,19 @@ + + + + Synapse Login + + + +
+
+ + + +
+ + + +
+ + diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js new file mode 100644 index 0000000000..416a7c6f41 --- /dev/null +++ b/synapse/res/username_picker/script.js @@ -0,0 +1,95 @@ +let inputField = document.getElementById("field-username"); +let inputForm = document.getElementById("form"); +let submitButton = document.getElementById("button-submit"); +let message = document.getElementById("message"); + +// Submit username and receive response +function showMessage(messageText) { + // Unhide the message text + message.classList.remove("hidden"); + + message.textContent = messageText; +}; + +function doSubmit() { + showMessage("Success. Please wait a moment for your browser to redirect."); + + // remove the event handler before re-submitting the form. + delete inputForm.onsubmit; + inputForm.submit(); +} + +function onResponse(response) { + // Display message + showMessage(response); + + // Enable submit button and input field + submitButton.classList.remove('button--disabled'); + submitButton.value = "Submit"; +}; + +let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); +function usernameIsValid(username) { + return !allowedUsernameCharacters.test(username); +} +let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; + +function buildQueryString(params) { + return Object.keys(params) + .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) + .join('&'); +} + +function submitUsername(username) { + if(username.length == 0) { + onResponse("Please enter a username."); + return; + } + if(!usernameIsValid(username)) { + onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); + return; + } + + // if this browser doesn't support fetch, skip the availability check. + if(!window.fetch) { + doSubmit(); + return; + } + + let check_uri = 'check?' + buildQueryString({"username": username}); + fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw text; }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + throw json.error; + } else if(json.available) { + doSubmit(); + } else { + onResponse("This username is not available, please choose another."); + } + }).catch((err) => { + onResponse("Error checking username availability: " + err); + }); +} + +function clickSubmit() { + event.preventDefault(); + if(submitButton.classList.contains('button--disabled')) { return; } + + // Disable submit button and input field + submitButton.classList.add('button--disabled'); + + // Submit username + submitButton.value = "Checking..."; + submitUsername(inputField.value); +}; + +inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css new file mode 100644 index 0000000000..745bd4c684 --- /dev/null +++ b/synapse/res/username_picker/style.css @@ -0,0 +1,27 @@ +input[type="text"] { + font-size: 100%; + background-color: #ededf0; + border: 1px solid #fff; + border-radius: .2em; + padding: .5em .9em; + display: block; + width: 26em; +} + +.button--disabled { + border-color: #fff; + background-color: transparent; + color: #000; + text-transform: none; +} + +.hidden { + display: none; +} + +.tooltip { + background-color: #f9f9fa; + padding: 1em; + margin: 1em 0; +} + diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py new file mode 100644 index 0000000000..d3b6803e65 --- /dev/null +++ b/synapse/rest/synapse/client/pick_username.py @@ -0,0 +1,88 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import TYPE_CHECKING + +import pkg_resources + +from twisted.web.http import Request +from twisted.web.resource import Resource +from twisted.web.static import File + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME +from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + + +def pick_username_resource(hs: "HomeServer") -> Resource: + """Factory method to generate the username picker resource. + + This resource gets mounted under /_synapse/client/pick_username. The top-level + resource is just a File resource which serves up the static files in the resources + "res" directory, but it has a couple of children: + + * "submit", which does the mechanics of registering the new user, and redirects the + browser back to the client URL + + * "check": checks if a userid is free. + """ + + # XXX should we make this path customisable so that admins can restyle it? + base_path = pkg_resources.resource_filename("synapse", "res/username_picker") + + res = File(base_path) + res.putChild(b"submit", SubmitResource(hs)) + res.putChild(b"check", AvailabilityCheckResource(hs)) + + return res + + +class AvailabilityCheckResource(DirectServeJsonResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request): + localpart = parse_string(request, "username", required=True) + + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + + is_available = await self._sso_handler.check_username_availability( + localpart, session_id.decode("ascii", errors="replace") + ) + return 200, {"available": is_available} + + +class SubmitResource(DirectServeHtmlResource): + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_POST(self, request: SynapseRequest): + localpart = parse_string(request, "username", required=True) + + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + + await self._sso_handler.handle_submit_username_request( + request, localpart, session_id.decode("ascii", errors="replace") + ) diff --git a/synapse/types.py b/synapse/types.py index 3ab6bdbe06..c7d4e95809 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -349,15 +349,17 @@ NON_MXID_CHARACTER_PATTERN = re.compile( ) -def map_username_to_mxid_localpart(username, case_sensitive=False): +def map_username_to_mxid_localpart( + username: Union[str, bytes], case_sensitive: bool = False +) -> str: """Map a username onto a string suitable for a MXID This follows the algorithm laid out at https://matrix.org/docs/spec/appendices.html#mapping-from-other-character-sets. Args: - username (unicode|bytes): username to be mapped - case_sensitive (bool): true if TEST and test should be mapped + username: username to be mapped + case_sensitive: true if TEST and test should be mapped onto different mxids Returns: diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index c54f1c5797..368d600b33 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -13,14 +13,21 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from urllib.parse import parse_qs, urlparse +import re +from typing import Dict +from urllib.parse import parse_qs, urlencode, urlparse from mock import ANY, Mock, patch import pymacaroons +from twisted.web.resource import Resource + +from synapse.api.errors import RedirectException from synapse.handlers.oidc_handler import OidcError from synapse.handlers.sso import MappingException +from synapse.rest.client.v1 import login +from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.server import HomeServer from synapse.types import UserID @@ -793,6 +800,140 @@ class OidcHandlerTestCase(HomeserverTestCase): "mapping_error", "Unable to generate a Matrix ID from the SSO response" ) + def test_empty_localpart(self): + """Attempts to map onto an empty localpart should be rejected.""" + userinfo = { + "sub": "tester", + "username": "", + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.username }}"} + } + } + } + ) + def test_null_localpart(self): + """Mapping onto a null localpart via an empty OIDC attribute should be rejected""" + userinfo = { + "sub": "tester", + "username": None, + } + self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) + self.assertRenderedError("mapping_error", "localpart is invalid: ") + + +class UsernamePickerTestCase(HomeserverTestCase): + servlets = [login.register_servlets] + + def default_config(self): + config = super().default_config() + config["public_baseurl"] = BASE_URL + oidc_config = { + "enabled": True, + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "issuer": ISSUER, + "scopes": SCOPES, + "user_mapping_provider": { + "config": {"display_name_template": "{{ user.displayname }}"} + }, + } + + # Update this config with what's in the default config so that + # override_config works as expected. + oidc_config.update(config.get("oidc_config", {})) + config["oidc_config"] = oidc_config + + # whitelist this client URI so we redirect straight to it rather than + # serving a confirmation page + config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + return config + + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + return d + + def test_username_picker(self): + """Test the happy path of a username picker flow.""" + client_redirect_url = "https://whitelisted.client" + + # first of all, mock up an OIDC callback to the OidcHandler, which should + # raise a RedirectException + userinfo = {"sub": "tester", "displayname": "Jonny"} + f = self.get_failure( + _make_callback_with_userinfo( + self.hs, userinfo, client_redirect_url=client_redirect_url + ), + RedirectException, + ) + + # check the Location and cookies returned by the RedirectException + self.assertEqual(f.value.location, b"/_synapse/client/pick_username") + cookieheader = f.value.cookies[0] + regex = re.compile(b"^username_mapping_session=([a-zA-Z]+);") + m = regex.search(cookieheader) + if not m: + self.fail("cookie header %s does not match %s" % (cookieheader, regex)) + + # introspect the sso handler a bit to check that the username mapping session + # looks ok. + session_id = m.group(1).decode("ascii") + username_mapping_sessions = self.hs.get_sso_handler()._username_mapping_sessions + self.assertIn( + session_id, username_mapping_sessions, "session id not found in map" + ) + session = username_mapping_sessions[session_id] + self.assertEqual(session.remote_user_id, "tester") + self.assertEqual(session.display_name, "Jonny") + self.assertEqual(session.client_redirect_url, client_redirect_url) + + # the expiry time should be about 15 minutes away + expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) + self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) + + # Now, submit a username to the username picker, which should serve a redirect + # back to the client + submit_path = f.value.location + b"/submit" + content = urlencode({b"username": b"bobby"}).encode("utf8") + chan = self.make_request( + "POST", + path=submit_path, + content=content, + content_is_form=True, + custom_headers=[ + ("Cookie", cookieheader), + # old versions of twisted don't do form-parsing without a valid + # content-length header. + ("Content-Length", str(len(content))), + ], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location starts with the requested redirect URL + self.assertEqual( + location_headers[0][: len(client_redirect_url)], client_redirect_url + ) + + # fish the login token out of the returned redirect uri + parts = urlparse(location_headers[0]) + query = parse_qs(parts.query) + login_token = query["loginToken"][0] + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@bobby:test") + async def _make_callback_with_userinfo( hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect" diff --git a/tests/unittest.py b/tests/unittest.py index 39e5e7b85c..af7f752c5a 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Dict, Optional, Type, TypeVar, Union +from typing import Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from mock import Mock, patch @@ -383,6 +383,9 @@ class HomeserverTestCase(TestCase): federation_auth_origin: str = None, content_is_form: bool = False, await_result: bool = True, + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the @@ -405,6 +408,8 @@ class HomeserverTestCase(TestCase): true (the default), will pump the test reactor until the the renderer tells the channel the request is finished. + custom_headers: (name, value) pairs to add as request headers + Returns: The FakeChannel object which stores the result of the request. """ @@ -420,6 +425,7 @@ class HomeserverTestCase(TestCase): federation_auth_origin, content_is_form, await_result, + custom_headers, ) def setup_test_homeserver(self, *args, **kwargs): -- cgit 1.5.1 From 56e00ca85e502247112a95ab8c452c83ab5fc4b0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 18 Dec 2020 11:01:57 -0500 Subject: Send the location of the web client to the IS when inviting via 3PIDs. (#8930) Adds a new setting `email.invite_client_location` which, if defined, is passed to the identity server during invites. --- changelog.d/8930.feature | 1 + docs/sample_config.yaml | 6 ++++++ synapse/config/emailconfig.py | 22 ++++++++++++++++++++++ synapse/handlers/identity.py | 5 +++++ 4 files changed, 34 insertions(+) create mode 100644 changelog.d/8930.feature (limited to 'synapse/config') diff --git a/changelog.d/8930.feature b/changelog.d/8930.feature new file mode 100644 index 0000000000..cb305b5266 --- /dev/null +++ b/changelog.d/8930.feature @@ -0,0 +1 @@ +Add an `email.invite_client_location` configuration option to send a web client location to the invite endpoint on the identity server which allows customisation of the email template. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 077cb619c7..0b4dd115fb 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2149,6 +2149,12 @@ email: # #validation_token_lifetime: 15m + # The web client location to direct users to during an invite. This is passed + # to the identity server as the org.matrix.web_client_location key. Defaults + # to unset, giving no guidance to the identity server. + # + #invite_client_location: https://app.element.io + # Directory in which Synapse will try to find the template files below. # If not set, or the files named below are not found within the template # directory, default templates from within the Synapse package will be used. diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index 7c8b64d84b..d4328c46b9 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -322,6 +322,22 @@ class EmailConfig(Config): self.email_subjects = EmailSubjectConfig(**subjects) + # The invite client location should be a HTTP(S) URL or None. + self.invite_client_location = email_config.get("invite_client_location") or None + if self.invite_client_location: + if not isinstance(self.invite_client_location, str): + raise ConfigError( + "Config option email.invite_client_location must be type str" + ) + if not ( + self.invite_client_location.startswith("http://") + or self.invite_client_location.startswith("https://") + ): + raise ConfigError( + "Config option email.invite_client_location must be a http or https URL", + path=("email", "invite_client_location"), + ) + def generate_config_section(self, config_dir_path, server_name, **kwargs): return ( """\ @@ -389,6 +405,12 @@ class EmailConfig(Config): # #validation_token_lifetime: 15m + # The web client location to direct users to during an invite. This is passed + # to the identity server as the org.matrix.web_client_location key. Defaults + # to unset, giving no guidance to the identity server. + # + #invite_client_location: https://app.element.io + # Directory in which Synapse will try to find the template files below. # If not set, or the files named below are not found within the template # directory, default templates from within the Synapse package will be used. diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 7301c24710..c05036ad1f 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -55,6 +55,8 @@ class IdentityHandler(BaseHandler): self.federation_http_client = hs.get_federation_http_client() self.hs = hs + self._web_client_location = hs.config.invite_client_location + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: @@ -803,6 +805,9 @@ class IdentityHandler(BaseHandler): "sender_display_name": inviter_display_name, "sender_avatar_url": inviter_avatar_url, } + # If a custom web client location is available, include it in the request. + if self._web_client_location: + invite_config["org.matrix.web_client_location"] = self._web_client_location # Add the identity service access token to the JSON body and use the v2 # Identity Service endpoints if id_access_token is present -- cgit 1.5.1 From cfcf5541b463d4d360ef40a2982d702e9d6fb76a Mon Sep 17 00:00:00 2001 From: Jerin J Titus <72017981+jerinjtitus@users.noreply.github.com> Date: Tue, 29 Dec 2020 20:00:48 +0530 Subject: Update the value of group_creation_prefix in sample config. (#8992) Removes the trailing slash with causes issues with matrix.to/Element. --- changelog.d/8992.doc | 1 + docs/sample_config.yaml | 2 +- synapse/config/groups.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8992.doc (limited to 'synapse/config') diff --git a/changelog.d/8992.doc b/changelog.d/8992.doc new file mode 100644 index 0000000000..6a47bda26b --- /dev/null +++ b/changelog.d/8992.doc @@ -0,0 +1 @@ +Update the example value of `group_creation_prefix` in the sample configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 0b4dd115fb..dd981609ac 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2366,7 +2366,7 @@ spam_checker: # If enabled, non server admins can only create groups with local parts # starting with this prefix # -#group_creation_prefix: "unofficial/" +#group_creation_prefix: "unofficial_" diff --git a/synapse/config/groups.py b/synapse/config/groups.py index d6862d9a64..7b7860ea71 100644 --- a/synapse/config/groups.py +++ b/synapse/config/groups.py @@ -32,5 +32,5 @@ class GroupsConfig(Config): # If enabled, non server admins can only create groups with local parts # starting with this prefix # - #group_creation_prefix: "unofficial/" + #group_creation_prefix: "unofficial_" """ -- cgit 1.5.1 From 111b673fc1bbd3d51302d915f2ad2c044ed7d3b8 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 5 Jan 2021 11:25:28 +0000 Subject: Add initial support for a "pick your IdP" page (#9017) During login, if there are multiple IdPs enabled, offer the user a choice of IdPs. --- changelog.d/9017.feature | 1 + docs/sample_config.yaml | 25 ++++++++ synapse/app/homeserver.py | 2 + synapse/config/sso.py | 27 ++++++++ synapse/handlers/cas_handler.py | 3 + synapse/handlers/oidc_handler.py | 3 + synapse/handlers/saml_handler.py | 3 + synapse/handlers/sso.py | 18 +++++- synapse/res/templates/sso_login_idp_picker.html | 28 +++++++++ synapse/rest/synapse/client/pick_idp.py | 82 +++++++++++++++++++++++++ synapse/static/client/login/style.css | 5 ++ 11 files changed, 194 insertions(+), 3 deletions(-) create mode 100644 changelog.d/9017.feature create mode 100644 synapse/res/templates/sso_login_idp_picker.html create mode 100644 synapse/rest/synapse/client/pick_idp.py (limited to 'synapse/config') diff --git a/changelog.d/9017.feature b/changelog.d/9017.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9017.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index dd981609ac..c8ae46d1b3 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1909,6 +1909,31 @@ sso: # # Synapse will look for the following templates in this directory: # + # * HTML page to prompt the user to choose an Identity Provider during + # login: 'sso_login_idp_picker.html'. + # + # This is only used if multiple SSO Identity Providers are configured. + # + # When rendering, this template is given the following variables: + # * redirect_url: the URL that the user will be redirected to after + # login. Needs manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # * providers: a list of available Identity Providers. Each element is + # an object with the following attributes: + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # + # The rendered HTML page should contain a form which submits its results + # back as a GET request, with the following query parameters: + # + # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed + # to the template) + # + # * idp: the 'idp_id' of the chosen IDP. + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 8d9b53be53..b1d9817a6a 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -63,6 +63,7 @@ from synapse.rest import ClientRestResource from synapse.rest.admin import AdminRestResource from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource +from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer @@ -194,6 +195,7 @@ class SynapseHomeServer(HomeServer): "/.well-known/matrix/client": WellKnownResource(self), "/_synapse/admin": AdminRestResource(self), "/_synapse/client/pick_username": pick_username_resource(self), + "/_synapse/client/pick_idp": PickIdpResource(self), } ) diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 93bbd40937..1aeb1c5c92 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -31,6 +31,7 @@ class SSOConfig(Config): # Read templates from disk ( + self.sso_login_idp_picker_template, self.sso_redirect_confirm_template, self.sso_auth_confirm_template, self.sso_error_template, @@ -38,6 +39,7 @@ class SSOConfig(Config): sso_auth_success_template, ) = self.read_templates( [ + "sso_login_idp_picker.html", "sso_redirect_confirm.html", "sso_auth_confirm.html", "sso_error.html", @@ -98,6 +100,31 @@ class SSOConfig(Config): # # Synapse will look for the following templates in this directory: # + # * HTML page to prompt the user to choose an Identity Provider during + # login: 'sso_login_idp_picker.html'. + # + # This is only used if multiple SSO Identity Providers are configured. + # + # When rendering, this template is given the following variables: + # * redirect_url: the URL that the user will be redirected to after + # login. Needs manual escaping (see + # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). + # + # * server_name: the homeserver's name. + # + # * providers: a list of available Identity Providers. Each element is + # an object with the following attributes: + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # + # The rendered HTML page should contain a form which submits its results + # back as a GET request, with the following query parameters: + # + # * redirectUrl: the client redirect URI (ie, the `redirect_url` passed + # to the template) + # + # * idp: the 'idp_id' of the chosen IDP. + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 295974c521..f3430c6713 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -77,6 +77,9 @@ class CasHandler: # identifier for the external_ids table self.idp_id = "cas" + # user-facing name of this auth provider + self.idp_name = "CAS" + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 3e2b60eb7b..6835c6c462 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -121,6 +121,9 @@ class OidcHandler(BaseHandler): # identifier for the external_ids table self.idp_id = "oidc" + # user-facing name of this auth provider + self.idp_name = "OIDC" + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 6106237f1f..a8376543c9 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -75,6 +75,9 @@ class SamlHandler(BaseHandler): # identifier for the external_ids table self.idp_id = "saml" + # user-facing name of this auth provider + self.idp_name = "SAML" + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d8fb8cdd05..2da1ea2223 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -14,7 +14,8 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Optional +from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional +from urllib.parse import urlencode import attr from typing_extensions import NoReturn, Protocol @@ -66,6 +67,11 @@ class SsoIdentityProvider(Protocol): Eg, "saml", "cas", "github" """ + @property + @abc.abstractmethod + def idp_name(self) -> str: + """User-facing name for this provider""" + @abc.abstractmethod async def handle_redirect_request( self, @@ -156,6 +162,10 @@ class SsoHandler: assert p_id not in self._identity_providers self._identity_providers[p_id] = p + def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]: + """Get the configured identity providers""" + return self._identity_providers + def render_error( self, request: Request, @@ -203,8 +213,10 @@ class SsoHandler: ap = next(iter(self._identity_providers.values())) return await ap.handle_redirect_request(request, client_redirect_url) - # otherwise, we have a configuration error - raise Exception("Multiple SSO identity providers have been configured!") + # otherwise, redirect to the IDP picker + return "/_synapse/client/pick_idp?" + urlencode( + (("redirectUrl", client_redirect_url),) + ) async def get_sso_user_by_remote_user_id( self, auth_provider_id: str, remote_user_id: str diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html new file mode 100644 index 0000000000..f53c9cd679 --- /dev/null +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -0,0 +1,28 @@ + + + + + + {{server_name | e}} Login + + +
+

{{server_name | e}} Login

+ +
+ + diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py new file mode 100644 index 0000000000..e5b720bbca --- /dev/null +++ b/synapse/rest/synapse/client/pick_idp.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING + +from synapse.http.server import ( + DirectServeHtmlResource, + finish_request, + respond_with_html, +) +from synapse.http.servlet import parse_string +from synapse.http.site import SynapseRequest + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class PickIdpResource(DirectServeHtmlResource): + """IdP picker resource. + + This resource gets mounted under /_synapse/client/pick_idp. It serves an HTML page + which prompts the user to choose an Identity Provider from the list. + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + self._sso_login_idp_picker_template = ( + hs.config.sso.sso_login_idp_picker_template + ) + self._server_name = hs.hostname + + async def _async_render_GET(self, request: SynapseRequest) -> None: + client_redirect_url = parse_string(request, "redirectUrl", required=True) + idp = parse_string(request, "idp", required=False) + + # if we need to pick an IdP, do so + if not idp: + return await self._serve_id_picker(request, client_redirect_url) + + # otherwise, redirect to the IdP's redirect URI + providers = self._sso_handler.get_identity_providers() + auth_provider = providers.get(idp) + if not auth_provider: + logger.info("Unknown idp %r", idp) + self._sso_handler.render_error( + request, "unknown_idp", "Unknown identity provider ID" + ) + return + + sso_url = await auth_provider.handle_redirect_request( + request, client_redirect_url.encode("utf8") + ) + logger.info("Redirecting to %s", sso_url) + request.redirect(sso_url) + finish_request(request) + + async def _serve_id_picker( + self, request: SynapseRequest, client_redirect_url: str + ) -> None: + # otherwise, serve up the IdP picker + providers = self._sso_handler.get_identity_providers() + html = self._sso_login_idp_picker_template.render( + redirect_url=client_redirect_url, + server_name=self._server_name, + providers=providers.values(), + ) + respond_with_html(request, 200, html) diff --git a/synapse/static/client/login/style.css b/synapse/static/client/login/style.css index 83e4f6abc8..dd76714a92 100644 --- a/synapse/static/client/login/style.css +++ b/synapse/static/client/login/style.css @@ -31,6 +31,11 @@ form { margin: 10px 0 0 0; } +ul.radiobuttons { + text-align: left; + list-style: none; +} + /* * Add some padding to the viewport. */ -- cgit 1.5.1 From b530eaa262b9c8af378f976e5d2628e8c02b10d8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 7 Jan 2021 20:19:26 +0000 Subject: Allow running sendToDevice on workers (#9044) --- changelog.d/9044.feature | 1 + scripts/synapse_port_db | 27 ++++ synapse/app/generic_worker.py | 3 + synapse/config/workers.py | 10 +- synapse/handlers/devicemessage.py | 31 +++-- synapse/replication/slave/storage/deviceinbox.py | 32 +---- synapse/replication/tcp/handler.py | 9 ++ synapse/storage/databases/main/__init__.py | 33 ----- synapse/storage/databases/main/deviceinbox.py | 147 ++++++++++++++++----- .../schema/delta/59/02shard_send_to_device.sql | 18 +++ .../03shard_send_to_device_sequence.sql.postgres | 25 ++++ 11 files changed, 231 insertions(+), 105 deletions(-) create mode 100644 changelog.d/9044.feature create mode 100644 synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres (limited to 'synapse/config') diff --git a/changelog.d/9044.feature b/changelog.d/9044.feature new file mode 100644 index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9044.feature @@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db index 5ad17aa90f..22dd169bfb 100755 --- a/scripts/synapse_port_db +++ b/scripts/synapse_port_db @@ -629,6 +629,7 @@ class Porter(object): await self._setup_state_group_id_seq() await self._setup_user_id_seq() await self._setup_events_stream_seqs() + await self._setup_device_inbox_seq() # Step 3. Get tables. self.progress.set_state("Fetching tables") @@ -911,6 +912,32 @@ class Porter(object): "_setup_events_stream_seqs", _setup_events_stream_seqs_set_pos, ) + async def _setup_device_inbox_seq(self): + """Set the device inbox sequence to the correct value. + """ + curr_local_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="device_inbox", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ) + + curr_federation_id = await self.sqlite_store.db_pool.simple_select_one_onecol( + table="device_federation_outbox", + keyvalues={}, + retcol="COALESCE(MAX(stream_id), 1)", + allow_none=True, + ) + + next_id = max(curr_local_id, curr_federation_id) + 1 + + def r(txn): + txn.execute( + "ALTER SEQUENCE device_inbox_sequence RESTART WITH %s", (next_id,) + ) + + return self.postgres_store.db_pool.runInteraction("_setup_device_inbox_seq", r) + ############################################## # The following is simply UI stuff diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index fa23d9bb20..4428472707 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -108,6 +108,7 @@ from synapse.rest.client.v2_alpha.account_data import ( ) from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet from synapse.rest.client.v2_alpha.register import RegisterRestServlet +from synapse.rest.client.v2_alpha.sendtodevice import SendToDeviceRestServlet from synapse.rest.client.versions import VersionsRestServlet from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource @@ -520,6 +521,8 @@ class GenericWorkerServer(HomeServer): room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) + SendToDeviceRestServlet(self).register(resource) + user_directory.register_servlets(self, resource) # If presence is disabled, use the stub servlet that does diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 7ca9efec52..364583f48b 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -53,6 +53,9 @@ class WriterLocations: default=["master"], type=List[str], converter=_instance_to_list_converter ) typing = attr.ib(default="master", type=str) + to_device = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -124,7 +127,7 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing"): + for stream in ("events", "typing", "to_device"): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -133,6 +136,11 @@ class WorkerConfig(Config): % (instance, stream) ) + if len(self.writers.to_device) != 1: + raise ConfigError( + "Must only specify one instance to handle `to_device` messages." + ) + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) # Whether this worker should run background tasks or not. diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index eb10d2b4bd..fc974a82e8 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -45,11 +45,25 @@ class DeviceMessageHandler: self.store = hs.get_datastore() self.notifier = hs.get_notifier() self.is_mine = hs.is_mine - self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler( - "m.direct_to_device", self.on_direct_to_device_edu - ) + # We only need to poke the federation sender explicitly if its on the + # same instance. Other federation sender instances will get notified by + # `synapse.app.generic_worker.FederationSenderHandler` when it sees it + # in the to-device replication stream. + self.federation_sender = None + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + + # If we can handle the to device EDUs we do so, otherwise we route them + # to the appropriate worker. + if hs.get_instance_name() in hs.config.worker.writers.to_device: + hs.get_federation_registry().register_edu_handler( + "m.direct_to_device", self.on_direct_to_device_edu + ) + else: + hs.get_federation_registry().register_instances_for_edu( + "m.direct_to_device", hs.config.worker.writers.to_device, + ) # The handler to call when we think a user's device list might be out of # sync. We do all device list resyncing on the master instance, so if @@ -204,7 +218,8 @@ class DeviceMessageHandler: ) log_kv({"remote_messages": remote_messages}) - for destination in remote_messages.keys(): - # Enqueue a new federation transaction to send the new - # device messages to each remote destination. - self.federation.send_device_messages(destination) + if self.federation_sender: + for destination in remote_messages.keys(): + # Enqueue a new federation transaction to send the new + # device messages to each remote destination. + self.federation_sender.send_device_messages(destination) diff --git a/synapse/replication/slave/storage/deviceinbox.py b/synapse/replication/slave/storage/deviceinbox.py index 62b68dd6e9..1260f6d141 100644 --- a/synapse/replication/slave/storage/deviceinbox.py +++ b/synapse/replication/slave/storage/deviceinbox.py @@ -14,38 +14,8 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import ToDeviceStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.deviceinbox import DeviceInboxWorkerStore -from synapse.util.caches.stream_change_cache import StreamChangeCache class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - self._device_inbox_id_gen = SlavedIdTracker( - db_conn, "device_inbox", "stream_id" - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token(), - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - self._device_inbox_id_gen.get_current_token(), - ) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ToDeviceStream.NAME: - self._device_inbox_id_gen.advance(instance_name, token) - for row in rows: - if row.entity.startswith("@"): - self._device_inbox_stream_cache.entity_has_changed( - row.entity, token - ) - else: - self._device_federation_outbox_stream_cache.entity_has_changed( - row.entity, token - ) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 95e5502bf2..1f89249475 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -56,6 +56,7 @@ from synapse.replication.tcp.streams import ( EventsStream, FederationStream, Stream, + ToDeviceStream, TypingStream, ) @@ -115,6 +116,14 @@ class ReplicationCommandHandler: continue + if isinstance(stream, ToDeviceStream): + # Only add ToDeviceStream as a source on instances in charge of + # sending to device messages. + if hs.get_instance_name() in hs.config.worker.writers.to_device: + self._streams_to_replicate.append(stream) + + continue + if isinstance(stream, TypingStream): # Only add TypingStream as a source on the instance in charge of # typing. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 701748f93b..c4de07a0a8 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -127,9 +127,6 @@ class DataStore( self._presence_id_gen = StreamIdGenerator( db_conn, "presence_stream", "stream_id" ) - self._device_inbox_id_gen = StreamIdGenerator( - db_conn, "device_inbox", "stream_id" - ) self._public_room_id_gen = StreamIdGenerator( db_conn, "public_room_list_stream", "stream_id" ) @@ -189,36 +186,6 @@ class DataStore( prefilled_cache=presence_cache_prefill, ) - max_device_inbox_id = self._device_inbox_id_gen.get_current_token() - device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( - db_conn, - "device_inbox", - entity_column="user_id", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_inbox_stream_cache = StreamChangeCache( - "DeviceInboxStreamChangeCache", - min_device_inbox_id, - prefilled_cache=device_inbox_prefill, - ) - # The federation outbox and the local device inbox uses the same - # stream_id generator. - device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( - db_conn, - "device_federation_outbox", - entity_column="destination", - stream_column="stream_id", - max_value=max_device_inbox_id, - limit=1000, - ) - self._device_federation_outbox_stream_cache = StreamChangeCache( - "DeviceFederationOutboxStreamChangeCache", - min_device_outbox_id, - prefilled_cache=device_outbox_prefill, - ) - device_list_max = self._device_list_id_gen.get_current_token() self._device_list_stream_cache = StreamChangeCache( "DeviceListStreamChangeCache", device_list_max diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index eb72c21155..58d3f71e45 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -17,10 +17,14 @@ import logging from typing import List, Tuple from synapse.logging.opentracing import log_kv, set_tag, trace +from synapse.replication.tcp.streams import ToDeviceStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.util import json_encoder from synapse.util.caches.expiringcache import ExpiringCache +from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -29,6 +33,8 @@ class DeviceInboxWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): super().__init__(database, db_conn, hs) + self._instance_name = hs.get_instance_name() + # Map of (user_id, device_id) to the last stream_id that has been # deleted up to. This is so that we can no op deletions. self._last_device_delete_cache = ExpiringCache( @@ -38,6 +44,73 @@ class DeviceInboxWorkerStore(SQLBaseStore): expiry_ms=30 * 60 * 1000, ) + if isinstance(database.engine, PostgresEngine): + self._can_write_to_device = ( + self._instance_name in hs.config.worker.writers.to_device + ) + + self._device_inbox_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="to_device", + instance_name=self._instance_name, + table="device_inbox", + instance_column="instance_name", + id_column="stream_id", + sequence_name="device_inbox_sequence", + writers=hs.config.worker.writers.to_device, + ) + else: + self._can_write_to_device = True + self._device_inbox_id_gen = StreamIdGenerator( + db_conn, "device_inbox", "stream_id" + ) + + max_device_inbox_id = self._device_inbox_id_gen.get_current_token() + device_inbox_prefill, min_device_inbox_id = self.db_pool.get_cache_dict( + db_conn, + "device_inbox", + entity_column="user_id", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_inbox_stream_cache = StreamChangeCache( + "DeviceInboxStreamChangeCache", + min_device_inbox_id, + prefilled_cache=device_inbox_prefill, + ) + + # The federation outbox and the local device inbox uses the same + # stream_id generator. + device_outbox_prefill, min_device_outbox_id = self.db_pool.get_cache_dict( + db_conn, + "device_federation_outbox", + entity_column="destination", + stream_column="stream_id", + max_value=max_device_inbox_id, + limit=1000, + ) + self._device_federation_outbox_stream_cache = StreamChangeCache( + "DeviceFederationOutboxStreamChangeCache", + min_device_outbox_id, + prefilled_cache=device_outbox_prefill, + ) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == ToDeviceStream.NAME: + self._device_inbox_id_gen.advance(instance_name, token) + for row in rows: + if row.entity.startswith("@"): + self._device_inbox_stream_cache.entity_has_changed( + row.entity, token + ) + else: + self._device_federation_outbox_stream_cache.entity_has_changed( + row.entity, token + ) + return super().process_replication_rows(stream_name, instance_name, token, rows) + def get_to_device_stream_token(self): return self._device_inbox_id_gen.get_current_token() @@ -290,38 +363,6 @@ class DeviceInboxWorkerStore(SQLBaseStore): "get_all_new_device_messages", get_all_new_device_messages_txn ) - -class DeviceInboxBackgroundUpdateStore(SQLBaseStore): - DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" - - def __init__(self, database: DatabasePool, db_conn, hs): - super().__init__(database, db_conn, hs) - - self.db_pool.updates.register_background_index_update( - "device_inbox_stream_index", - index_name="device_inbox_stream_id_user_id", - table="device_inbox", - columns=["stream_id", "user_id"], - ) - - self.db_pool.updates.register_background_update_handler( - self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox - ) - - async def _background_drop_index_device_inbox(self, progress, batch_size): - def reindex_txn(conn): - txn = conn.cursor() - txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") - txn.close() - - await self.db_pool.runWithConnection(reindex_txn) - - await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) - - return 1 - - -class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): @trace async def add_messages_to_device_inbox( self, @@ -340,6 +381,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) The new stream_id. """ + assert self._can_write_to_device + def add_messages_txn(txn, now_ms, stream_id): # Add the local messages directly to the local inbox. self._add_messages_to_local_device_inbox_txn( @@ -358,6 +401,7 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) "stream_id": stream_id, "queued_ts": now_ms, "messages_json": json_encoder.encode(edu), + "instance_name": self._instance_name, } for destination, edu in remote_messages_by_destination.items() ], @@ -380,6 +424,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) async def add_messages_from_remote_to_device_inbox( self, origin: str, message_id: str, local_messages_by_user_then_device: dict ) -> int: + assert self._can_write_to_device + def add_messages_txn(txn, now_ms, stream_id): # Check if we've already inserted a matching message_id for that # origin. This can happen if the origin doesn't receive our @@ -428,6 +474,8 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) def _add_messages_to_local_device_inbox_txn( self, txn, stream_id, messages_by_user_then_device ): + assert self._can_write_to_device + local_by_user_then_device = {} for user_id, messages_by_device in messages_by_user_then_device.items(): messages_json_for_user = {} @@ -481,8 +529,43 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore) "device_id": device_id, "stream_id": stream_id, "message_json": message_json, + "instance_name": self._instance_name, } for user_id, messages_by_device in local_by_user_then_device.items() for device_id, message_json in messages_by_device.items() ], ) + + +class DeviceInboxBackgroundUpdateStore(SQLBaseStore): + DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop" + + def __init__(self, database: DatabasePool, db_conn, hs): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_index_update( + "device_inbox_stream_index", + index_name="device_inbox_stream_id_user_id", + table="device_inbox", + columns=["stream_id", "user_id"], + ) + + self.db_pool.updates.register_background_update_handler( + self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox + ) + + async def _background_drop_index_device_inbox(self, progress, batch_size): + def reindex_txn(conn): + txn = conn.cursor() + txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id") + txn.close() + + await self.db_pool.runWithConnection(reindex_txn) + + await self.db_pool.updates._end_background_update(self.DEVICE_INBOX_STREAM_ID) + + return 1 + + +class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore): + pass diff --git a/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql new file mode 100644 index 0000000000..d781a92fec --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/02shard_send_to_device.sql @@ -0,0 +1,18 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE device_inbox ADD COLUMN instance_name TEXT; +ALTER TABLE device_federation_inbox ADD COLUMN instance_name TEXT; +ALTER TABLE device_federation_outbox ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres new file mode 100644 index 0000000000..45a845a3a5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/03shard_send_to_device_sequence.sql.postgres @@ -0,0 +1,25 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS device_inbox_sequence; + +-- We need to take the max across both device_inbox and device_federation_outbox +-- tables as they share the ID generator +SELECT setval('device_inbox_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(stream_id), 1) FROM device_inbox), + (SELECT COALESCE(MAX(stream_id), 1) FROM device_federation_outbox) + ) +)); -- cgit 1.5.1 From d32870ffa5a2353d93e5723787d5f4dcbf14b32d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 8 Jan 2021 14:23:04 +0000 Subject: Fix validate_config on nested objects (#9054) --- changelog.d/9054.bugfix | 1 + synapse/config/_util.py | 2 +- tests/config/test_util.py | 53 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9054.bugfix create mode 100644 tests/config/test_util.py (limited to 'synapse/config') diff --git a/changelog.d/9054.bugfix b/changelog.d/9054.bugfix new file mode 100644 index 0000000000..0bfe951f17 --- /dev/null +++ b/changelog.d/9054.bugfix @@ -0,0 +1 @@ +Fix a minor bug which could cause confusing error messages from invalid configurations. diff --git a/synapse/config/_util.py b/synapse/config/_util.py index 1bbe83c317..8fce7f6bb1 100644 --- a/synapse/config/_util.py +++ b/synapse/config/_util.py @@ -56,7 +56,7 @@ def json_error_to_config_error( """ # copy `config_path` before modifying it. path = list(config_path) - for p in list(e.path): + for p in list(e.absolute_path): if isinstance(p, int): path.append("" % p) else: diff --git a/tests/config/test_util.py b/tests/config/test_util.py new file mode 100644 index 0000000000..10363e3765 --- /dev/null +++ b/tests/config/test_util.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config import ConfigError +from synapse.config._util import validate_config + +from tests.unittest import TestCase + + +class ValidateConfigTestCase(TestCase): + """Test cases for synapse.config._util.validate_config""" + + def test_bad_object_in_array(self): + """malformed objects within an array should be validated correctly""" + + # consider a structure: + # + # array_of_objs: + # - r: 1 + # foo: 2 + # + # - r: 2 + # bar: 3 + # + # ... where each entry must contain an "r": check that the path + # to the required item is correclty reported. + + schema = { + "type": "object", + "properties": { + "array_of_objs": { + "type": "array", + "items": {"type": "object", "required": ["r"]}, + }, + }, + } + + with self.assertRaises(ConfigError) as c: + validate_config(schema, {"array_of_objs": [{}]}, ("base",)) + + self.assertEqual(c.exception.path, ["base", "array_of_objs", ""]) -- cgit 1.5.1 From 7cc9509eca0d754b763253dd3c25cec688b47639 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 18 Dec 2020 12:13:03 +0000 Subject: Extract OIDCProviderConfig object Collect all the config options which related to an OIDC provider into a single object. --- synapse/config/oidc_config.py | 165 ++++++++++++++++++++++++++++----------- synapse/handlers/oidc_handler.py | 37 +++++---- 2 files changed, 140 insertions(+), 62 deletions(-) (limited to 'synapse/config') diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 4e3055282d..9f36e63849 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2020 Quentin Gliech +# Copyright 2020 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Type + +import attr + from synapse.python_dependencies import DependencyException, check_requirements +from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module from ._base import Config, ConfigError @@ -25,65 +31,29 @@ class OIDCConfig(Config): section = "oidc" def read_config(self, config, **kwargs): - self.oidc_enabled = False + self.oidc_provider = None # type: Optional[OidcProviderConfig] oidc_config = config.get("oidc_config") + if oidc_config and oidc_config.get("enabled", False): + self.oidc_provider = _parse_oidc_config_dict(oidc_config) - if not oidc_config or not oidc_config.get("enabled", False): + if not self.oidc_provider: return try: check_requirements("oidc") except DependencyException as e: - raise ConfigError(e.message) + raise ConfigError(e.message) from e public_baseurl = self.public_baseurl if public_baseurl is None: raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" - self.oidc_enabled = True - self.oidc_discover = oidc_config.get("discover", True) - self.oidc_issuer = oidc_config["issuer"] - self.oidc_client_id = oidc_config["client_id"] - self.oidc_client_secret = oidc_config["client_secret"] - self.oidc_client_auth_method = oidc_config.get( - "client_auth_method", "client_secret_basic" - ) - self.oidc_scopes = oidc_config.get("scopes", ["openid"]) - self.oidc_authorization_endpoint = oidc_config.get("authorization_endpoint") - self.oidc_token_endpoint = oidc_config.get("token_endpoint") - self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") - self.oidc_jwks_uri = oidc_config.get("jwks_uri") - self.oidc_skip_verification = oidc_config.get("skip_verification", False) - self.oidc_user_profile_method = oidc_config.get("user_profile_method", "auto") - self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) - - ump_config = oidc_config.get("user_mapping_provider", {}) - ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) - ump_config.setdefault("config", {}) - - ( - self.oidc_user_mapping_provider_class, - self.oidc_user_mapping_provider_config, - ) = load_module(ump_config, ("oidc_config", "user_mapping_provider")) - - # Ensure loaded user mapping module has defined all necessary methods - required_methods = [ - "get_remote_user_id", - "map_user_attributes", - ] - missing_methods = [ - method - for method in required_methods - if not hasattr(self.oidc_user_mapping_provider_class, method) - ] - if missing_methods: - raise ConfigError( - "Class specified by oidc_config." - "user_mapping_provider.module is missing required " - "methods: %s" % (", ".join(missing_methods),) - ) + @property + def oidc_enabled(self) -> bool: + # OIDC is enabled if we have a provider + return bool(self.oidc_provider) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ @@ -224,3 +194,108 @@ class OIDCConfig(Config): """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) + + +def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": + """Take the configuration dict and parse it into an OidcProviderConfig + + Raises: + ConfigError if the configuration is malformed. + """ + ump_config = oidc_config.get("user_mapping_provider", {}) + ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) + ump_config.setdefault("config", {}) + + (user_mapping_provider_class, user_mapping_provider_config,) = load_module( + ump_config, ("oidc_config", "user_mapping_provider") + ) + + # Ensure loaded user mapping module has defined all necessary methods + required_methods = [ + "get_remote_user_id", + "map_user_attributes", + ] + missing_methods = [ + method + for method in required_methods + if not hasattr(user_mapping_provider_class, method) + ] + if missing_methods: + raise ConfigError( + "Class specified by oidc_config." + "user_mapping_provider.module is missing required " + "methods: %s" % (", ".join(missing_methods),) + ) + + return OidcProviderConfig( + discover=oidc_config.get("discover", True), + issuer=oidc_config["issuer"], + client_id=oidc_config["client_id"], + client_secret=oidc_config["client_secret"], + client_auth_method=oidc_config.get("client_auth_method", "client_secret_basic"), + scopes=oidc_config.get("scopes", ["openid"]), + authorization_endpoint=oidc_config.get("authorization_endpoint"), + token_endpoint=oidc_config.get("token_endpoint"), + userinfo_endpoint=oidc_config.get("userinfo_endpoint"), + jwks_uri=oidc_config.get("jwks_uri"), + skip_verification=oidc_config.get("skip_verification", False), + user_profile_method=oidc_config.get("user_profile_method", "auto"), + allow_existing_users=oidc_config.get("allow_existing_users", False), + user_mapping_provider_class=user_mapping_provider_class, + user_mapping_provider_config=user_mapping_provider_config, + ) + + +@attr.s +class OidcProviderConfig: + # whether the OIDC discovery mechanism is used to discover endpoints + discover = attr.ib(type=bool) + + # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to + # discover the provider's endpoints. + issuer = attr.ib(type=str) + + # oauth2 client id to use + client_id = attr.ib(type=str) + + # oauth2 client secret to use + client_secret = attr.ib(type=str) + + # auth method to use when exchanging the token. + # Valid values are 'client_secret_basic', 'client_secret_post' and + # 'none'. + client_auth_method = attr.ib(type=str) + + # list of scopes to request + scopes = attr.ib(type=Collection[str]) + + # the oauth2 authorization endpoint. Required if discovery is disabled. + authorization_endpoint = attr.ib(type=Optional[str]) + + # the oauth2 token endpoint. Required if discovery is disabled. + token_endpoint = attr.ib(type=Optional[str]) + + # the OIDC userinfo endpoint. Required if discovery is disabled and the + # "openid" scope is not requested. + userinfo_endpoint = attr.ib(type=Optional[str]) + + # URI where to fetch the JWKS. Required if discovery is disabled and the + # "openid" scope is used. + jwks_uri = attr.ib(type=Optional[str]) + + # Whether to skip metadata verification + skip_verification = attr.ib(type=bool) + + # Whether to fetch the user profile from the userinfo endpoint. Valid + # values are: "auto" or "userinfo_endpoint". + user_profile_method = attr.ib(type=str) + + # whether to allow a user logging in via OIDC to match a pre-existing account + # instead of failing + allow_existing_users = attr.ib(type=bool) + + # the class of the user mapping provider + user_mapping_provider_class = attr.ib(type=Type) + + # the config of the user mapping provider + user_mapping_provider_config = attr.ib() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 88097639ef..84754e5c9c 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -94,27 +94,30 @@ class OidcHandler: self._token_generator = OidcSessionTokenGenerator(hs) self._callback_url = hs.config.oidc_callback_url # type: str - self._scopes = hs.config.oidc_scopes # type: List[str] - self._user_profile_method = hs.config.oidc_user_profile_method # type: str + + provider = hs.config.oidc.oidc_provider + # we should not have been instantiated if there is no configured provider. + assert provider is not None + + self._scopes = provider.scopes + self._user_profile_method = provider.user_profile_method self._client_auth = ClientAuth( - hs.config.oidc_client_id, - hs.config.oidc_client_secret, - hs.config.oidc_client_auth_method, + provider.client_id, provider.client_secret, provider.client_auth_method, ) # type: ClientAuth - self._client_auth_method = hs.config.oidc_client_auth_method # type: str + self._client_auth_method = provider.client_auth_method self._provider_metadata = OpenIDProviderMetadata( - issuer=hs.config.oidc_issuer, - authorization_endpoint=hs.config.oidc_authorization_endpoint, - token_endpoint=hs.config.oidc_token_endpoint, - userinfo_endpoint=hs.config.oidc_userinfo_endpoint, - jwks_uri=hs.config.oidc_jwks_uri, + issuer=provider.issuer, + authorization_endpoint=provider.authorization_endpoint, + token_endpoint=provider.token_endpoint, + userinfo_endpoint=provider.userinfo_endpoint, + jwks_uri=provider.jwks_uri, ) # type: OpenIDProviderMetadata - self._provider_needs_discovery = hs.config.oidc_discover # type: bool - self._user_mapping_provider = hs.config.oidc_user_mapping_provider_class( - hs.config.oidc_user_mapping_provider_config - ) # type: OidcMappingProvider - self._skip_verification = hs.config.oidc_skip_verification # type: bool - self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool + self._provider_needs_discovery = provider.discover + self._user_mapping_provider = provider.user_mapping_provider_class( + provider.user_mapping_provider_config + ) + self._skip_verification = provider.skip_verification + self._allow_existing_users = provider.allow_existing_users self._http_client = hs.get_proxied_http_client() self._server_name = hs.config.server_name # type: str -- cgit 1.5.1 From dc3c83a9339961e6d52378eeabb68069ac0714cd Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Fri, 18 Dec 2020 13:34:59 +0000 Subject: Add jsonschema verification for the oidc provider config --- synapse/config/oidc_config.py | 50 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) (limited to 'synapse/config') diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 9f36e63849..c705de5694 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -18,6 +18,7 @@ from typing import Optional, Type import attr +from synapse.config._util import validate_config from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module @@ -31,10 +32,13 @@ class OIDCConfig(Config): section = "oidc" def read_config(self, config, **kwargs): + validate_config(MAIN_CONFIG_SCHEMA, config, ()) + self.oidc_provider = None # type: Optional[OidcProviderConfig] oidc_config = config.get("oidc_config") if oidc_config and oidc_config.get("enabled", False): + validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config") self.oidc_provider = _parse_oidc_config_dict(oidc_config) if not self.oidc_provider: @@ -196,6 +200,52 @@ class OIDCConfig(Config): ) +# jsonschema definition of the configuration settings for an oidc identity provider +OIDC_PROVIDER_CONFIG_SCHEMA = { + "type": "object", + "required": ["issuer", "client_id", "client_secret"], + "properties": { + "discover": {"type": "boolean"}, + "issuer": {"type": "string"}, + "client_id": {"type": "string"}, + "client_secret": {"type": "string"}, + "client_auth_method": { + "type": "string", + # the following list is the same as the keys of + # authlib.oauth2.auth.ClientAuth.DEFAULT_AUTH_METHODS. We inline it + # to avoid importing authlib here. + "enum": ["client_secret_basic", "client_secret_post", "none"], + }, + "scopes": {"type": "array", "items": {"type": "string"}}, + "authorization_endpoint": {"type": "string"}, + "token_endpoint": {"type": "string"}, + "userinfo_endpoint": {"type": "string"}, + "jwks_uri": {"type": "string"}, + "skip_verification": {"type": "boolean"}, + "user_profile_method": { + "type": "string", + "enum": ["auto", "userinfo_endpoint"], + }, + "allow_existing_users": {"type": "boolean"}, + "user_mapping_provider": {"type": ["object", "null"]}, + }, +} + +# the `oidc_config` setting can either be None (as it is in the default +# config), or an object. If an object, it is ignored unless it has an "enabled: True" +# property. +# +# It's *possible* to represent this with jsonschema, but the resultant errors aren't +# particularly clear, so we just check for either an object or a null here, and do +# additional checks in the code. +OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]} + +MAIN_CONFIG_SCHEMA = { + "type": "object", + "properties": {"oidc_config": OIDC_CONFIG_SCHEMA}, +} + + def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": """Take the configuration dict and parse it into an OidcProviderConfig -- cgit 1.5.1 From 5310808d3bebd17275355ecd474bc013e8c7462d Mon Sep 17 00:00:00 2001 From: Richard van der Hoff Date: Tue, 12 Jan 2021 18:19:42 +0000 Subject: Give the user a better error when they present bad SSO creds If a user tries to do UI Auth via SSO, but uses the wrong account on the SSO IdP, try to give them a better error. Previously, the UIA would claim to be successful, but then the operation in question would simply fail with "auth fail". Instead, serve up an error page which explains the failure. --- changelog.d/9091.feature | 1 + docs/sample_config.yaml | 8 +++++++ synapse/config/sso.py | 10 +++++++++ synapse/handlers/sso.py | 33 +++++++++++++++++++++++----- synapse/res/templates/sso_auth_bad_user.html | 18 +++++++++++++++ 5 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 changelog.d/9091.feature create mode 100644 synapse/res/templates/sso_auth_bad_user.html (limited to 'synapse/config') diff --git a/changelog.d/9091.feature b/changelog.d/9091.feature new file mode 100644 index 0000000000..79fcd701f8 --- /dev/null +++ b/changelog.d/9091.feature @@ -0,0 +1 @@ +During user-interactive authentication via single-sign-on, give a better error if the user uses the wrong account on the SSO IdP. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c8ae46d1b3..9da351f9f3 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1969,6 +1969,14 @@ sso: # # This template has no additional variables. # + # * HTML page shown after a user-interactive authentication session which + # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. + # + # When rendering, this template is given the following variables: + # * server_name: the homeserver's name. + # * user_id_to_verify: the MXID of the user that we are trying to + # validate. + # # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) # attempts to login: 'sso_account_deactivated.html'. # diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 1aeb1c5c92..366f0d4698 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -37,6 +37,7 @@ class SSOConfig(Config): self.sso_error_template, sso_account_deactivated_template, sso_auth_success_template, + self.sso_auth_bad_user_template, ) = self.read_templates( [ "sso_login_idp_picker.html", @@ -45,6 +46,7 @@ class SSOConfig(Config): "sso_error.html", "sso_account_deactivated.html", "sso_auth_success.html", + "sso_auth_bad_user.html", ], template_dir, ) @@ -160,6 +162,14 @@ class SSOConfig(Config): # # This template has no additional variables. # + # * HTML page shown after a user-interactive authentication session which + # does not map correctly onto the expected user: 'sso_auth_bad_user.html'. + # + # When rendering, this template is given the following variables: + # * server_name: the homeserver's name. + # * user_id_to_verify: the MXID of the user that we are trying to + # validate. + # # * HTML page shown during single sign-on if a deactivated user (according to Synapse's database) # attempts to login: 'sso_account_deactivated.html'. # diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d096e0b091..69ffc9d9c2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -23,6 +23,7 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest @@ -147,6 +148,7 @@ class SsoHandler: self._server_name = hs.hostname self._registration_handler = hs.get_registration_handler() self._error_template = hs.config.sso_error_template + self._bad_user_template = hs.config.sso_auth_bad_user_template self._auth_handler = hs.get_auth_handler() # a lock on the mappings @@ -577,19 +579,40 @@ class SsoHandler: auth_provider_id, remote_user_id, ) + user_id_to_verify = await self._auth_handler.get_session_data( + ui_auth_session_id, UIAuthSessionDataConstants.REQUEST_USER_ID + ) # type: str + if not user_id: logger.warning( "Remote user %s/%s has not previously logged in here: UIA will fail", auth_provider_id, remote_user_id, ) - # Let the UIA flow handle this the same as if they presented creds for a - # different user. - user_id = "" + elif user_id != user_id_to_verify: + logger.warning( + "Remote user %s/%s mapped onto incorrect user %s: UIA will fail", + auth_provider_id, + remote_user_id, + user_id, + ) + else: + # success! + await self._auth_handler.complete_sso_ui_auth( + user_id, ui_auth_session_id, request + ) + return + + # the user_id didn't match: mark the stage of the authentication as unsuccessful + await self._store.mark_ui_auth_stage_complete( + ui_auth_session_id, LoginType.SSO, "" + ) - await self._auth_handler.complete_sso_ui_auth( - user_id, ui_auth_session_id, request + # render an error page. + html = self._bad_user_template.render( + server_name=self._server_name, user_id_to_verify=user_id_to_verify, ) + respond_with_html(request, 200, html) async def check_username_availability( self, localpart: str, session_id: str, diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html new file mode 100644 index 0000000000..3611191bf9 --- /dev/null +++ b/synapse/res/templates/sso_auth_bad_user.html @@ -0,0 +1,18 @@ + + + Authentication Failed + + +
+

+ We were unable to validate your {{server_name | e}} account via + single-sign-on (SSO), because the SSO Identity Provider returned + different details than when you logged in. +

+

+ Try the operation again, and ensure that you use the same details on + the Identity Provider as when you log into your account. +

+
+ + -- cgit 1.5.1 From 4575ad0b1e86c814e6d1c3ca6ac31ba4eeeb5c66 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 Jan 2021 13:22:12 +0000 Subject: Store an IdP ID in the OIDC session (#9109) Again in preparation for handling more than one OIDC provider, add a new caveat to the macaroon used as an OIDC session cookie, which remembers which OIDC provider we are talking to. In future, when we get a callback, we'll need it to make sure we talk to the right IdP. As part of this, I'm adding an idp_id and idp_name field to the OIDC configuration object. They aren't yet documented, and we'll just use the old values by default. --- changelog.d/9109.feature | 1 + synapse/config/oidc_config.py | 26 +++++++++++++++++++++++--- synapse/handlers/oidc_handler.py | 22 ++++++++++++++++------ tests/handlers/test_oidc.py | 3 ++- 4 files changed, 42 insertions(+), 10 deletions(-) create mode 100644 changelog.d/9109.feature (limited to 'synapse/config') diff --git a/changelog.d/9109.feature b/changelog.d/9109.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9109.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index c705de5694..fddca19223 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2020 Quentin Gliech -# Copyright 2020 The Matrix.org Foundation C.I.C. +# Copyright 2020-2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import string from typing import Optional, Type import attr @@ -38,7 +39,7 @@ class OIDCConfig(Config): oidc_config = config.get("oidc_config") if oidc_config and oidc_config.get("enabled", False): - validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config") + validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) self.oidc_provider = _parse_oidc_config_dict(oidc_config) if not self.oidc_provider: @@ -205,6 +206,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", "required": ["issuer", "client_id", "client_secret"], "properties": { + "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, + "idp_name": {"type": "string"}, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -277,7 +280,17 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": "methods: %s" % (", ".join(missing_methods),) ) + # MSC2858 will appy certain limits in what can be used as an IdP id, so let's + # enforce those limits now. + idp_id = oidc_config.get("idp_id", "oidc") + valid_idp_chars = set(string.ascii_letters + string.digits + "-._~") + + if any(c not in valid_idp_chars for c in idp_id): + raise ConfigError('idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"') + return OidcProviderConfig( + idp_id=idp_id, + idp_name=oidc_config.get("idp_name", "OIDC"), discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -296,8 +309,15 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": ) -@attr.s +@attr.s(slots=True, frozen=True) class OidcProviderConfig: + # a unique identifier for this identity provider. Used in the 'user_external_ids' + # table, as well as the query/path parameter used in the login protocol. + idp_id = attr.ib(type=str) + + # user-facing name for this identity provider. + idp_name = attr.ib(type=str) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index d6347bb1b8..f63a90ec5c 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -175,7 +175,7 @@ class OidcHandler: session_data = self._token_generator.verify_oidc_session_token( session, state ) - except MacaroonDeserializationException as e: + except (MacaroonDeserializationException, ValueError) as e: logger.exception("Invalid session") self._sso_handler.render_error(request, "invalid_session", str(e)) return @@ -253,10 +253,10 @@ class OidcProvider: self._server_name = hs.config.server_name # type: str # identifier for the external_ids table - self.idp_id = "oidc" + self.idp_id = provider.idp_id # user-facing name of this auth provider - self.idp_name = "OIDC" + self.idp_name = provider.idp_name self._sso_handler = hs.get_sso_handler() @@ -656,6 +656,7 @@ class OidcProvider: cookie = self._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( + idp_id=self.idp_id, nonce=nonce, client_redirect_url=client_redirect_url.decode(), ui_auth_session_id=ui_auth_session_id, @@ -924,6 +925,7 @@ class OidcSessionTokenGenerator: macaroon.add_first_party_caveat("gen = 1") macaroon.add_first_party_caveat("type = session") macaroon.add_first_party_caveat("state = %s" % (state,)) + macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,)) macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,)) macaroon.add_first_party_caveat( "client_redirect_url = %s" % (session_data.client_redirect_url,) @@ -952,6 +954,9 @@ class OidcSessionTokenGenerator: Returns: The data extracted from the session cookie + + Raises: + ValueError if an expected caveat is missing from the macaroon. """ macaroon = pymacaroons.Macaroon.deserialize(session) @@ -960,6 +965,7 @@ class OidcSessionTokenGenerator: v.satisfy_exact("type = session") v.satisfy_exact("state = %s" % (state,)) v.satisfy_general(lambda c: c.startswith("nonce = ")) + v.satisfy_general(lambda c: c.startswith("idp_id = ")) v.satisfy_general(lambda c: c.startswith("client_redirect_url = ")) # Sometimes there's a UI auth session ID, it seems to be OK to attempt # to always satisfy this. @@ -968,9 +974,9 @@ class OidcSessionTokenGenerator: v.verify(macaroon, self._macaroon_secret_key) - # Extract the `nonce`, `client_redirect_url`, and maybe the - # `ui_auth_session_id` from the token. + # Extract the session data from the token. nonce = self._get_value_from_macaroon(macaroon, "nonce") + idp_id = self._get_value_from_macaroon(macaroon, "idp_id") client_redirect_url = self._get_value_from_macaroon( macaroon, "client_redirect_url" ) @@ -983,6 +989,7 @@ class OidcSessionTokenGenerator: return OidcSessionData( nonce=nonce, + idp_id=idp_id, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, ) @@ -998,7 +1005,7 @@ class OidcSessionTokenGenerator: The extracted value Raises: - Exception: if the caveat was not in the macaroon + ValueError: if the caveat was not in the macaroon """ prefix = key + " = " for caveat in macaroon.caveats: @@ -1019,6 +1026,9 @@ class OidcSessionTokenGenerator: class OidcSessionData: """The attributes which are stored in a OIDC session cookie""" + # the Identity Provider being used + idp_id = attr.ib(type=str) + # The `nonce` parameter passed to the OIDC provider. nonce = attr.ib(type=str) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5d338bea87..38ae8ca19e 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -848,6 +848,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return self.handler._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( + idp_id="oidc", nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=ui_auth_session_id, @@ -990,7 +991,7 @@ async def _make_callback_with_userinfo( session = handler._token_generator.generate_oidc_session_token( state=state, session_data=OidcSessionData( - nonce="nonce", client_redirect_url=client_redirect_url, + idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url, ), ) request = _build_callback_request("code", state, session) -- cgit 1.5.1 From 9ffac2bef1cbf74694280e4976605f3563f97074 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 Jan 2021 15:59:20 +0000 Subject: Remote dependency on distutils (#9125) `distutils` is pretty much deprecated these days, and replaced with `setuptools`. It's also annoying because it's you can't `pip install` it, and it's hard to figure out which debian package we should depend on to make sure it's there. Since we only use it for a tiny function anyway, let's just vendor said function into our codebase. --- changelog.d/9125.misc | 1 + debian/changelog | 6 ++++++ debian/control | 1 - synapse/config/registration.py | 11 +++++------ synapse/events/__init__.py | 3 ++- synapse/util/stringutils.py | 19 +++++++++++++++++++ 6 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 changelog.d/9125.misc (limited to 'synapse/config') diff --git a/changelog.d/9125.misc b/changelog.d/9125.misc new file mode 100644 index 0000000000..08459caf5a --- /dev/null +++ b/changelog.d/9125.misc @@ -0,0 +1 @@ +Remove dependency on `distutils`. diff --git a/debian/changelog b/debian/changelog index 609436bf75..1c6308e3a2 100644 --- a/debian/changelog +++ b/debian/changelog @@ -1,3 +1,9 @@ +matrix-synapse-py3 (1.25.0ubuntu1) UNRELEASED; urgency=medium + + * Remove dependency on `python3-distutils`. + + -- Richard van der Hoff Fri, 15 Jan 2021 12:44:19 +0000 + matrix-synapse-py3 (1.25.0) stable; urgency=medium [ Dan Callahan ] diff --git a/debian/control b/debian/control index b10401be43..8167a901a4 100644 --- a/debian/control +++ b/debian/control @@ -31,7 +31,6 @@ Pre-Depends: dpkg (>= 1.16.1) Depends: adduser, debconf, - python3-distutils|libpython3-stdlib (<< 3.6), ${misc:Depends}, ${shlibs:Depends}, ${synapse:pydepends}, diff --git a/synapse/config/registration.py b/synapse/config/registration.py index cc5f75123c..740c3fc1b1 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -14,14 +14,13 @@ # limitations under the License. import os -from distutils.util import strtobool import pkg_resources from synapse.api.constants import RoomCreationPreset from synapse.config._base import Config, ConfigError from synapse.types import RoomAlias, UserID -from synapse.util.stringutils import random_string_with_symbols +from synapse.util.stringutils import random_string_with_symbols, strtobool class AccountValidityConfig(Config): @@ -86,12 +85,12 @@ class RegistrationConfig(Config): section = "registration" def read_config(self, config, **kwargs): - self.enable_registration = bool( - strtobool(str(config.get("enable_registration", False))) + self.enable_registration = strtobool( + str(config.get("enable_registration", False)) ) if "disable_registration" in config: - self.enable_registration = not bool( - strtobool(str(config["disable_registration"])) + self.enable_registration = not strtobool( + str(config["disable_registration"]) ) self.account_validity = AccountValidityConfig( diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 8028663fa8..3ec4120f85 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -17,7 +17,6 @@ import abc import os -from distutils.util import strtobool from typing import Dict, Optional, Tuple, Type from unpaddedbase64 import encode_base64 @@ -26,6 +25,7 @@ from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVers from synapse.types import JsonDict, RoomStreamToken from synapse.util.caches import intern_dict from synapse.util.frozenutils import freeze +from synapse.util.stringutils import strtobool # Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents # bugs where we accidentally share e.g. signature dicts. However, converting a @@ -34,6 +34,7 @@ from synapse.util.frozenutils import freeze # NOTE: This is overridden by the configuration by the Synapse worker apps, but # for the sake of tests, it is set here while it cannot be configured on the # homeserver object itself. + USE_FROZEN_DICTS = strtobool(os.environ.get("SYNAPSE_USE_FROZEN_DICTS", "0")) diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index 61d96a6c28..b103c8694c 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -75,3 +75,22 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str: if len(items) <= maxitems: return str(items) return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]" + + +def strtobool(val: str) -> bool: + """Convert a string representation of truth to True or False + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values + are 'n', 'no', 'f', 'false', 'off', and '0'. Raises ValueError if + 'val' is anything else. + + This is lifted from distutils.util.strtobool, with the exception that it actually + returns a bool, rather than an int. + """ + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError("invalid truth value %r" % (val,)) -- cgit 1.5.1 From 9de6b9411750c9adf72bdd9d180d2f51b89e3c03 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Fri, 15 Jan 2021 16:55:29 +0000 Subject: Land support for multiple OIDC providers (#9110) This is the final step for supporting multiple OIDC providers concurrently. First of all, we reorganise the config so that you can specify a list of OIDC providers, instead of a single one. Before: oidc_config: enabled: true issuer: "https://oidc_provider" # etc After: oidc_providers: - idp_id: prov1 issuer: "https://oidc_provider" - idp_id: prov2 issuer: "https://another_oidc_provider" The old format is still grandfathered in. With that done, it's then simply a matter of having OidcHandler instantiate a new OidcProvider for each configured provider. --- changelog.d/9110.feature | 1 + docs/openid.md | 201 ++++++++++++------------ docs/sample_config.yaml | 274 ++++++++++++++++---------------- synapse/config/cas.py | 2 +- synapse/config/oidc_config.py | 329 ++++++++++++++++++++++----------------- synapse/handlers/oidc_handler.py | 27 +++- tests/handlers/test_oidc.py | 4 +- 7 files changed, 456 insertions(+), 382 deletions(-) create mode 100644 changelog.d/9110.feature (limited to 'synapse/config') diff --git a/changelog.d/9110.feature b/changelog.d/9110.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9110.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/openid.md b/docs/openid.md index ffa4238fff..b86ae89768 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -42,11 +42,10 @@ as follows: * For other installation mechanisms, see the documentation provided by the maintainer. -To enable the OpenID integration, you should then add an `oidc_config` section -to your configuration file (or uncomment the `enabled: true` line in the -existing section). See [sample_config.yaml](./sample_config.yaml) for some -sample settings, as well as the text below for example configurations for -specific providers. +To enable the OpenID integration, you should then add a section to the `oidc_providers` +setting in your configuration file (or uncomment one of the existing examples). +See [sample_config.yaml](./sample_config.yaml) for some sample settings, as well as +the text below for example configurations for specific providers. ## Sample configs @@ -62,20 +61,21 @@ Directory (tenant) ID as it will be used in the Azure links. Edit your Synapse config file and change the `oidc_config` section: ```yaml -oidc_config: - enabled: true - issuer: "https://login.microsoftonline.com//v2.0" - client_id: "" - client_secret: "" - scopes: ["openid", "profile"] - authorization_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/authorize" - token_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/token" - userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" - - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username.split('@')[0] }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: microsoft + idp_name: Microsoft + issuer: "https://login.microsoftonline.com//v2.0" + client_id: "" + client_secret: "" + scopes: ["openid", "profile"] + authorization_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/authorize" + token_endpoint: "https://login.microsoftonline.com//oauth2/v2.0/token" + userinfo_endpoint: "https://graph.microsoft.com/oidc/userinfo" + + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username.split('@')[0] }}" + display_name_template: "{{ user.name }}" ``` ### [Dex][dex-idp] @@ -103,17 +103,18 @@ Run with `dex serve examples/config-dev.yaml`. Synapse config: ```yaml -oidc_config: - enabled: true - skip_verification: true # This is needed as Dex is served on an insecure endpoint - issuer: "http://127.0.0.1:5556/dex" - client_id: "synapse" - client_secret: "secret" - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.name }}" - display_name_template: "{{ user.name|capitalize }}" +oidc_providers: + - idp_id: dex + idp_name: "My Dex server" + skip_verification: true # This is needed as Dex is served on an insecure endpoint + issuer: "http://127.0.0.1:5556/dex" + client_id: "synapse" + client_secret: "secret" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.name }}" + display_name_template: "{{ user.name|capitalize }}" ``` ### [Keycloak][keycloak-idp] @@ -152,16 +153,17 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to 8. Copy Secret ```yaml -oidc_config: - enabled: true - issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" - client_id: "synapse" - client_secret: "copy secret generated from above" - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: keycloak + idp_name: "My KeyCloak server" + issuer: "https://127.0.0.1:8443/auth/realms/{realm_name}" + client_id: "synapse" + client_secret: "copy secret generated from above" + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### [Auth0][auth0] @@ -191,16 +193,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: auth0 + idp_name: Auth0 + issuer: "https://your-tier.eu.auth0.com/" # TO BE FILLED + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### GitHub @@ -219,21 +222,22 @@ does not return a `sub` property, an alternative `subject_claim` has to be set. Synapse config: ```yaml -oidc_config: - enabled: true - discover: false - issuer: "https://github.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - authorization_endpoint: "https://github.com/login/oauth/authorize" - token_endpoint: "https://github.com/login/oauth/access_token" - userinfo_endpoint: "https://api.github.com/user" - scopes: ["read:user"] - user_mapping_provider: - config: - subject_claim: "id" - localpart_template: "{{ user.login }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: github + idp_name: Github + discover: false + issuer: "https://github.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + authorization_endpoint: "https://github.com/login/oauth/authorize" + token_endpoint: "https://github.com/login/oauth/access_token" + userinfo_endpoint: "https://api.github.com/user" + scopes: ["read:user"] + user_mapping_provider: + config: + subject_claim: "id" + localpart_template: "{{ user.login }}" + display_name_template: "{{ user.name }}" ``` ### [Google][google-idp] @@ -243,16 +247,17 @@ oidc_config: 2. add an "OAuth Client ID" for a Web Application under "Credentials". 3. Copy the Client ID and Client Secret, and add the following to your synapse config: ```yaml - oidc_config: - enabled: true - issuer: "https://accounts.google.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - scopes: ["openid", "profile"] - user_mapping_provider: - config: - localpart_template: "{{ user.given_name|lower }}" - display_name_template: "{{ user.name }}" + oidc_providers: + - idp_id: google + idp_name: Google + issuer: "https://accounts.google.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + scopes: ["openid", "profile"] + user_mapping_provider: + config: + localpart_template: "{{ user.given_name|lower }}" + display_name_template: "{{ user.name }}" ``` 4. Back in the Google console, add this Authorized redirect URI: `[synapse public baseurl]/_synapse/oidc/callback`. @@ -266,16 +271,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://id.twitch.tv/oauth2/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - user_mapping_provider: - config: - localpart_template: "{{ user.preferred_username }}" - display_name_template: "{{ user.name }}" +oidc_providers: + - idp_id: twitch + idp_name: Twitch + issuer: "https://id.twitch.tv/oauth2/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + user_mapping_provider: + config: + localpart_template: "{{ user.preferred_username }}" + display_name_template: "{{ user.name }}" ``` ### GitLab @@ -287,16 +293,17 @@ oidc_config: Synapse config: ```yaml -oidc_config: - enabled: true - issuer: "https://gitlab.com/" - client_id: "your-client-id" # TO BE FILLED - client_secret: "your-client-secret" # TO BE FILLED - client_auth_method: "client_secret_post" - scopes: ["openid", "read_user"] - user_profile_method: "userinfo_endpoint" - user_mapping_provider: - config: - localpart_template: '{{ user.nickname }}' - display_name_template: '{{ user.name }}' +oidc_providers: + - idp_id: gitlab + idp_name: Gitlab + issuer: "https://gitlab.com/" + client_id: "your-client-id" # TO BE FILLED + client_secret: "your-client-secret" # TO BE FILLED + client_auth_method: "client_secret_post" + scopes: ["openid", "read_user"] + user_profile_method: "userinfo_endpoint" + user_mapping_provider: + config: + localpart_template: '{{ user.nickname }}' + display_name_template: '{{ user.name }}' ``` diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9da351f9f3..ae995efe9b 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1709,141 +1709,149 @@ saml2_config: #idp_entityid: 'https://our_idp/entityid' -# Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. +# List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration +# and login. # -# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md -# for some example configurations. +# Options for each entry include: # -oidc_config: - # Uncomment the following to enable authorization against an OpenID Connect - # server. Defaults to false. - # - #enabled: true - - # Uncomment the following to disable use of the OIDC discovery mechanism to - # discover endpoints. Defaults to true. - # - #discover: false - - # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to - # discover the provider's endpoints. - # - # Required if 'enabled' is true. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. - # - # Required if 'enabled' is true. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. - # - # Required if 'enabled' is true. - # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are 'client_secret_basic' (default), 'client_secret_post' and - # 'none'. - # - #client_auth_method: client_secret_post - - # list of scopes to request. This should normally include the "openid" scope. - # Defaults to ["openid"]. - # - #scopes: ["openid", "profile"] - - # the oauth2 authorization endpoint. Required if provider discovery is disabled. - # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" - - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" - - # the OIDC userinfo endpoint. Required if discovery is disabled and the - # "openid" scope is not requested. - # - #userinfo_endpoint: "https://accounts.example.com/userinfo" - - # URI where to fetch the JWKS. Required if discovery is disabled and the - # "openid" scope is used. - # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" - - # Uncomment to skip metadata verification. Defaults to false. - # - # Use this if you are connecting to a provider that is not OpenID Connect - # compliant. - # Avoid this in production. - # - #skip_verification: true - - # Whether to fetch the user profile from the userinfo endpoint. Valid - # values are: "auto" or "userinfo_endpoint". - # - # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included - # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. - # - #user_profile_method: "userinfo_endpoint" - - # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead - # of failing. This could be used if switching from password logins to OIDC. Defaults to false. - # - #allow_existing_users: true - - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. - # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. - # - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers - # for information on implementing a custom mapping provider. - # - #module: mapping_provider.OidcMappingProvider - - # Custom configuration values for the module. This section will be passed as - # a Python dictionary to the user mapping provider module's `parse_config` - # method. - # - # The examples below are intended for the default provider: they should be - # changed if using a custom provider. - # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID. - # - # When rendering, this template is given the following variables: - # * user: The claims returned by the UserInfo Endpoint and/or in the ID - # Token - # - # If this is not set, the user will be prompted to choose their - # own username. - # - #localpart_template: "{{ user.preferred_username }}" - - # Jinja2 template for the display name to set on first login. - # - # If unset, no displayname will be set. - # - #display_name_template: "{{ user.given_name }} {{ user.last_name }}" - - # Jinja2 templates for extra attributes to send back to the client during - # login. - # - # Note that these are non-standard and clients will ignore them without modifications. - # - #extra_attributes: - #birthdate: "{{ user.birthdate }}" - +# idp_id: a unique identifier for this identity provider. Used internally +# by Synapse; should be a single word such as 'github'. +# +# Note that, if this is changed, users authenticating via that provider +# will no longer be recognised as the same user! +# +# idp_name: A user-facing name for this identity provider, which is used to +# offer the user a choice of login mechanisms. +# +# discover: set to 'false' to disable the use of the OIDC discovery mechanism +# to discover endpoints. Defaults to true. +# +# issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery +# is enabled) to discover the provider's endpoints. +# +# client_id: Required. oauth2 client id to use. +# +# client_secret: Required. oauth2 client secret to use. +# +# client_auth_method: auth method to use when exchanging the token. Valid +# values are 'client_secret_basic' (default), 'client_secret_post' and +# 'none'. +# +# scopes: list of scopes to request. This should normally include the "openid" +# scope. Defaults to ["openid"]. +# +# authorization_endpoint: the oauth2 authorization endpoint. Required if +# provider discovery is disabled. +# +# token_endpoint: the oauth2 token endpoint. Required if provider discovery is +# disabled. +# +# userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is +# disabled and the 'openid' scope is not requested. +# +# jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and +# the 'openid' scope is used. +# +# skip_verification: set to 'true' to skip metadata verification. Use this if +# you are connecting to a provider that is not OpenID Connect compliant. +# Defaults to false. Avoid this in production. +# +# user_profile_method: Whether to fetch the user profile from the userinfo +# endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. +# +# Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is +# included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the +# userinfo endpoint. +# +# allow_existing_users: set to 'true' to allow a user logging in via OIDC to +# match a pre-existing account instead of failing. This could be used if +# switching from password logins to OIDC. Defaults to false. +# +# user_mapping_provider: Configuration for how attributes returned from a OIDC +# provider are mapped onto a matrix user. This setting has the following +# sub-properties: +# +# module: The class name of a custom mapping module. Default is +# 'synapse.handlers.oidc_handler.JinjaOidcMappingProvider'. +# See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers +# for information on implementing a custom mapping provider. +# +# config: Configuration for the mapping provider module. This section will +# be passed as a Python dictionary to the user mapping provider +# module's `parse_config` method. +# +# For the default provider, the following settings are available: +# +# sub: name of the claim containing a unique identifier for the +# user. Defaults to 'sub', which OpenID Connect compliant +# providers should provide. +# +# localpart_template: Jinja2 template for the localpart of the MXID. +# If this is not set, the user will be prompted to choose their +# own username. +# +# display_name_template: Jinja2 template for the display name to set +# on first login. If unset, no displayname will be set. +# +# extra_attributes: a map of Jinja2 templates for extra attributes +# to send back to the client during login. +# Note that these are non-standard and clients will ignore them +# without modifications. +# +# When rendering, the Jinja2 templates are given a 'user' variable, +# which is set to the claims returned by the UserInfo Endpoint and/or +# in the ID Token. +# +# See https://github.com/matrix-org/synapse/blob/master/docs/openid.md +# for information on how to configure these options. +# +# For backwards compatibility, it is also possible to configure a single OIDC +# provider via an 'oidc_config' setting. This is now deprecated and admins are +# advised to migrate to the 'oidc_providers' format. +# +oidc_providers: + # Generic example + # + #- idp_id: my_idp + # idp_name: "My OpenID provider" + # discover: false + # issuer: "https://accounts.example.com/" + # client_id: "provided-by-your-issuer" + # client_secret: "provided-by-your-issuer" + # client_auth_method: client_secret_post + # scopes: ["openid", "profile"] + # authorization_endpoint: "https://accounts.example.com/oauth2/auth" + # token_endpoint: "https://accounts.example.com/oauth2/token" + # userinfo_endpoint: "https://accounts.example.com/userinfo" + # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + # skip_verification: true + + # For use with Keycloak + # + #- idp_id: keycloak + # idp_name: Keycloak + # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" + # client_id: "synapse" + # client_secret: "copy secret generated in Keycloak UI" + # scopes: ["openid", "profile"] + + # For use with Github + # + #- idp_id: google + # idp_name: Google + # discover: false + # issuer: "https://github.com/" + # client_id: "your-client-id" # TO BE FILLED + # client_secret: "your-client-secret" # TO BE FILLED + # authorization_endpoint: "https://github.com/login/oauth/authorize" + # token_endpoint: "https://github.com/login/oauth/access_token" + # userinfo_endpoint: "https://api.github.com/user" + # scopes: ["read:user"] + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{ user.login }" + # display_name_template: "{ user.name }" # Enable Central Authentication Service (CAS) for registration and login. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index 2f97e6d258..c7877b4095 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -40,7 +40,7 @@ class CasConfig(Config): self.cas_required_attributes = {} def generate_config_section(self, config_dir_path, server_name, **kwargs): - return """ + return """\ # Enable Central Authentication Service (CAS) for registration and login. # cas_config: diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index fddca19223..c7fa749377 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -15,7 +15,7 @@ # limitations under the License. import string -from typing import Optional, Type +from typing import Iterable, Optional, Type import attr @@ -33,16 +33,8 @@ class OIDCConfig(Config): section = "oidc" def read_config(self, config, **kwargs): - validate_config(MAIN_CONFIG_SCHEMA, config, ()) - - self.oidc_provider = None # type: Optional[OidcProviderConfig] - - oidc_config = config.get("oidc_config") - if oidc_config and oidc_config.get("enabled", False): - validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) - self.oidc_provider = _parse_oidc_config_dict(oidc_config) - - if not self.oidc_provider: + self.oidc_providers = tuple(_parse_oidc_provider_configs(config)) + if not self.oidc_providers: return try: @@ -58,144 +50,153 @@ class OIDCConfig(Config): @property def oidc_enabled(self) -> bool: # OIDC is enabled if we have a provider - return bool(self.oidc_provider) + return bool(self.oidc_providers) def generate_config_section(self, config_dir_path, server_name, **kwargs): return """\ - # Enable OpenID Connect (OIDC) / OAuth 2.0 for registration and login. + # List of OpenID Connect (OIDC) / OAuth 2.0 identity providers, for registration + # and login. + # + # Options for each entry include: + # + # idp_id: a unique identifier for this identity provider. Used internally + # by Synapse; should be a single word such as 'github'. + # + # Note that, if this is changed, users authenticating via that provider + # will no longer be recognised as the same user! + # + # idp_name: A user-facing name for this identity provider, which is used to + # offer the user a choice of login mechanisms. + # + # discover: set to 'false' to disable the use of the OIDC discovery mechanism + # to discover endpoints. Defaults to true. + # + # issuer: Required. The OIDC issuer. Used to validate tokens and (if discovery + # is enabled) to discover the provider's endpoints. + # + # client_id: Required. oauth2 client id to use. + # + # client_secret: Required. oauth2 client secret to use. + # + # client_auth_method: auth method to use when exchanging the token. Valid + # values are 'client_secret_basic' (default), 'client_secret_post' and + # 'none'. + # + # scopes: list of scopes to request. This should normally include the "openid" + # scope. Defaults to ["openid"]. + # + # authorization_endpoint: the oauth2 authorization endpoint. Required if + # provider discovery is disabled. + # + # token_endpoint: the oauth2 token endpoint. Required if provider discovery is + # disabled. + # + # userinfo_endpoint: the OIDC userinfo endpoint. Required if discovery is + # disabled and the 'openid' scope is not requested. + # + # jwks_uri: URI where to fetch the JWKS. Required if discovery is disabled and + # the 'openid' scope is used. + # + # skip_verification: set to 'true' to skip metadata verification. Use this if + # you are connecting to a provider that is not OpenID Connect compliant. + # Defaults to false. Avoid this in production. + # + # user_profile_method: Whether to fetch the user profile from the userinfo + # endpoint. Valid values are: 'auto' or 'userinfo_endpoint'. + # + # Defaults to 'auto', which fetches the userinfo endpoint if 'openid' is + # included in 'scopes'. Set to 'userinfo_endpoint' to always fetch the + # userinfo endpoint. + # + # allow_existing_users: set to 'true' to allow a user logging in via OIDC to + # match a pre-existing account instead of failing. This could be used if + # switching from password logins to OIDC. Defaults to false. + # + # user_mapping_provider: Configuration for how attributes returned from a OIDC + # provider are mapped onto a matrix user. This setting has the following + # sub-properties: + # + # module: The class name of a custom mapping module. Default is + # {mapping_provider!r}. + # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers + # for information on implementing a custom mapping provider. + # + # config: Configuration for the mapping provider module. This section will + # be passed as a Python dictionary to the user mapping provider + # module's `parse_config` method. + # + # For the default provider, the following settings are available: + # + # sub: name of the claim containing a unique identifier for the + # user. Defaults to 'sub', which OpenID Connect compliant + # providers should provide. + # + # localpart_template: Jinja2 template for the localpart of the MXID. + # If this is not set, the user will be prompted to choose their + # own username. + # + # display_name_template: Jinja2 template for the display name to set + # on first login. If unset, no displayname will be set. + # + # extra_attributes: a map of Jinja2 templates for extra attributes + # to send back to the client during login. + # Note that these are non-standard and clients will ignore them + # without modifications. + # + # When rendering, the Jinja2 templates are given a 'user' variable, + # which is set to the claims returned by the UserInfo Endpoint and/or + # in the ID Token. # # See https://github.com/matrix-org/synapse/blob/master/docs/openid.md - # for some example configurations. + # for information on how to configure these options. # - oidc_config: - # Uncomment the following to enable authorization against an OpenID Connect - # server. Defaults to false. - # - #enabled: true - - # Uncomment the following to disable use of the OIDC discovery mechanism to - # discover endpoints. Defaults to true. - # - #discover: false - - # the OIDC issuer. Used to validate tokens and (if discovery is enabled) to - # discover the provider's endpoints. - # - # Required if 'enabled' is true. - # - #issuer: "https://accounts.example.com/" - - # oauth2 client id to use. - # - # Required if 'enabled' is true. - # - #client_id: "provided-by-your-issuer" - - # oauth2 client secret to use. - # - # Required if 'enabled' is true. - # - #client_secret: "provided-by-your-issuer" - - # auth method to use when exchanging the token. - # Valid values are 'client_secret_basic' (default), 'client_secret_post' and - # 'none'. - # - #client_auth_method: client_secret_post - - # list of scopes to request. This should normally include the "openid" scope. - # Defaults to ["openid"]. - # - #scopes: ["openid", "profile"] - - # the oauth2 authorization endpoint. Required if provider discovery is disabled. - # - #authorization_endpoint: "https://accounts.example.com/oauth2/auth" - - # the oauth2 token endpoint. Required if provider discovery is disabled. - # - #token_endpoint: "https://accounts.example.com/oauth2/token" - - # the OIDC userinfo endpoint. Required if discovery is disabled and the - # "openid" scope is not requested. - # - #userinfo_endpoint: "https://accounts.example.com/userinfo" - - # URI where to fetch the JWKS. Required if discovery is disabled and the - # "openid" scope is used. - # - #jwks_uri: "https://accounts.example.com/.well-known/jwks.json" - - # Uncomment to skip metadata verification. Defaults to false. - # - # Use this if you are connecting to a provider that is not OpenID Connect - # compliant. - # Avoid this in production. - # - #skip_verification: true - - # Whether to fetch the user profile from the userinfo endpoint. Valid - # values are: "auto" or "userinfo_endpoint". + # For backwards compatibility, it is also possible to configure a single OIDC + # provider via an 'oidc_config' setting. This is now deprecated and admins are + # advised to migrate to the 'oidc_providers' format. + # + oidc_providers: + # Generic example # - # Defaults to "auto", which fetches the userinfo endpoint if "openid" is included - # in `scopes`. Uncomment the following to always fetch the userinfo endpoint. + #- idp_id: my_idp + # idp_name: "My OpenID provider" + # discover: false + # issuer: "https://accounts.example.com/" + # client_id: "provided-by-your-issuer" + # client_secret: "provided-by-your-issuer" + # client_auth_method: client_secret_post + # scopes: ["openid", "profile"] + # authorization_endpoint: "https://accounts.example.com/oauth2/auth" + # token_endpoint: "https://accounts.example.com/oauth2/token" + # userinfo_endpoint: "https://accounts.example.com/userinfo" + # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" + # skip_verification: true + + # For use with Keycloak # - #user_profile_method: "userinfo_endpoint" - - # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead - # of failing. This could be used if switching from password logins to OIDC. Defaults to false. - # - #allow_existing_users: true - - # An external module can be provided here as a custom solution to mapping - # attributes returned from a OIDC provider onto a matrix user. + #- idp_id: keycloak + # idp_name: Keycloak + # issuer: "https://127.0.0.1:8443/auth/realms/my_realm_name" + # client_id: "synapse" + # client_secret: "copy secret generated in Keycloak UI" + # scopes: ["openid", "profile"] + + # For use with Github # - user_mapping_provider: - # The custom module's class. Uncomment to use a custom module. - # Default is {mapping_provider!r}. - # - # See https://github.com/matrix-org/synapse/blob/master/docs/sso_mapping_providers.md#openid-mapping-providers - # for information on implementing a custom mapping provider. - # - #module: mapping_provider.OidcMappingProvider - - # Custom configuration values for the module. This section will be passed as - # a Python dictionary to the user mapping provider module's `parse_config` - # method. - # - # The examples below are intended for the default provider: they should be - # changed if using a custom provider. - # - config: - # name of the claim containing a unique identifier for the user. - # Defaults to `sub`, which OpenID Connect compliant providers should provide. - # - #subject_claim: "sub" - - # Jinja2 template for the localpart of the MXID. - # - # When rendering, this template is given the following variables: - # * user: The claims returned by the UserInfo Endpoint and/or in the ID - # Token - # - # If this is not set, the user will be prompted to choose their - # own username. - # - #localpart_template: "{{{{ user.preferred_username }}}}" - - # Jinja2 template for the display name to set on first login. - # - # If unset, no displayname will be set. - # - #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" - - # Jinja2 templates for extra attributes to send back to the client during - # login. - # - # Note that these are non-standard and clients will ignore them without modifications. - # - #extra_attributes: - #birthdate: "{{{{ user.birthdate }}}}" + #- idp_id: google + # idp_name: Google + # discover: false + # issuer: "https://github.com/" + # client_id: "your-client-id" # TO BE FILLED + # client_secret: "your-client-secret" # TO BE FILLED + # authorization_endpoint: "https://github.com/login/oauth/authorize" + # token_endpoint: "https://github.com/login/oauth/access_token" + # userinfo_endpoint: "https://api.github.com/user" + # scopes: ["read:user"] + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{{ user.login }}" + # display_name_template: "{{ user.name }}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) @@ -234,7 +235,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { }, } -# the `oidc_config` setting can either be None (as it is in the default +# the same as OIDC_PROVIDER_CONFIG_SCHEMA, but with compulsory idp_id and idp_name +OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA = { + "allOf": [OIDC_PROVIDER_CONFIG_SCHEMA, {"required": ["idp_id", "idp_name"]}] +} + + +# the `oidc_providers` list can either be None (as it is in the default config), or +# a list of provider configs, each of which requires an explicit ID and name. +OIDC_PROVIDER_LIST_SCHEMA = { + "oneOf": [ + {"type": "null"}, + {"type": "array", "items": OIDC_PROVIDER_CONFIG_WITH_ID_SCHEMA}, + ] +} + +# the `oidc_config` setting can either be None (which it used to be in the default # config), or an object. If an object, it is ignored unless it has an "enabled: True" # property. # @@ -243,12 +259,41 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { # additional checks in the code. OIDC_CONFIG_SCHEMA = {"oneOf": [{"type": "null"}, {"type": "object"}]} +# the top-level schema can contain an "oidc_config" and/or an "oidc_providers". MAIN_CONFIG_SCHEMA = { "type": "object", - "properties": {"oidc_config": OIDC_CONFIG_SCHEMA}, + "properties": { + "oidc_config": OIDC_CONFIG_SCHEMA, + "oidc_providers": OIDC_PROVIDER_LIST_SCHEMA, + }, } +def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConfig"]: + """extract and parse the OIDC provider configs from the config dict + + The configuration may contain either a single `oidc_config` object with an + `enabled: True` property, or a list of provider configurations under + `oidc_providers`, *or both*. + + Returns a generator which yields the OidcProviderConfig objects + """ + validate_config(MAIN_CONFIG_SCHEMA, config, ()) + + for p in config.get("oidc_providers") or []: + yield _parse_oidc_config_dict(p) + + # for backwards-compatibility, it is also possible to provide a single "oidc_config" + # object with an "enabled: True" property. + oidc_config = config.get("oidc_config") + if oidc_config and oidc_config.get("enabled", False): + # MAIN_CONFIG_SCHEMA checks that `oidc_config` is an object, but not that + # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA + # above), so now we need to validate it. + validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) + yield _parse_oidc_config_dict(oidc_config) + + def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": """Take the configuration dict and parse it into an OidcProviderConfig diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index f63a90ec5c..5e5fda7b2f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -78,21 +78,28 @@ class OidcHandler: def __init__(self, hs: "HomeServer"): self._sso_handler = hs.get_sso_handler() - provider_conf = hs.config.oidc.oidc_provider + provider_confs = hs.config.oidc.oidc_providers # we should not have been instantiated if there is no configured provider. - assert provider_conf is not None + assert provider_confs self._token_generator = OidcSessionTokenGenerator(hs) - - self._provider = OidcProvider(hs, self._token_generator, provider_conf) + self._providers = { + p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs + } async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. Called at startup to ensure we have everything we need. """ - await self._provider.load_metadata() - await self._provider.load_jwks() + for idp_id, p in self._providers.items(): + try: + await p.load_metadata() + await p.load_jwks() + except Exception as e: + raise Exception( + "Error while initialising OIDC provider %r" % (idp_id,) + ) from e async def handle_oidc_callback(self, request: SynapseRequest) -> None: """Handle an incoming request to /_synapse/oidc/callback @@ -184,6 +191,12 @@ class OidcHandler: self._sso_handler.render_error(request, "mismatching_session", str(e)) return + oidc_provider = self._providers.get(session_data.idp_id) + if not oidc_provider: + logger.error("OIDC session uses unknown IdP %r", oidc_provider) + self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP") + return + if b"code" not in request.args: logger.info("Code parameter is missing") self._sso_handler.render_error( @@ -193,7 +206,7 @@ class OidcHandler: code = request.args[b"code"][0].decode() - await self._provider.handle_oidc_callback(request, session_data, code) + await oidc_provider.handle_oidc_callback(request, session_data, code) class OidcError(Exception): diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 02e21ed6ca..b3dfa40d25 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -145,7 +145,7 @@ class OidcHandlerTestCase(HomeserverTestCase): hs = self.setup_test_homeserver(proxied_http_client=self.http_client) self.handler = hs.get_oidc_handler() - self.provider = self.handler._provider + self.provider = self.handler._providers["oidc"] sso_handler = hs.get_sso_handler() # Mock the render error method. self.render_error = Mock(return_value=None) @@ -866,7 +866,7 @@ async def _make_callback_with_userinfo( from synapse.handlers.oidc_handler import OidcSessionData handler = hs.get_oidc_handler() - provider = handler._provider + provider = handler._providers["oidc"] provider._exchange_code = simple_async_mock(return_value={}) provider._parse_id_token = simple_async_mock(return_value=userinfo) provider._fetch_userinfo = simple_async_mock(return_value=userinfo) -- cgit 1.5.1 From 6633a4015a7b4ba60f87c5e6f979a9c9d8f9d8fe Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 18 Jan 2021 15:47:59 +0000 Subject: Allow moving account data and receipts streams off master (#9104) --- changelog.d/9104.feature | 1 + synapse/app/generic_worker.py | 15 +- synapse/config/workers.py | 18 +- synapse/handlers/account_data.py | 144 ++++++++++++++++ synapse/handlers/read_marker.py | 5 +- synapse/handlers/receipts.py | 27 ++- synapse/handlers/room_member.py | 7 +- synapse/replication/http/__init__.py | 2 + synapse/replication/http/account_data.py | 187 +++++++++++++++++++++ synapse/replication/slave/storage/_base.py | 10 +- synapse/replication/slave/storage/account_data.py | 40 +---- synapse/replication/slave/storage/receipts.py | 35 +--- synapse/replication/tcp/handler.py | 19 +++ synapse/rest/client/v2_alpha/account_data.py | 22 +-- synapse/rest/client/v2_alpha/tags.py | 11 +- synapse/server.py | 5 + synapse/storage/databases/main/__init__.py | 10 +- synapse/storage/databases/main/account_data.py | 107 +++++++++--- synapse/storage/databases/main/deviceinbox.py | 4 +- .../storage/databases/main/event_push_actions.py | 92 +++++----- synapse/storage/databases/main/events_worker.py | 8 +- synapse/storage/databases/main/receipts.py | 108 ++++++++---- .../main/schema/delta/59/06shard_account_data.sql | 20 +++ .../delta/59/06shard_account_data.sql.postgres | 32 ++++ synapse/storage/databases/main/tags.py | 10 +- synapse/storage/util/id_generators.py | 84 +++++---- tests/storage/test_id_generators.py | 112 +++++++++++- 27 files changed, 855 insertions(+), 280 deletions(-) create mode 100644 changelog.d/9104.feature create mode 100644 synapse/replication/http/account_data.py create mode 100644 synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql create mode 100644 synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres (limited to 'synapse/config') diff --git a/changelog.d/9104.feature b/changelog.d/9104.feature new file mode 100644 index 0000000000..1c4f88bce9 --- /dev/null +++ b/changelog.d/9104.feature @@ -0,0 +1 @@ +Add experimental support for moving off receipts and account data persistence off master. diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index cb202bda44..e60988fa4a 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -100,7 +100,16 @@ from synapse.rest.client.v1.profile import ( ) from synapse.rest.client.v1.push_rule import PushRuleRestServlet from synapse.rest.client.v1.voip import VoipRestServlet -from synapse.rest.client.v2_alpha import groups, room_keys, sync, user_directory +from synapse.rest.client.v2_alpha import ( + account_data, + groups, + read_marker, + receipts, + room_keys, + sync, + tags, + user_directory, +) from synapse.rest.client.v2_alpha._base import client_patterns from synapse.rest.client.v2_alpha.account import ThreepidRestServlet from synapse.rest.client.v2_alpha.account_data import ( @@ -531,6 +540,10 @@ class GenericWorkerServer(HomeServer): room.register_deprecated_servlets(self, resource) InitialSyncRestServlet(self).register(resource) room_keys.register_servlets(self, resource) + tags.register_servlets(self, resource) + account_data.register_servlets(self, resource) + receipts.register_servlets(self, resource) + read_marker.register_servlets(self, resource) SendToDeviceRestServlet(self).register(resource) diff --git a/synapse/config/workers.py b/synapse/config/workers.py index 364583f48b..f10e33f7b8 100644 --- a/synapse/config/workers.py +++ b/synapse/config/workers.py @@ -56,6 +56,12 @@ class WriterLocations: to_device = attr.ib( default=["master"], type=List[str], converter=_instance_to_list_converter, ) + account_data = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) + receipts = attr.ib( + default=["master"], type=List[str], converter=_instance_to_list_converter, + ) class WorkerConfig(Config): @@ -127,7 +133,7 @@ class WorkerConfig(Config): # Check that the configured writers for events and typing also appears in # `instance_map`. - for stream in ("events", "typing", "to_device"): + for stream in ("events", "typing", "to_device", "account_data", "receipts"): instances = _instance_to_list_converter(getattr(self.writers, stream)) for instance in instances: if instance != "master" and instance not in self.instance_map: @@ -141,6 +147,16 @@ class WorkerConfig(Config): "Must only specify one instance to handle `to_device` messages." ) + if len(self.writers.account_data) != 1: + raise ConfigError( + "Must only specify one instance to handle `account_data` messages." + ) + + if len(self.writers.receipts) != 1: + raise ConfigError( + "Must only specify one instance to handle `receipts` messages." + ) + self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events) # Whether this worker should run background tasks or not. diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py index 341135822e..b1a5df9638 100644 --- a/synapse/handlers/account_data.py +++ b/synapse/handlers/account_data.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- # Copyright 2015, 2016 OpenMarket Ltd +# Copyright 2021 The Matrix.org Foundation C.I.C. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,14 +13,157 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import random from typing import TYPE_CHECKING, List, Tuple +from synapse.replication.http.account_data import ( + ReplicationAddTagRestServlet, + ReplicationRemoveTagRestServlet, + ReplicationRoomAccountDataRestServlet, + ReplicationUserAccountDataRestServlet, +) from synapse.types import JsonDict, UserID if TYPE_CHECKING: from synapse.app.homeserver import HomeServer +class AccountDataHandler: + def __init__(self, hs: "HomeServer"): + self._store = hs.get_datastore() + self._instance_name = hs.get_instance_name() + self._notifier = hs.get_notifier() + + self._user_data_client = ReplicationUserAccountDataRestServlet.make_client(hs) + self._room_data_client = ReplicationRoomAccountDataRestServlet.make_client(hs) + self._add_tag_client = ReplicationAddTagRestServlet.make_client(hs) + self._remove_tag_client = ReplicationRemoveTagRestServlet.make_client(hs) + self._account_data_writers = hs.config.worker.writers.account_data + + async def add_account_data_to_room( + self, user_id: str, room_id: str, account_data_type: str, content: JsonDict + ) -> int: + """Add some account_data to a room for a user. + + Args: + user_id: The user to add a tag for. + room_id: The room to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + + Returns: + The maximum stream ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_account_data_to_room( + user_id, room_id, account_data_type, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + + return max_stream_id + else: + response = await self._room_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + account_data_type=account_data_type, + content=content, + ) + return response["max_stream_id"] + + async def add_account_data_for_user( + self, user_id: str, account_data_type: str, content: JsonDict + ) -> int: + """Add some account_data to a room for a user. + + Args: + user_id: The user to add a tag for. + account_data_type: The type of account_data to add. + content: A json object to associate with the tag. + + Returns: + The maximum stream ID. + """ + + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_account_data_for_user( + user_id, account_data_type, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._user_data_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + account_data_type=account_data_type, + content=content, + ) + return response["max_stream_id"] + + async def add_tag_to_room( + self, user_id: str, room_id: str, tag: str, content: JsonDict + ) -> int: + """Add a tag to a room for a user. + + Args: + user_id: The user to add a tag for. + room_id: The room to add a tag for. + tag: The tag name to add. + content: A json object to associate with the tag. + + Returns: + The next account data ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.add_tag_to_room( + user_id, room_id, tag, content + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._add_tag_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + tag=tag, + content=content, + ) + return response["max_stream_id"] + + async def remove_tag_from_room(self, user_id: str, room_id: str, tag: str) -> int: + """Remove a tag from a room for a user. + + Returns: + The next account data ID. + """ + if self._instance_name in self._account_data_writers: + max_stream_id = await self._store.remove_tag_from_room( + user_id, room_id, tag + ) + + self._notifier.on_new_event( + "account_data_key", max_stream_id, users=[user_id] + ) + return max_stream_id + else: + response = await self._remove_tag_client( + instance_name=random.choice(self._account_data_writers), + user_id=user_id, + room_id=room_id, + tag=tag, + ) + return response["max_stream_id"] + + class AccountDataEventSource: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py index a7550806e6..6bb2fd936b 100644 --- a/synapse/handlers/read_marker.py +++ b/synapse/handlers/read_marker.py @@ -31,8 +31,8 @@ class ReadMarkerHandler(BaseHandler): super().__init__(hs) self.server_name = hs.config.server_name self.store = hs.get_datastore() + self.account_data_handler = hs.get_account_data_handler() self.read_marker_linearizer = Linearizer(name="read_marker") - self.notifier = hs.get_notifier() async def received_client_read_marker( self, room_id: str, user_id: str, event_id: str @@ -59,7 +59,6 @@ class ReadMarkerHandler(BaseHandler): if should_update: content = {"event_id": event_id} - max_id = await self.store.add_account_data_to_room( + await self.account_data_handler.add_account_data_to_room( user_id, room_id, "m.fully_read", content ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index a9abdf42e0..cc21fc2284 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -32,10 +32,26 @@ class ReceiptsHandler(BaseHandler): self.server_name = hs.config.server_name self.store = hs.get_datastore() self.hs = hs - self.federation = hs.get_federation_sender() - hs.get_federation_registry().register_edu_handler( - "m.receipt", self._received_remote_receipt - ) + + # We only need to poke the federation sender explicitly if its on the + # same instance. Other federation sender instances will get notified by + # `synapse.app.generic_worker.FederationSenderHandler` when it sees it + # in the receipts stream. + self.federation_sender = None + if hs.should_send_federation(): + self.federation_sender = hs.get_federation_sender() + + # If we can handle the receipt EDUs we do so, otherwise we route them + # to the appropriate worker. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + hs.get_federation_registry().register_edu_handler( + "m.receipt", self._received_remote_receipt + ) + else: + hs.get_federation_registry().register_instances_for_edu( + "m.receipt", hs.config.worker.writers.receipts, + ) + self.clock = self.hs.get_clock() self.state = hs.get_state_handler() @@ -125,7 +141,8 @@ class ReceiptsHandler(BaseHandler): if not is_new: return - await self.federation.send_read_receipt(receipt) + if self.federation_sender: + await self.federation_sender.send_read_receipt(receipt) class ReceiptEventSource: diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cb5a29bc7e..e001e418f9 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -63,6 +63,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): self.registration_handler = hs.get_registration_handler() self.profile_handler = hs.get_profile_handler() self.event_creation_handler = hs.get_event_creation_handler() + self.account_data_handler = hs.get_account_data_handler() self.member_linearizer = Linearizer(name="member") @@ -253,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): direct_rooms[key].append(new_room_id) # Save back to user's m.direct account data - await self.store.add_account_data_for_user( + await self.account_data_handler.add_account_data_for_user( user_id, AccountDataTypes.DIRECT, direct_rooms ) break @@ -263,7 +264,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # Copy each room tag to the new room for tag, tag_content in room_tags.items(): - await self.store.add_tag_to_room(user_id, new_room_id, tag, tag_content) + await self.account_data_handler.add_tag_to_room( + user_id, new_room_id, tag, tag_content + ) async def update_membership( self, diff --git a/synapse/replication/http/__init__.py b/synapse/replication/http/__init__.py index a84a064c8d..dd527e807f 100644 --- a/synapse/replication/http/__init__.py +++ b/synapse/replication/http/__init__.py @@ -15,6 +15,7 @@ from synapse.http.server import JsonResource from synapse.replication.http import ( + account_data, devices, federation, login, @@ -40,6 +41,7 @@ class ReplicationRestResource(JsonResource): presence.register_servlets(hs, self) membership.register_servlets(hs, self) streams.register_servlets(hs, self) + account_data.register_servlets(hs, self) # The following can't currently be instantiated on workers. if hs.config.worker.worker_app is None: diff --git a/synapse/replication/http/account_data.py b/synapse/replication/http/account_data.py new file mode 100644 index 0000000000..52d32528ee --- /dev/null +++ b/synapse/replication/http/account_data.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from synapse.http.servlet import parse_json_object_from_request +from synapse.replication.http._base import ReplicationEndpoint + +logger = logging.getLogger(__name__) + + +class ReplicationUserAccountDataRestServlet(ReplicationEndpoint): + """Add user account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_user_account_data/:user_id/:type + + { + "content": { ... }, + } + + """ + + NAME = "add_user_account_data" + PATH_ARGS = ("user_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_for_user( + user_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRoomAccountDataRestServlet(ReplicationEndpoint): + """Add room account data on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_room_account_data/:user_id/:room_id/:account_data_type + + { + "content": { ... }, + } + + """ + + NAME = "add_room_account_data" + PATH_ARGS = ("user_id", "room_id", "account_data_type") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, account_data_type, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, account_data_type): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_account_data_to_room( + user_id, room_id, account_data_type, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationAddTagRestServlet(ReplicationEndpoint): + """Add tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/add_tag/:user_id/:room_id/:tag + + { + "content": { ... }, + } + + """ + + NAME = "add_tag" + PATH_ARGS = ("user_id", "room_id", "tag") + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag, content): + payload = { + "content": content, + } + + return payload + + async def _handle_request(self, request, user_id, room_id, tag): + content = parse_json_object_from_request(request) + + max_stream_id = await self.handler.add_tag_to_room( + user_id, room_id, tag, content["content"] + ) + + return 200, {"max_stream_id": max_stream_id} + + +class ReplicationRemoveTagRestServlet(ReplicationEndpoint): + """Remove tag on the appropriate account data worker. + + Request format: + + POST /_synapse/replication/remove_tag/:user_id/:room_id/:tag + + {} + + """ + + NAME = "remove_tag" + PATH_ARGS = ( + "user_id", + "room_id", + "tag", + ) + CACHE = False + + def __init__(self, hs): + super().__init__(hs) + + self.handler = hs.get_account_data_handler() + self.clock = hs.get_clock() + + @staticmethod + async def _serialize_payload(user_id, room_id, tag): + + return {} + + async def _handle_request(self, request, user_id, room_id, tag): + max_stream_id = await self.handler.remove_tag_from_room(user_id, room_id, tag,) + + return 200, {"max_stream_id": max_stream_id} + + +def register_servlets(hs, http_server): + ReplicationUserAccountDataRestServlet(hs).register(http_server) + ReplicationRoomAccountDataRestServlet(hs).register(http_server) + ReplicationAddTagRestServlet(hs).register(http_server) + ReplicationRemoveTagRestServlet(hs).register(http_server) diff --git a/synapse/replication/slave/storage/_base.py b/synapse/replication/slave/storage/_base.py index d0089fe06c..693c9ab901 100644 --- a/synapse/replication/slave/storage/_base.py +++ b/synapse/replication/slave/storage/_base.py @@ -33,9 +33,13 @@ class BaseSlavedStore(CacheInvalidationWorkerStore): database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) # type: Optional[MultiWriterIdGenerator] diff --git a/synapse/replication/slave/storage/account_data.py b/synapse/replication/slave/storage/account_data.py index 4268565fc8..21afe5f155 100644 --- a/synapse/replication/slave/storage/account_data.py +++ b/synapse/replication/slave/storage/account_data.py @@ -15,47 +15,9 @@ # limitations under the License. from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker -from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.tags import TagsWorkerStore class SlavedAccountDataStore(TagsWorkerStore, AccountDataWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = SlavedIdTracker( - db_conn, - "account_data", - "stream_id", - extra_tables=[ - ("room_account_data", "stream_id"), - ("room_tags_revisions", "stream_id"), - ], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self): - return self._account_data_id_gen.get_current_token() - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == TagAccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - self.get_tags_for_user.invalidate((row.user_id,)) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - elif stream_name == AccountDataStream.NAME: - self._account_data_id_gen.advance(instance_name, token) - for row in rows: - if not row.room_id: - self.get_global_account_data_by_type_for_user.invalidate( - (row.data_type, row.user_id) - ) - self.get_account_data_for_user.invalidate((row.user_id,)) - self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) - self.get_account_data_for_room_and_type.invalidate( - (row.user_id, row.room_id, row.data_type) - ) - self._account_data_stream_cache.entity_has_changed(row.user_id, token) - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/slave/storage/receipts.py b/synapse/replication/slave/storage/receipts.py index 6195917376..3dfdd9961d 100644 --- a/synapse/replication/slave/storage/receipts.py +++ b/synapse/replication/slave/storage/receipts.py @@ -14,43 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from synapse.replication.tcp.streams import ReceiptsStream -from synapse.storage.database import DatabasePool from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from ._base import BaseSlavedStore -from ._slaved_id_tracker import SlavedIdTracker class SlavedReceiptsStore(ReceiptsWorkerStore, BaseSlavedStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = SlavedIdTracker( - db_conn, "receipts_linearized", "stream_id" - ) - - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() - - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): - self.get_receipts_for_user.invalidate((user_id, receipt_type)) - self._get_linearized_receipts_for_room.invalidate_many((room_id,)) - self.get_last_receipt_event_id_for_user.invalidate( - (user_id, room_id, receipt_type) - ) - self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) - self.get_receipts_for_room.invalidate((room_id, receipt_type)) - - def process_replication_rows(self, stream_name, instance_name, token, rows): - if stream_name == ReceiptsStream.NAME: - self._receipts_id_gen.advance(instance_name, token) - for row in rows: - self.invalidate_caches_for_receipt( - row.room_id, row.receipt_type, row.user_id - ) - self._receipts_stream_cache.entity_has_changed(row.room_id, token) - - return super().process_replication_rows(stream_name, instance_name, token, rows) + pass diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 1f89249475..317796d5e0 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -51,11 +51,14 @@ from synapse.replication.tcp.commands import ( from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.streams import ( STREAMS_MAP, + AccountDataStream, BackfillStream, CachesStream, EventsStream, FederationStream, + ReceiptsStream, Stream, + TagAccountDataStream, ToDeviceStream, TypingStream, ) @@ -132,6 +135,22 @@ class ReplicationCommandHandler: continue + if isinstance(stream, (AccountDataStream, TagAccountDataStream)): + # Only add AccountDataStream and TagAccountDataStream as a source on the + # instance in charge of account_data persistence. + if hs.get_instance_name() in hs.config.worker.writers.account_data: + self._streams_to_replicate.append(stream) + + continue + + if isinstance(stream, ReceiptsStream): + # Only add ReceiptsStream as a source on the instance in charge of + # receipts. + if hs.get_instance_name() in hs.config.worker.writers.receipts: + self._streams_to_replicate.append(stream) + + continue + # Only add any other streams if we're on master. if hs.config.worker_app is not None: continue diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py index 87a5b1b86b..3f28c0bc3e 100644 --- a/synapse/rest/client/v2_alpha/account_data.py +++ b/synapse/rest/client/v2_alpha/account_data.py @@ -37,24 +37,16 @@ class AccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") body = parse_json_object_from_request(request) - max_id = await self.store.add_account_data_for_user( - user_id, account_data_type, body - ) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_account_data_for_user(user_id, account_data_type, body) return 200, {} @@ -89,13 +81,9 @@ class RoomAccountDataServlet(RestServlet): super().__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() - self.notifier = hs.get_notifier() - self._is_worker = hs.config.worker_app is not None + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, account_data_type): - if self._is_worker: - raise Exception("Cannot handle PUT /account_data on worker") - requester = await self.auth.get_user_by_req(request) if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add account data for other users.") @@ -109,12 +97,10 @@ class RoomAccountDataServlet(RestServlet): " Use /rooms/!roomId:server.name/read_markers", ) - max_id = await self.store.add_account_data_to_room( + await self.handler.add_account_data_to_room( user_id, room_id, account_data_type, body ) - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) - return 200, {} async def on_GET(self, request, user_id, room_id, account_data_type): diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index bf3a79db44..a97cd66c52 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -58,8 +58,7 @@ class TagServlet(RestServlet): def __init__(self, hs): super().__init__() self.auth = hs.get_auth() - self.store = hs.get_datastore() - self.notifier = hs.get_notifier() + self.handler = hs.get_account_data_handler() async def on_PUT(self, request, user_id, room_id, tag): requester = await self.auth.get_user_by_req(request) @@ -68,9 +67,7 @@ class TagServlet(RestServlet): body = parse_json_object_from_request(request) - max_id = await self.store.add_tag_to_room(user_id, room_id, tag, body) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.add_tag_to_room(user_id, room_id, tag, body) return 200, {} @@ -79,9 +76,7 @@ class TagServlet(RestServlet): if user_id != requester.user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - max_id = await self.store.remove_tag_from_room(user_id, room_id, tag) - - self.notifier.on_new_event("account_data_key", max_id, users=[user_id]) + await self.handler.remove_tag_from_room(user_id, room_id, tag) return 200, {} diff --git a/synapse/server.py b/synapse/server.py index d4c235cda5..9cdda83aa1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -55,6 +55,7 @@ from synapse.federation.sender import FederationSender from synapse.federation.transport.client import TransportLayerClient from synapse.groups.attestations import GroupAttestationSigning, GroupAttestionRenewer from synapse.groups.groups_server import GroupsServerHandler, GroupsServerWorkerHandler +from synapse.handlers.account_data import AccountDataHandler from synapse.handlers.account_validity import AccountValidityHandler from synapse.handlers.acme import AcmeHandler from synapse.handlers.admin import AdminHandler @@ -711,6 +712,10 @@ class HomeServer(metaclass=abc.ABCMeta): def get_module_api(self) -> ModuleApi: return ModuleApi(self, self.get_auth_handler()) + @cache_in_self + def get_account_data_handler(self) -> AccountDataHandler: + return AccountDataHandler(self) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index c4de07a0a8..ae561a2da3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -160,9 +160,13 @@ class DataStore( database, stream_name="caches", instance_name=hs.get_instance_name(), - table="cache_invalidation_stream_by_instance", - instance_column="instance_name", - id_column="stream_id", + tables=[ + ( + "cache_invalidation_stream_by_instance", + "instance_name", + "stream_id", + ) + ], sequence_name="cache_invalidation_stream_seq", writers=[], ) diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py index bad8260892..68896f34af 100644 --- a/synapse/storage/databases/main/account_data.py +++ b/synapse/storage/databases/main/account_data.py @@ -14,14 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import Dict, List, Optional, Set, Tuple from synapse.api.constants import AccountDataTypes +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import AccountDataStream, TagAccountDataStream from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached @@ -30,14 +32,57 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. -class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): +class AccountDataWorkerStore(SQLBaseStore): """This is an abstract base class where subclasses must implement `get_max_account_data_stream_id` which can be called in the initializer. """ def __init__(self, database: DatabasePool, db_conn, hs): + self._instance_name = hs.get_instance_name() + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_account_data = ( + self._instance_name in hs.config.worker.writers.account_data + ) + + self._account_data_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="account_data", + instance_name=self._instance_name, + tables=[ + ("room_account_data", "instance_name", "stream_id"), + ("room_tags_revisions", "instance_name", "stream_id"), + ("account_data", "instance_name", "stream_id"), + ], + sequence_name="account_data_sequence", + writers=hs.config.worker.writers.account_data, + ) + else: + self._can_write_to_account_data = True + + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._account_data_id_gen = StreamIdGenerator( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + else: + self._account_data_id_gen = SlavedIdTracker( + db_conn, + "room_account_data", + "stream_id", + extra_tables=[("room_tags_revisions", "stream_id")], + ) + account_max = self.get_max_account_data_stream_id() self._account_data_stream_cache = StreamChangeCache( "AccountDataAndTagsChangeCache", account_max @@ -45,14 +90,13 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): super().__init__(database, db_conn, hs) - @abc.abstractmethod - def get_max_account_data_stream_id(self): + def get_max_account_data_stream_id(self) -> int: """Get the current max stream ID for account data stream Returns: int """ - raise NotImplementedError() + return self._account_data_id_gen.get_current_token() @cached() async def get_account_data_for_user( @@ -307,25 +351,26 @@ class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): ) ) - -class AccountDataStore(AccountDataWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - self._account_data_id_gen = StreamIdGenerator( - db_conn, - "room_account_data", - "stream_id", - extra_tables=[("room_tags_revisions", "stream_id")], - ) - - super().__init__(database, db_conn, hs) - - def get_max_account_data_stream_id(self) -> int: - """Get the current max stream id for the private user data stream - - Returns: - The maximum stream ID. - """ - return self._account_data_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == TagAccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + self.get_tags_for_user.invalidate((row.user_id,)) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + elif stream_name == AccountDataStream.NAME: + self._account_data_id_gen.advance(instance_name, token) + for row in rows: + if not row.room_id: + self.get_global_account_data_by_type_for_user.invalidate( + (row.data_type, row.user_id) + ) + self.get_account_data_for_user.invalidate((row.user_id,)) + self.get_account_data_for_room.invalidate((row.user_id, row.room_id)) + self.get_account_data_for_room_and_type.invalidate( + (row.user_id, row.room_id, row.data_type) + ) + self._account_data_stream_cache.entity_has_changed(row.user_id, token) + return super().process_replication_rows(stream_name, instance_name, token, rows) async def add_account_data_to_room( self, user_id: str, room_id: str, account_data_type: str, content: JsonDict @@ -341,6 +386,8 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ + assert self._can_write_to_account_data + content_json = json_encoder.encode(content) async with self._account_data_id_gen.get_next() as next_id: @@ -381,6 +428,8 @@ class AccountDataStore(AccountDataWorkerStore): Returns: The maximum stream ID. """ + assert self._can_write_to_account_data + async with self._account_data_id_gen.get_next() as next_id: await self.db_pool.runInteraction( "add_user_account_data", @@ -463,3 +512,7 @@ class AccountDataStore(AccountDataWorkerStore): # Invalidate the cache for any ignored users which were added or removed. for ignored_user_id in previously_ignored_users ^ currently_ignored_users: self._invalidate_cache_and_stream(txn, self.ignored_by, (ignored_user_id,)) + + +class AccountDataStore(AccountDataWorkerStore): + pass diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index 58d3f71e45..31f70ac5ef 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -54,9 +54,7 @@ class DeviceInboxWorkerStore(SQLBaseStore): db=database, stream_name="to_device", instance_name=self._instance_name, - table="device_inbox", - instance_column="instance_name", - id_column="stream_id", + tables=[("device_inbox", "instance_name", "stream_id")], sequence_name="device_inbox_sequence", writers=hs.config.worker.writers.to_device, ) diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index e5c03cc609..1b657191a9 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -835,6 +835,52 @@ class EventPushActionsWorkerStore(SQLBaseStore): (rotate_to_stream_ordering,), ) + def _remove_old_push_actions_before_txn( + self, txn, room_id, user_id, stream_ordering + ): + """ + Purges old push actions for a user and room before a given + stream_ordering. + + We however keep a months worth of highlighted notifications, so that + users can still get a list of recent highlights. + + Args: + txn: The transcation + room_id: Room ID to delete from + user_id: user ID to delete for + stream_ordering: The lowest stream ordering which will + not be deleted. + """ + txn.call_after( + self.get_unread_event_push_actions_by_room_for_user.invalidate_many, + (room_id, user_id), + ) + + # We need to join on the events table to get the received_ts for + # event_push_actions and sqlite won't let us use a join in a delete so + # we can't just delete where received_ts < x. Furthermore we can + # only identify event_push_actions by a tuple of room_id, event_id + # we we can't use a subquery. + # Instead, we look up the stream ordering for the last event in that + # room received before the threshold time and delete event_push_actions + # in the room with a stream_odering before that. + txn.execute( + "DELETE FROM event_push_actions " + " WHERE user_id = ? AND room_id = ? AND " + " stream_ordering <= ?" + " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", + (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), + ) + + txn.execute( + """ + DELETE FROM event_push_summary + WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? + """, + (room_id, user_id, stream_ordering), + ) + class EventPushActionsStore(EventPushActionsWorkerStore): EPA_HIGHLIGHT_INDEX = "epa_highlight_index" @@ -894,52 +940,6 @@ class EventPushActionsStore(EventPushActionsWorkerStore): pa["actions"] = _deserialize_action(pa["actions"], pa["highlight"]) return push_actions - def _remove_old_push_actions_before_txn( - self, txn, room_id, user_id, stream_ordering - ): - """ - Purges old push actions for a user and room before a given - stream_ordering. - - We however keep a months worth of highlighted notifications, so that - users can still get a list of recent highlights. - - Args: - txn: The transcation - room_id: Room ID to delete from - user_id: user ID to delete for - stream_ordering: The lowest stream ordering which will - not be deleted. - """ - txn.call_after( - self.get_unread_event_push_actions_by_room_for_user.invalidate_many, - (room_id, user_id), - ) - - # We need to join on the events table to get the received_ts for - # event_push_actions and sqlite won't let us use a join in a delete so - # we can't just delete where received_ts < x. Furthermore we can - # only identify event_push_actions by a tuple of room_id, event_id - # we we can't use a subquery. - # Instead, we look up the stream ordering for the last event in that - # room received before the threshold time and delete event_push_actions - # in the room with a stream_odering before that. - txn.execute( - "DELETE FROM event_push_actions " - " WHERE user_id = ? AND room_id = ? AND " - " stream_ordering <= ?" - " AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)", - (user_id, room_id, stream_ordering, self.stream_ordering_month_ago), - ) - - txn.execute( - """ - DELETE FROM event_push_summary - WHERE room_id = ? AND user_id = ? AND stream_ordering <= ? - """, - (room_id, user_id, stream_ordering), - ) - def _action_has_highlight(actions): for action in actions: diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 4732685f6e..71d823be72 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -96,9 +96,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, stream_name="events", instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", + tables=[("events", "instance_name", "stream_ordering")], sequence_name="events_stream_seq", writers=hs.config.worker.writers.events, ) @@ -107,9 +105,7 @@ class EventsWorkerStore(SQLBaseStore): db=database, stream_name="backfill", instance_name=hs.get_instance_name(), - table="events", - instance_column="instance_name", - id_column="stream_ordering", + tables=[("events", "instance_name", "stream_ordering")], sequence_name="events_backfill_stream_seq", positive=False, writers=hs.config.worker.writers.events, diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 1e7949a323..e0e57f0578 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,15 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import logging from typing import Any, Dict, List, Optional, Tuple from twisted.internet import defer +from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker +from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool -from synapse.storage.util.id_generators import StreamIdGenerator +from synapse.storage.engines import PostgresEngine +from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -31,28 +33,56 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) -# The ABCMeta metaclass ensures that it cannot be instantiated without -# the abstract methods being implemented. -class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): - """This is an abstract base class where subclasses must implement - `get_max_receipt_stream_id` which can be called in the initializer. - """ - +class ReceiptsWorkerStore(SQLBaseStore): def __init__(self, database: DatabasePool, db_conn, hs): + self._instance_name = hs.get_instance_name() + + if isinstance(database.engine, PostgresEngine): + self._can_write_to_receipts = ( + self._instance_name in hs.config.worker.writers.receipts + ) + + self._receipts_id_gen = MultiWriterIdGenerator( + db_conn=db_conn, + db=database, + stream_name="account_data", + instance_name=self._instance_name, + tables=[("receipts_linearized", "instance_name", "stream_id")], + sequence_name="receipts_sequence", + writers=hs.config.worker.writers.receipts, + ) + else: + self._can_write_to_receipts = True + + # We shouldn't be running in worker mode with SQLite, but its useful + # to support it for unit tests. + # + # If this process is the writer than we need to use + # `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets + # updated over replication. (Multiple writers are not supported for + # SQLite). + if hs.get_instance_name() in hs.config.worker.writers.events: + self._receipts_id_gen = StreamIdGenerator( + db_conn, "receipts_linearized", "stream_id" + ) + else: + self._receipts_id_gen = SlavedIdTracker( + db_conn, "receipts_linearized", "stream_id" + ) + super().__init__(database, db_conn, hs) self._receipts_stream_cache = StreamChangeCache( "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - @abc.abstractmethod def get_max_receipt_stream_id(self): """Get the current max stream ID for receipts stream Returns: int """ - raise NotImplementedError() + return self._receipts_id_gen.get_current_token() @cached() async def get_users_with_read_receipts_in_room(self, room_id): @@ -428,19 +458,25 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta): self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - -class ReceiptsStore(ReceiptsWorkerStore): - def __init__(self, database: DatabasePool, db_conn, hs): - # We instantiate this first as the ReceiptsWorkerStore constructor - # needs to be able to call get_max_receipt_stream_id - self._receipts_id_gen = StreamIdGenerator( - db_conn, "receipts_linearized", "stream_id" + def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + self.get_receipts_for_user.invalidate((user_id, receipt_type)) + self._get_linearized_receipts_for_room.invalidate_many((room_id,)) + self.get_last_receipt_event_id_for_user.invalidate( + (user_id, room_id, receipt_type) ) + self._invalidate_get_users_with_receipts_in_room(room_id, receipt_type, user_id) + self.get_receipts_for_room.invalidate((room_id, receipt_type)) + + def process_replication_rows(self, stream_name, instance_name, token, rows): + if stream_name == ReceiptsStream.NAME: + self._receipts_id_gen.advance(instance_name, token) + for row in rows: + self.invalidate_caches_for_receipt( + row.room_id, row.receipt_type, row.user_id + ) + self._receipts_stream_cache.entity_has_changed(row.room_id, token) - super().__init__(database, db_conn, hs) - - def get_max_receipt_stream_id(self): - return self._receipts_id_gen.get_current_token() + return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( self, txn, room_id, receipt_type, user_id, event_id, data, stream_id @@ -452,6 +488,8 @@ class ReceiptsStore(ReceiptsWorkerStore): otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) """ + assert self._can_write_to_receipts + res = self.db_pool.simple_select_one_txn( txn, table="events", @@ -483,28 +521,14 @@ class ReceiptsStore(ReceiptsWorkerStore): ) return None - txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) - txn.call_after( - self._invalidate_get_users_with_receipts_in_room, - room_id, - receipt_type, - user_id, - ) - txn.call_after(self.get_receipts_for_user.invalidate, (user_id, receipt_type)) - # FIXME: This shouldn't invalidate the whole cache txn.call_after( - self._get_linearized_receipts_for_room.invalidate_many, (room_id,) + self.invalidate_caches_for_receipt, room_id, receipt_type, user_id ) txn.call_after( self._receipts_stream_cache.entity_has_changed, room_id, stream_id ) - txn.call_after( - self.get_last_receipt_event_id_for_user.invalidate, - (user_id, room_id, receipt_type), - ) - self.db_pool.simple_upsert_txn( txn, table="receipts_linearized", @@ -543,6 +567,8 @@ class ReceiptsStore(ReceiptsWorkerStore): Automatically does conversion between linearized and graph representations. """ + assert self._can_write_to_receipts + if not event_ids: return None @@ -607,6 +633,8 @@ class ReceiptsStore(ReceiptsWorkerStore): async def insert_graph_receipt( self, room_id, receipt_type, user_id, event_ids, data ): + assert self._can_write_to_receipts + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, @@ -620,6 +648,8 @@ class ReceiptsStore(ReceiptsWorkerStore): def insert_graph_receipt_txn( self, txn, room_id, receipt_type, user_id, event_ids, data ): + assert self._can_write_to_receipts + txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) txn.call_after( self._invalidate_get_users_with_receipts_in_room, @@ -653,3 +683,7 @@ class ReceiptsStore(ReceiptsWorkerStore): "data": json_encoder.encode(data), }, ) + + +class ReceiptsStore(ReceiptsWorkerStore): + pass diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql new file mode 100644 index 0000000000..46abf8d562 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql @@ -0,0 +1,20 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE room_account_data ADD COLUMN instance_name TEXT; +ALTER TABLE room_tags_revisions ADD COLUMN instance_name TEXT; +ALTER TABLE account_data ADD COLUMN instance_name TEXT; + +ALTER TABLE receipts_linearized ADD COLUMN instance_name TEXT; diff --git a/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres new file mode 100644 index 0000000000..4a6e6c74f5 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/59/06shard_account_data.sql.postgres @@ -0,0 +1,32 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE SEQUENCE IF NOT EXISTS account_data_sequence; + +-- We need to take the max across all the account_data tables as they share the +-- ID generator +SELECT setval('account_data_sequence', ( + SELECT GREATEST( + (SELECT COALESCE(MAX(stream_id), 1) FROM room_account_data), + (SELECT COALESCE(MAX(stream_id), 1) FROM room_tags_revisions), + (SELECT COALESCE(MAX(stream_id), 1) FROM account_data) + ) +)); + +CREATE SEQUENCE IF NOT EXISTS receipts_sequence; + +SELECT setval('receipts_sequence', ( + SELECT COALESCE(MAX(stream_id), 1) FROM receipts_linearized +)); diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py index 74da9c49f2..50067eabfc 100644 --- a/synapse/storage/databases/main/tags.py +++ b/synapse/storage/databases/main/tags.py @@ -183,8 +183,6 @@ class TagsWorkerStore(AccountDataWorkerStore): ) return {row["tag"]: db_to_json(row["content"]) for row in rows} - -class TagsStore(TagsWorkerStore): async def add_tag_to_room( self, user_id: str, room_id: str, tag: str, content: JsonDict ) -> int: @@ -199,6 +197,8 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ + assert self._can_write_to_account_data + content_json = json_encoder.encode(content) def add_tag_txn(txn, next_id): @@ -223,6 +223,7 @@ class TagsStore(TagsWorkerStore): Returns: The next account data ID. """ + assert self._can_write_to_account_data def remove_tag_txn(txn, next_id): sql = ( @@ -250,6 +251,7 @@ class TagsStore(TagsWorkerStore): room_id: The ID of the room. next_id: The the revision to advance to. """ + assert self._can_write_to_account_data txn.call_after( self._account_data_stream_cache.entity_has_changed, user_id, next_id @@ -278,3 +280,7 @@ class TagsStore(TagsWorkerStore): # which stream_id ends up in the table, as long as it is higher # than the id that the client has. pass + + +class TagsStore(TagsWorkerStore): + pass diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 133c0e7a28..39a3ab1162 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -17,7 +17,7 @@ import logging import threading from collections import deque from contextlib import contextmanager -from typing import Dict, List, Optional, Set, Union +from typing import Dict, List, Optional, Set, Tuple, Union import attr from typing_extensions import Deque @@ -186,11 +186,12 @@ class MultiWriterIdGenerator: Args: db_conn db - stream_name: A name for the stream. + stream_name: A name for the stream, for use in the `stream_positions` + table. (Does not need to be the same as the replication stream name) instance_name: The name of this instance. - table: Database table associated with stream. - instance_column: Column that stores the row's writer's instance name - id_column: Column that stores the stream ID. + tables: List of tables associated with the stream. Tuple of table + name, column name that stores the writer's instance name, and + column name that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. writers: A list of known writers to use to populate current positions @@ -206,9 +207,7 @@ class MultiWriterIdGenerator: db: DatabasePool, stream_name: str, instance_name: str, - table: str, - instance_column: str, - id_column: str, + tables: List[Tuple[str, str, str]], sequence_name: str, writers: List[str], positive: bool = True, @@ -260,15 +259,16 @@ class MultiWriterIdGenerator: self._sequence_gen = PostgresSequenceGenerator(sequence_name) # We check that the table and sequence haven't diverged. - self._sequence_gen.check_consistency( - db_conn, table=table, id_column=id_column, positive=positive - ) + for table, _, id_column in tables: + self._sequence_gen.check_consistency( + db_conn, table=table, id_column=id_column, positive=positive + ) # This goes and fills out the above state from the database. - self._load_current_ids(db_conn, table, instance_column, id_column) + self._load_current_ids(db_conn, tables) def _load_current_ids( - self, db_conn, table: str, instance_column: str, id_column: str + self, db_conn, tables: List[Tuple[str, str, str]], ): cur = db_conn.cursor(txn_name="_load_current_ids") @@ -306,17 +306,22 @@ class MultiWriterIdGenerator: # We add a GREATEST here to ensure that the result is always # positive. (This can be a problem for e.g. backfill streams where # the server has never backfilled). - sql = """ - SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) - FROM %(table)s - """ % { - "id": id_column, - "table": table, - "agg": "MAX" if self._positive else "-MIN", - } - cur.execute(sql) - (stream_id,) = cur.fetchone() - self._persisted_upto_position = stream_id + max_stream_id = 1 + for table, _, id_column in tables: + sql = """ + SELECT GREATEST(COALESCE(%(agg)s(%(id)s), 1), 1) + FROM %(table)s + """ % { + "id": id_column, + "table": table, + "agg": "MAX" if self._positive else "-MIN", + } + cur.execute(sql) + (stream_id,) = cur.fetchone() + + max_stream_id = max(max_stream_id, stream_id) + + self._persisted_upto_position = max_stream_id else: # If we have a min_stream_id then we pull out everything greater # than it from the DB so that we can prefill @@ -329,21 +334,28 @@ class MultiWriterIdGenerator: # stream positions table before restart (or the stream position # table otherwise got out of date). - sql = """ - SELECT %(instance)s, %(id)s FROM %(table)s - WHERE ? %(cmp)s %(id)s - """ % { - "id": id_column, - "table": table, - "instance": instance_column, - "cmp": "<=" if self._positive else ">=", - } - cur.execute(sql, (min_stream_id * self._return_factor,)) - self._persisted_upto_position = min_stream_id + rows = [] + for table, instance_column, id_column in tables: + sql = """ + SELECT %(instance)s, %(id)s FROM %(table)s + WHERE ? %(cmp)s %(id)s + """ % { + "id": id_column, + "table": table, + "instance": instance_column, + "cmp": "<=" if self._positive else ">=", + } + cur.execute(sql, (min_stream_id * self._return_factor,)) + + rows.extend(cur) + + # Sort so that we handle rows in order for each instance. + rows.sort() + with self._lock: - for (instance, stream_id,) in cur: + for (instance, stream_id,) in rows: stream_id = self._return_factor * stream_id self._add_persisted_position(stream_id) diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index cc0612cf65..3e2fd4da01 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -51,9 +51,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): self.db_pool, stream_name="test_stream", instance_name=instance_name, - table="foobar", - instance_column="instance_name", - id_column="stream_id", + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers, ) @@ -487,9 +485,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.db_pool, stream_name="test_stream", instance_name=instance_name, - table="foobar", - instance_column="instance_name", - id_column="stream_id", + tables=[("foobar", "instance_name", "stream_id")], sequence_name="foobar_seq", writers=writers, positive=False, @@ -579,3 +575,107 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) + + +class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase): + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db_pool = self.store.db_pool # type: DatabasePool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar1 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + txn.execute( + """ + CREATE TABLE foobar2 ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator( + self, instance_name="master", writers=["master"] + ) -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db_pool, + stream_name="test_stream", + instance_name=instance_name, + tables=[ + ("foobar1", "instance_name", "stream_id"), + ("foobar2", "instance_name", "stream_id"), + ], + sequence_name="foobar_seq", + writers=writers, + ) + + return self.get_success_or_raise(self.db_pool.runWithConnection(_create)) + + def _insert_rows( + self, + table: str, + instance_name: str, + number: int, + update_stream_table: bool = True, + ): + """Insert N rows as the given instance, inserting with stream IDs pulled + from the postgres sequence. + """ + + def _insert(txn): + for _ in range(number): + txn.execute( + "INSERT INTO %s VALUES (nextval('foobar_seq'), ?)" % (table,), + (instance_name,), + ) + if update_stream_table: + txn.execute( + """ + INSERT INTO stream_positions VALUES ('test_stream', ?, lastval()) + ON CONFLICT (stream_name, instance_name) DO UPDATE SET stream_id = lastval() + """, + (instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_rows", _insert)) + + def test_load_existing_stream(self): + """Test creating ID gens with multiple tables that have rows from after + the position in `stream_positions` table. + """ + self._insert_rows("foobar1", "first", 3) + self._insert_rows("foobar2", "second", 3) + self._insert_rows("foobar2", "second", 1, update_stream_table=False) + + first_id_gen = self._create_id_generator("first", writers=["first", "second"]) + second_id_gen = self._create_id_generator("second", writers=["first", "second"]) + + # The first ID gen will notice that it can advance its token to 7 as it + # has no in progress writes... + self.assertEqual(first_id_gen.get_positions(), {"first": 7, "second": 6}) + self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 7) + self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 6) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) + + # ... but the second ID gen doesn't know that. + self.assertEqual(second_id_gen.get_positions(), {"first": 3, "second": 7}) + self.assertEqual(second_id_gen.get_current_token_for_writer("first"), 3) + self.assertEqual(second_id_gen.get_current_token_for_writer("second"), 7) + self.assertEqual(first_id_gen.get_persisted_upto_position(), 7) -- cgit 1.5.1 From 73b03722f446bf182f5f7a0ed318dffd55513bd3 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 19 Jan 2021 14:56:54 +0000 Subject: Fix error messages from OIDC config parsing (#9153) Make sure we report the correct config path for errors in the OIDC configs. --- changelog.d/9153.feature | 1 + synapse/config/oidc_config.py | 25 +++++++++++++++---------- 2 files changed, 16 insertions(+), 10 deletions(-) create mode 100644 changelog.d/9153.feature (limited to 'synapse/config') diff --git a/changelog.d/9153.feature b/changelog.d/9153.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9153.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index c7fa749377..80a24cfbc9 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -15,7 +15,7 @@ # limitations under the License. import string -from typing import Iterable, Optional, Type +from typing import Iterable, Optional, Tuple, Type import attr @@ -280,8 +280,8 @@ def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConf """ validate_config(MAIN_CONFIG_SCHEMA, config, ()) - for p in config.get("oidc_providers") or []: - yield _parse_oidc_config_dict(p) + for i, p in enumerate(config.get("oidc_providers") or []): + yield _parse_oidc_config_dict(p, ("oidc_providers", "" % (i,))) # for backwards-compatibility, it is also possible to provide a single "oidc_config" # object with an "enabled: True" property. @@ -291,10 +291,12 @@ def _parse_oidc_provider_configs(config: JsonDict) -> Iterable["OidcProviderConf # it matches OIDC_PROVIDER_CONFIG_SCHEMA (see the comments on OIDC_CONFIG_SCHEMA # above), so now we need to validate it. validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",)) - yield _parse_oidc_config_dict(oidc_config) + yield _parse_oidc_config_dict(oidc_config, ("oidc_config",)) -def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": +def _parse_oidc_config_dict( + oidc_config: JsonDict, config_path: Tuple[str, ...] +) -> "OidcProviderConfig": """Take the configuration dict and parse it into an OidcProviderConfig Raises: @@ -305,7 +307,7 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": ump_config.setdefault("config", {}) (user_mapping_provider_class, user_mapping_provider_config,) = load_module( - ump_config, ("oidc_config", "user_mapping_provider") + ump_config, config_path + ("user_mapping_provider",) ) # Ensure loaded user mapping module has defined all necessary methods @@ -320,9 +322,9 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": ] if missing_methods: raise ConfigError( - "Class specified by oidc_config." - "user_mapping_provider.module is missing required " - "methods: %s" % (", ".join(missing_methods),) + "Class %s is missing required " + "methods: %s" % (user_mapping_provider_class, ", ".join(missing_methods),), + config_path + ("user_mapping_provider", "module"), ) # MSC2858 will appy certain limits in what can be used as an IdP id, so let's @@ -331,7 +333,10 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig": valid_idp_chars = set(string.ascii_letters + string.digits + "-._~") if any(c not in valid_idp_chars for c in idp_id): - raise ConfigError('idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"') + raise ConfigError( + 'idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"', + config_path + ("idp_id",), + ) return OidcProviderConfig( idp_id=idp_id, -- cgit 1.5.1 From fa50e4bf4ddcb8e98d44700513a28c490f80f02b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 12:30:41 +0000 Subject: Give `public_baseurl` a default value (#9159) --- changelog.d/9159.feature | 1 + docs/sample_config.yaml | 31 +++++++++++++++++-------------- synapse/api/urls.py | 2 -- synapse/config/_base.py | 11 ++++++----- synapse/config/emailconfig.py | 8 -------- synapse/config/oidc_config.py | 2 -- synapse/config/registration.py | 21 ++++----------------- synapse/config/saml2_config.py | 2 -- synapse/config/server.py | 24 +++++++++++++++--------- synapse/config/sso.py | 13 +++++-------- synapse/handlers/identity.py | 2 -- synapse/rest/well_known.py | 4 ---- tests/rest/test_well_known.py | 9 --------- tests/utils.py | 1 - 14 files changed, 48 insertions(+), 83 deletions(-) create mode 100644 changelog.d/9159.feature (limited to 'synapse/config') diff --git a/changelog.d/9159.feature b/changelog.d/9159.feature new file mode 100644 index 0000000000..b7748757de --- /dev/null +++ b/changelog.d/9159.feature @@ -0,0 +1 @@ +Give the `public_baseurl` a default value, if it is not explicitly set in the configuration file. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index ae995efe9b..7fdd798d70 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -67,11 +67,16 @@ pid_file: DATADIR/homeserver.pid # #web_client_location: https://riot.example.com/ -# The public-facing base URL that clients use to access this HS -# (not including _matrix/...). This is the same URL a user would -# enter into the 'custom HS URL' field on their client. If you -# use synapse with a reverse proxy, this should be the URL to reach -# synapse via the proxy. +# The public-facing base URL that clients use to access this Homeserver (not +# including _matrix/...). This is the same URL a user might enter into the +# 'Custom Homeserver URL' field on their client. If you use Synapse with a +# reverse proxy, this should be the URL to reach Synapse via the proxy. +# Otherwise, it should be the URL to reach Synapse's client HTTP listener (see +# 'listeners' below). +# +# If this is left unset, it defaults to 'https:///'. (Note that +# that will not work unless you configure Synapse or a reverse-proxy to listen +# on port 443.) # #public_baseurl: https://example.com/ @@ -1150,8 +1155,9 @@ account_validity: # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -1242,8 +1248,7 @@ account_validity: # The identity server which we suggest that clients should use when users log # in on this server. # -# (By default, no suggestion is made, so it is left up to the client. -# This setting is ignored unless public_baseurl is also set.) +# (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -1268,8 +1273,6 @@ account_validity: # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # -# If a delegate is specified, the config option public_baseurl must also be filled out. -# account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process @@ -1901,9 +1904,9 @@ sso: # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # diff --git a/synapse/api/urls.py b/synapse/api/urls.py index 6379c86dde..e36aeef31f 100644 --- a/synapse/api/urls.py +++ b/synapse/api/urls.py @@ -42,8 +42,6 @@ class ConsentURIBuilder: """ if hs_config.form_secret is None: raise ConfigError("form_secret not set in config") - if hs_config.public_baseurl is None: - raise ConfigError("public_baseurl not set in config") self._hmac_secret = hs_config.form_secret.encode("utf-8") self._public_baseurl = hs_config.public_baseurl diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 2931a88207..94144efc87 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -252,11 +252,12 @@ class Config: env = jinja2.Environment(loader=loader, autoescape=autoescape) # Update the environment with our custom filters - env.filters.update({"format_ts": _format_ts_filter}) - if self.public_baseurl: - env.filters.update( - {"mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl)} - ) + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) for filename in filenames: # Load the template diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index d4328c46b9..6a487afd34 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -166,11 +166,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - # public_baseurl is required to build password reset and validation links that - # will be emailed to users - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( MISSING_PASSWORD_RESET_CONFIG_ERROR % (", ".join(missing),) @@ -269,9 +264,6 @@ class EmailConfig(Config): if not self.email_notif_from: missing.append("email.notif_from") - if config.get("public_baseurl") is None: - missing.append("public_baseurl") - if missing: raise ConfigError( "email.enable_notifs is True but required keys are missing: %s" diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 80a24cfbc9..df55367434 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -43,8 +43,6 @@ class OIDCConfig(Config): raise ConfigError(e.message) from e public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("oidc_config requires a public_baseurl to be set") self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" @property diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 740c3fc1b1..4bfc69cb7a 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -49,10 +49,6 @@ class AccountValidityConfig(Config): self.startup_job_max_delta = self.period * 10.0 / 100.0 - if self.renew_by_email_enabled: - if "public_baseurl" not in synapse_config: - raise ConfigError("Can't send renewal emails without 'public_baseurl'") - template_dir = config.get("template_dir") if not template_dir: @@ -109,13 +105,6 @@ class RegistrationConfig(Config): account_threepid_delegates = config.get("account_threepid_delegates") or {} self.account_threepid_delegate_email = account_threepid_delegates.get("email") self.account_threepid_delegate_msisdn = account_threepid_delegates.get("msisdn") - if self.account_threepid_delegate_msisdn and not self.public_baseurl: - raise ConfigError( - "The configuration option `public_baseurl` is required if " - "`account_threepid_delegate.msisdn` is set, such that " - "clients know where to submit validation tokens to. Please " - "configure `public_baseurl`." - ) self.default_identity_server = config.get("default_identity_server") self.allow_guest_access = config.get("allow_guest_access", False) @@ -240,8 +229,9 @@ class RegistrationConfig(Config): # send an email to the account's email address with a renewal link. By # default, no such emails are sent. # - # If you enable this setting, you will also need to fill out the 'email' and - # 'public_baseurl' configuration sections. + # If you enable this setting, you will also need to fill out the 'email' + # configuration section. You should also check that 'public_baseurl' is set + # correctly. # #renew_at: 1w @@ -332,8 +322,7 @@ class RegistrationConfig(Config): # The identity server which we suggest that clients should use when users log # in on this server. # - # (By default, no suggestion is made, so it is left up to the client. - # This setting is ignored unless public_baseurl is also set.) + # (By default, no suggestion is made, so it is left up to the client.) # #default_identity_server: https://matrix.org @@ -358,8 +347,6 @@ class RegistrationConfig(Config): # by the Matrix Identity Service API specification: # https://matrix.org/docs/spec/identity_service/latest # - # If a delegate is specified, the config option public_baseurl must also be filled out. - # account_threepid_delegates: #email: https://example.com # Delegate email sending to example.com #msisdn: http://localhost:8090 # Delegate SMS sending to this local process diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 7b97d4f114..f33dfa0d6a 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -189,8 +189,6 @@ class SAML2Config(Config): import saml2 public_baseurl = self.public_baseurl - if public_baseurl is None: - raise ConfigError("saml2_config requires a public_baseurl to be set") if self.saml2_grandfathered_mxid_source_attribute: optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) diff --git a/synapse/config/server.py b/synapse/config/server.py index 7242a4aa8e..75ba161f35 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -161,7 +161,11 @@ class ServerConfig(Config): self.print_pidfile = config.get("print_pidfile") self.user_agent_suffix = config.get("user_agent_suffix") self.use_frozen_dicts = config.get("use_frozen_dicts", False) - self.public_baseurl = config.get("public_baseurl") + self.public_baseurl = config.get("public_baseurl") or "https://%s/" % ( + self.server_name, + ) + if self.public_baseurl[-1] != "/": + self.public_baseurl += "/" # Whether to enable user presence. self.use_presence = config.get("use_presence", True) @@ -317,9 +321,6 @@ class ServerConfig(Config): # Always blacklist 0.0.0.0, :: self.federation_ip_range_blacklist.update(["0.0.0.0", "::"]) - if self.public_baseurl is not None: - if self.public_baseurl[-1] != "/": - self.public_baseurl += "/" self.start_pushers = config.get("start_pushers", True) # (undocumented) option for torturing the worker-mode replication a bit, @@ -740,11 +741,16 @@ class ServerConfig(Config): # #web_client_location: https://riot.example.com/ - # The public-facing base URL that clients use to access this HS - # (not including _matrix/...). This is the same URL a user would - # enter into the 'custom HS URL' field on their client. If you - # use synapse with a reverse proxy, this should be the URL to reach - # synapse via the proxy. + # The public-facing base URL that clients use to access this Homeserver (not + # including _matrix/...). This is the same URL a user might enter into the + # 'Custom Homeserver URL' field on their client. If you use Synapse with a + # reverse proxy, this should be the URL to reach Synapse via the proxy. + # Otherwise, it should be the URL to reach Synapse's client HTTP listener (see + # 'listeners' below). + # + # If this is left unset, it defaults to 'https:///'. (Note that + # that will not work unless you configure Synapse or a reverse-proxy to listen + # on port 443.) # #public_baseurl: https://example.com/ diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 366f0d4698..59be825532 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -64,11 +64,8 @@ class SSOConfig(Config): # gracefully to the client). This would make it pointless to ask the user for # confirmation, since the URL the confirmation page would be showing wouldn't be # the client's. - # public_baseurl is an optional setting, so we only add the fallback's URL to the - # list if it's provided (because we can't figure out what that URL is otherwise). - if self.public_baseurl: - login_fallback_url = self.public_baseurl + "_matrix/static/client/login" - self.sso_client_whitelist.append(login_fallback_url) + login_fallback_url = self.public_baseurl + "_matrix/static/client/login" + self.sso_client_whitelist.append(login_fallback_url) def generate_config_section(self, **kwargs): return """\ @@ -86,9 +83,9 @@ class SSOConfig(Config): # phishing attacks from evil.site. To avoid this, include a slash after the # hostname: "https://my.client/". # - # If public_baseurl is set, then the login fallback page (used by clients - # that don't natively support the required login flows) is whitelisted in - # addition to any URLs in this list. + # The login fallback page (used by clients that don't natively support the + # required login flows) is automatically whitelisted in addition to any URLs + # in this list. # # By default, this list is empty. # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c05036ad1f..f61844d688 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - assert self.hs.config.public_baseurl - # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) diff --git a/synapse/rest/well_known.py b/synapse/rest/well_known.py index f591cc6c5c..241fe746d9 100644 --- a/synapse/rest/well_known.py +++ b/synapse/rest/well_known.py @@ -34,10 +34,6 @@ class WellKnownBuilder: self._config = hs.config def get_well_known(self): - # if we don't have a public_baseurl, we can't help much here. - if self._config.public_baseurl is None: - return None - result = {"m.homeserver": {"base_url": self._config.public_baseurl}} if self._config.default_identity_server: diff --git a/tests/rest/test_well_known.py b/tests/rest/test_well_known.py index 14de0921be..c5e44af9f7 100644 --- a/tests/rest/test_well_known.py +++ b/tests/rest/test_well_known.py @@ -40,12 +40,3 @@ class WellKnownTests(unittest.HomeserverTestCase): "m.identity_server": {"base_url": "https://testis"}, }, ) - - def test_well_known_no_public_baseurl(self): - self.hs.config.public_baseurl = None - - channel = self.make_request( - "GET", "/.well-known/matrix/client", shorthand=False - ) - - self.assertEqual(channel.code, 404) diff --git a/tests/utils.py b/tests/utils.py index 977eeaf6ee..09614093bc 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -159,7 +159,6 @@ def default_config(name, parse=False): "remote": {"per_second": 10000, "burst_count": 10000}, }, "saml2_enabled": False, - "public_baseurl": None, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, "old_signing_keys": {}, -- cgit 1.5.1 From 0cd2938bc854d947ae8102ded688a626c9fac5b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 13:15:14 +0000 Subject: Support icons for Identity Providers (#9154) --- changelog.d/9154.feature | 1 + docs/sample_config.yaml | 4 ++ mypy.ini | 1 + synapse/config/oidc_config.py | 20 ++++++ synapse/config/server.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/server.py | 2 +- synapse/handlers/cas_handler.py | 4 ++ synapse/handlers/oidc_handler.py | 3 + synapse/handlers/room.py | 2 +- synapse/handlers/saml_handler.py | 4 ++ synapse/handlers/sso.py | 5 ++ synapse/http/endpoint.py | 79 --------------------- synapse/res/templates/sso_login_idp_picker.html | 3 + synapse/rest/client/v1/room.py | 3 +- synapse/storage/databases/main/room.py | 6 +- synapse/types.py | 2 +- synapse/util/stringutils.py | 92 +++++++++++++++++++++++++ tests/http/test_endpoint.py | 2 +- 19 files changed, 146 insertions(+), 91 deletions(-) create mode 100644 changelog.d/9154.feature delete mode 100644 synapse/http/endpoint.py (limited to 'synapse/config') diff --git a/changelog.d/9154.feature b/changelog.d/9154.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9154.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7fdd798d70..b49a5da8cc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1726,6 +1726,10 @@ saml2_config: # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # +# idp_icon: An optional icon for this identity provider, which is presented +# by identity picker pages. If given, must be an MXC URI of the format +# mxc:/// +# # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # diff --git a/mypy.ini b/mypy.ini index b996867121..bd99069c81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -100,6 +100,7 @@ files = synapse/util/async_helpers.py, synapse/util/caches, synapse/util/metrics.py, + synapse/util/stringutils.py, tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index df55367434..f257fcd412 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -23,6 +23,7 @@ from synapse.config._util import validate_config from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module +from synapse.util.stringutils import parse_and_validate_mxc_uri from ._base import Config, ConfigError @@ -66,6 +67,10 @@ class OIDCConfig(Config): # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # + # idp_icon: An optional icon for this identity provider, which is presented + # by identity picker pages. If given, must be an MXC URI of the format + # mxc:/// + # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # @@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "properties": { "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, + "idp_icon": {"type": "string"}, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -336,9 +342,20 @@ def _parse_oidc_config_dict( config_path + ("idp_id",), ) + # MSC2858 also specifies that the idp_icon must be a valid MXC uri + idp_icon = oidc_config.get("idp_icon") + if idp_icon is not None: + try: + parse_and_validate_mxc_uri(idp_icon) + except ValueError as e: + raise ConfigError( + "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) + ) from e + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), + idp_icon=idp_icon, discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -366,6 +383,9 @@ class OidcProviderConfig: # user-facing name for this identity provider. idp_name = attr.ib(type=str) + # Optional MXC URI for icon for this IdP. + idp_icon = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/config/server.py b/synapse/config/server.py index 75ba161f35..47a0370173 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -26,7 +26,7 @@ import yaml from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e5339aca23..171d25c945 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -49,7 +49,6 @@ from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction -from synapse.http.endpoint import parse_server_name from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_server_name if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cfd094e58f..95c64510a9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,7 +28,6 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -45,6 +44,7 @@ from synapse.logging.opentracing import ( ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f3430c6713..0f342c607b 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,6 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" + # we do not currently support icons for CAS auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ba686d74b2..1607e12935 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -271,6 +271,9 @@ class OidcProvider: # user-facing name of this auth provider self.idp_name = provider.idp_name + # MXC URI for icon for this auth provider + self.idp_icon = provider.idp_icon + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3bece6d668..ee27d99135 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,7 +38,6 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents -from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -55,6 +54,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a8376543c9..38461cf79d 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,6 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" + # we do not currently support icons for SAML auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index dcc85e9871..d493327a10 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol): def idp_name(self) -> str: """User-facing name for this provider""" + @property + def idp_icon(self) -> Optional[str]: + """Optional MXC URI for user-facing icon""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py deleted file mode 100644 index 92a5b606c8..0000000000 --- a/synapse/http/endpoint.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -logger = logging.getLogger(__name__) - - -def parse_server_name(server_name): - """Split a server name into host/port parts. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - try: - if server_name[-1] == "]": - # ipv6 literal, hopefully - return server_name, None - - domain_port = server_name.rsplit(":", 1) - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None - return domain, port - except Exception: - raise ValueError("Invalid server name '%s'" % server_name) - - -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") - - -def parse_and_validate_server_name(server_name): - """Split a server name into host/port parts and do some basic validation. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - host, port = parse_server_name(server_name) - - # these tests don't need to be bulletproof as we'll find out soon enough - # if somebody is giving us invalid data. What we *do* need is to be sure - # that nobody is sneaking IP literals in that look like hostnames, etc. - - # look for ipv6 literals - if host[0] == "[": - if host[-1] != "]": - raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) - - return host, port diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index f53c9cd679..5b38481012 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -17,6 +17,9 @@
  • +{% if p.idp_icon %} + +{% endif %}
  • {% endfor %} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e6725b03b0..f95627ee61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -32,7 +32,6 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder -from synapse.util.stringutils import random_string +from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: import synapse.server diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 284f2ce77c..a9fcb5f59c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -16,7 +16,6 @@ import collections import logging -import re from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple @@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +from synapse.util.stringutils import MXC_REGEX logger = logging.getLogger(__name__) @@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - sql = """ SELECT stream_ordering, json FROM events JOIN event_json USING (room_id, event_id) @@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore): for url in (content_url, thumbnail_url): if not url: continue - matches = mxc_re.match(url) + matches = MXC_REGEX.match(url) if matches: hostname = matches.group(1) media_id = matches.group(2) diff --git a/synapse/types.py b/synapse/types.py index 20a43d05bf..eafe729dfe 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b103c8694c..f8038bf861 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -18,6 +18,7 @@ import random import re import string from collections.abc import Iterable +from typing import Optional, Tuple from synapse.api.errors import Codes, SynapseError @@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, +# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically +# says "there is no grammar for media ids" +# +# The server_name part of this is purposely lax: use parse_and_validate_mxc for +# additional validation. +# +MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret): ) +def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == "]": + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") + + +def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) + + return host, port + + +def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: + """Parse the given string as an MXC URI + + Checks that the "server name" part is a valid server name + + Args: + mxc: the (alleged) MXC URI to be checked + Returns: + hostname, port, media id + Raises: + ValueError if the URI cannot be parsed + """ + m = MXC_REGEX.match(mxc) + if not m: + raise ValueError("mxc URI %r did not match expected format" % (mxc,)) + server_name = m.group(1) + media_id = m.group(2) + host, port = parse_and_validate_server_name(server_name) + return host, port, media_id + + def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index b2e9533b07..d06ea518ce 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name +from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name from tests import unittest -- cgit 1.5.1 From e51b2f3f912534c8f6af70c746c993352a05c1be Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 13:55:14 +0000 Subject: Tighten the restrictions on `idp_id` (#9177) --- changelog.d/9177.feature | 1 + synapse/config/oidc_config.py | 12 +++++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) create mode 100644 changelog.d/9177.feature (limited to 'synapse/config') diff --git a/changelog.d/9177.feature b/changelog.d/9177.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9177.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index f257fcd412..8cb0c42f36 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -331,17 +331,23 @@ def _parse_oidc_config_dict( config_path + ("user_mapping_provider", "module"), ) - # MSC2858 will appy certain limits in what can be used as an IdP id, so let's + # MSC2858 will apply certain limits in what can be used as an IdP id, so let's # enforce those limits now. + # TODO: factor out this stuff to a generic function idp_id = oidc_config.get("idp_id", "oidc") - valid_idp_chars = set(string.ascii_letters + string.digits + "-._~") + valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") if any(c not in valid_idp_chars for c in idp_id): raise ConfigError( - 'idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"', + 'idp_id may only contain a-z, 0-9, "-", ".", "_"', config_path + ("idp_id",), ) + if idp_id[0] not in string.ascii_lowercase: + raise ConfigError( + "idp_id must start with a-z", config_path + ("idp_id",), + ) + # MSC2858 also specifies that the idp_icon must be a valid MXC uri idp_icon = oidc_config.get("idp_icon") if idp_icon is not None: -- cgit 1.5.1 From 7447f197026db570c1c1af240642566b31f81e42 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 21 Jan 2021 12:25:02 +0000 Subject: Prefix idp_id with "oidc-" (#9189) ... to avoid clashes with other SSO mechanisms --- changelog.d/9189.misc | 1 + docs/sample_config.yaml | 13 +++++++++---- synapse/config/oidc_config.py | 28 ++++++++++++++++++++++++---- tests/rest/client/v1/test_login.py | 2 +- 4 files changed, 35 insertions(+), 9 deletions(-) create mode 100644 changelog.d/9189.misc (limited to 'synapse/config') diff --git a/changelog.d/9189.misc b/changelog.d/9189.misc new file mode 100644 index 0000000000..9a5740aac2 --- /dev/null +++ b/changelog.d/9189.misc @@ -0,0 +1 @@ +Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index b49a5da8cc..87bfe22237 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1728,7 +1728,9 @@ saml2_config: # # idp_icon: An optional icon for this identity provider, which is presented # by identity picker pages. If given, must be an MXC URI of the format -# mxc:/// +# mxc:///. (An easy way to obtain such an MXC URI +# is to upload an image to an (unencrypted) room and then copy the "url" +# from the source of the event.) # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -1814,13 +1816,16 @@ saml2_config: # # For backwards compatibility, it is also possible to configure a single OIDC # provider via an 'oidc_config' setting. This is now deprecated and admins are -# advised to migrate to the 'oidc_providers' format. +# advised to migrate to the 'oidc_providers' format. (When doing that migration, +# use 'oidc' for the idp_id to ensure that existing users continue to be +# recognised.) # oidc_providers: # Generic example # #- idp_id: my_idp # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" # discover: false # issuer: "https://accounts.example.com/" # client_id: "provided-by-your-issuer" @@ -1844,8 +1849,8 @@ oidc_providers: # For use with Github # - #- idp_id: google - # idp_name: Google + #- idp_id: github + # idp_name: Github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 8cb0c42f36..d58a83be7f 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -69,7 +69,9 @@ class OIDCConfig(Config): # # idp_icon: An optional icon for this identity provider, which is presented # by identity picker pages. If given, must be an MXC URI of the format - # mxc:/// + # mxc:///. (An easy way to obtain such an MXC URI + # is to upload an image to an (unencrypted) room and then copy the "url" + # from the source of the event.) # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -155,13 +157,16 @@ class OIDCConfig(Config): # # For backwards compatibility, it is also possible to configure a single OIDC # provider via an 'oidc_config' setting. This is now deprecated and admins are - # advised to migrate to the 'oidc_providers' format. + # advised to migrate to the 'oidc_providers' format. (When doing that migration, + # use 'oidc' for the idp_id to ensure that existing users continue to be + # recognised.) # oidc_providers: # Generic example # #- idp_id: my_idp # idp_name: "My OpenID provider" + # idp_icon: "mxc://example.com/mediaid" # discover: false # issuer: "https://accounts.example.com/" # client_id: "provided-by-your-issuer" @@ -185,8 +190,8 @@ class OIDCConfig(Config): # For use with Github # - #- idp_id: google - # idp_name: Google + #- idp_id: github + # idp_name: Github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -210,6 +215,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", "required": ["issuer", "client_id", "client_secret"], "properties": { + # TODO: fix the maxLength here depending on what MSC2528 decides + # remember that we prefix the ID given here with `oidc-` "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, "idp_icon": {"type": "string"}, @@ -335,6 +342,8 @@ def _parse_oidc_config_dict( # enforce those limits now. # TODO: factor out this stuff to a generic function idp_id = oidc_config.get("idp_id", "oidc") + + # TODO: update this validity check based on what MSC2858 decides. valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") if any(c not in valid_idp_chars for c in idp_id): @@ -348,6 +357,17 @@ def _parse_oidc_config_dict( "idp_id must start with a-z", config_path + ("idp_id",), ) + # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid + # clashes with other mechs (such as SAML, CAS). + # + # We allow "oidc" as an exception so that people migrating from old-style + # "oidc_config" format (which has long used "oidc" as its idp_id) can migrate to + # a new-style "oidc_providers" entry without changing the idp_id for their provider + # (and thereby invalidating their user_external_ids data). + + if idp_id != "oidc": + idp_id = "oidc-" + idp_id + # MSC2858 also specifies that the idp_icon must be a valid MXC uri idp_icon = oidc_config.get("idp_icon") if idp_icon is not None: diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2d25490374..2672ce24c6 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -446,7 +446,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "oidc-idp1", "saml"]) self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) -- cgit 1.5.1 From 42a8e81370855a2c612f2acfd1c0648329a12aff Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 21 Jan 2021 13:20:58 +0000 Subject: Add a check for duplicate IdP ids (#9184) --- changelog.d/9184.misc | 1 + synapse/config/oidc_config.py | 11 +++++++++++ 2 files changed, 12 insertions(+) create mode 100644 changelog.d/9184.misc (limited to 'synapse/config') diff --git a/changelog.d/9184.misc b/changelog.d/9184.misc new file mode 100644 index 0000000000..70da3d6cf5 --- /dev/null +++ b/changelog.d/9184.misc @@ -0,0 +1 @@ +Emit an error at startup if different Identity Providers are configured with the same `idp_id`. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index d58a83be7f..bfeceeed18 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -15,6 +15,7 @@ # limitations under the License. import string +from collections import Counter from typing import Iterable, Optional, Tuple, Type import attr @@ -43,6 +44,16 @@ class OIDCConfig(Config): except DependencyException as e: raise ConfigError(e.message) from e + # check we don't have any duplicate idp_ids now. (The SSO handler will also + # check for duplicates when the REST listeners get registered, but that happens + # after synapse has forked so doesn't give nice errors.) + c = Counter([i.idp_id for i in self.oidc_providers]) + for idp_id, count in c.items(): + if count > 1: + raise ConfigError( + "Multiple OIDC providers have the idp_id %r." % idp_id + ) + public_baseurl = self.public_baseurl self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" -- cgit 1.5.1 From dd8da8c5f6ac525a7456437913a03f68d4504605 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Jan 2021 13:57:31 +0000 Subject: Precompute joined hosts and store in Redis (#9198) --- changelog.d/9198.misc | 1 + stubs/txredisapi.pyi | 12 +++- synapse/config/_base.pyi | 2 + synapse/federation/sender/__init__.py | 50 +++++++++----- synapse/handlers/federation.py | 5 ++ synapse/handlers/message.py | 42 ++++++++++++ synapse/replication/tcp/external_cache.py | 105 ++++++++++++++++++++++++++++++ synapse/replication/tcp/handler.py | 15 +---- synapse/server.py | 30 +++++++++ synapse/state/__init__.py | 11 +++- tests/replication/_base.py | 41 +++++++----- 11 files changed, 265 insertions(+), 49 deletions(-) create mode 100644 changelog.d/9198.misc create mode 100644 synapse/replication/tcp/external_cache.py (limited to 'synapse/config') diff --git a/changelog.d/9198.misc b/changelog.d/9198.misc new file mode 100644 index 0000000000..a6cb77fbb2 --- /dev/null +++ b/changelog.d/9198.misc @@ -0,0 +1 @@ +Precompute joined hosts and store in Redis. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index bdc892ec82..618548a305 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,11 +15,21 @@ """Contains *incomplete* type hints for txredisapi. """ -from typing import List, Optional, Type, Union +from typing import Any, List, Optional, Type, Union class RedisProtocol: def publish(self, channel: str, message: bytes): ... async def ping(self) -> None: ... + async def set( + self, + key: str, + value: Any, + expire: Optional[int] = None, + pexpire: Optional[int] = None, + only_if_not_exists: bool = False, + only_if_exists: bool = False, + ) -> None: ... + async def get(self, key: str) -> Any: ... class SubscriberProtocol(RedisProtocol): def __init__(self, *args, **kwargs): ... diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..8ba669059a 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -18,6 +18,7 @@ from synapse.config import ( password_auth_providers, push, ratelimiting, + redis, registration, repository, room_directory, @@ -79,6 +80,7 @@ class RootConfig: roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig tracer: tracer.TracerConfig + redis: redis.RedisConfig config_classes: List = ... def __init__(self) -> None: ... diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 604cfd1935..643b26ae6d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -142,6 +142,8 @@ class FederationSender: self._wake_destinations_needing_catchup, ) + self._external_cache = hs.get_external_cache() + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -197,22 +199,40 @@ class FederationSender: if not event.internal_metadata.should_proactively_send(): return - try: - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - destinations = await self.state.get_hosts_in_room_at_events( - event.room_id, event_ids=event.prev_event_ids() - ) - except Exception: - logger.exception( - "Failed to calculate hosts in room for event: %s", - event.event_id, + destinations = None # type: Optional[Set[str]] + if not event.prev_event_ids(): + # If there are no prev event IDs then the state is empty + # and so no remote servers in the room + destinations = set() + else: + # We check the external cache for the destinations, which is + # stored per state group. + + sg = await self._external_cache.get( + "event_to_prev_state_group", event.event_id ) - return + if sg: + destinations = await self._external_cache.get( + "get_joined_hosts", str(sg) + ) + + if destinations is None: + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = await self.state.get_hosts_in_room_at_events( + event.room_id, event_ids=event.prev_event_ids() + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return destinations = { d diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fd8de8696d..b6dc7f99b6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.GuestAccess and not context.rejected: await self.maybe_kick_guest_users(event) + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _check_for_soft_fail( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9dfeab09cd..e2a7d567fa 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -432,6 +432,8 @@ class EventCreationHandler: self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._external_cache = hs.get_external_cache() + async def create_event( self, requester: Requester, @@ -939,6 +941,8 @@ class EventCreationHandler: await self.action_generator.handle_push_actions_for_event(event, context) + await self.cache_joined_hosts_for_event(event) + try: # If we're a worker we need to hit out to the master. writer_instance = self._events_shard_config.get_instance(event.room_id) @@ -978,6 +982,44 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise + async def cache_joined_hosts_for_event(self, event: EventBase) -> None: + """Precalculate the joined hosts at the event, when using Redis, so that + external federation senders don't have to recalculate it themselves. + """ + + if not self._external_cache.is_enabled(): + return + + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We always set the state group -> joined hosts cache, even if + # we already set it, so that the expiry time is reset. + + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() + ) + + if state_entry.state_group: + joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) + async def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str ) -> None: diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py new file mode 100644 index 0000000000..34fa3ff5b3 --- /dev/null +++ b/synapse/replication/tcp/external_cache.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from prometheus_client import Counter + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import json_decoder, json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + +set_counter = Counter( + "synapse_external_cache_set", + "Number of times we set a cache", + labelnames=["cache_name"], +) + +get_counter = Counter( + "synapse_external_cache_get", + "Number of times we get a cache", + labelnames=["cache_name", "hit"], +) + + +logger = logging.getLogger(__name__) + + +class ExternalCache: + """A cache backed by an external Redis. Does nothing if no Redis is + configured. + """ + + def __init__(self, hs: "HomeServer"): + self._redis_connection = hs.get_outbound_redis_connection() + + def _get_redis_key(self, cache_name: str, key: str) -> str: + return "cache_v1:%s:%s" % (cache_name, key) + + def is_enabled(self) -> bool: + """Whether the external cache is used or not. + + It's safe to use the cache when this returns false, the methods will + just no-op, but the function is useful to avoid doing unnecessary work. + """ + return self._redis_connection is not None + + async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: + """Add the key/value to the named cache, with the expiry time given. + """ + + if self._redis_connection is None: + return + + set_counter.labels(cache_name).inc() + + # txredisapi requires the value to be string, bytes or numbers, so we + # encode stuff in JSON. + encoded_value = json_encoder.encode(value) + + logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) + + return await make_deferred_yieldable( + self._redis_connection.set( + self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + ) + ) + + async def get(self, cache_name: str, key: str) -> Optional[Any]: + """Look up a key/value in the named cache. + """ + + if self._redis_connection is None: + return None + + result = await make_deferred_yieldable( + self._redis_connection.get(self._get_redis_key(cache_name, key)) + ) + + logger.debug("Got cache result %s %s: %r", cache_name, key, result) + + get_counter.labels(cache_name, result is not None).inc() + + if not result: + return None + + # For some reason the integers get magically converted back to integers + if isinstance(result, int): + return result + + return json_decoder.decode(result) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 58d46a5951..8ea8dcd587 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -286,13 +286,6 @@ class ReplicationCommandHandler: if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, - lazyConnection, - ) - - logger.info( - "Connecting to redis (host=%r port=%r)", - hs.config.redis_host, - hs.config.redis_port, ) # First let's ensure that we have a ReplicationStreamer started. @@ -303,13 +296,7 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = lazyConnection( - hs=hs, - host=hs.config.redis_host, - port=hs.config.redis_port, - password=hs.config.redis.redis_password, - reconnect=True, - ) + outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( diff --git a/synapse/server.py b/synapse/server.py index 9cdda83aa1..9bdd3177d7 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.external_cache import ExternalCache from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.streams import STREAMS_MAP, Stream @@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.handlers.oidc_handler import OidcHandler from synapse.handlers.saml_handler import SamlHandler @@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta): def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) + @cache_in_self + def get_external_cache(self) -> ExternalCache: + return ExternalCache(self) + + @cache_in_self + def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: + if not self.config.redis.redis_enabled: + return None + + # We only want to import redis module if we're using it, as we have + # `txredisapi` as an optional dependency. + from synapse.replication.tcp.redis import lazyConnection + + logger.info( + "Connecting to redis (host=%r port=%r) for external cache", + self.config.redis_host, + self.config.redis_port, + ) + + return lazyConnection( + hs=self, + host=self.config.redis_host, + port=self.config.redis_port, + password=self.config.redis.redis_password, + reconnect=True, + ) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 84f59c7d85..3bd9ff8ca0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -310,6 +310,7 @@ class StateHandler: state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None + entry = None else: # otherwise, we'll need to resolve the state across the prev_events. @@ -340,9 +341,13 @@ class StateHandler: current_state_ids=state_ids_before_event, ) - # XXX: can we update the state cache entry for the new state group? or - # could we set a flag on resolve_state_groups_for_events to tell it to - # always make a state group? + # Assign the new state group to the cached state entry. + # + # Note that this can race in that we could generate multiple state + # groups for the same state entry, but that is just inefficient + # rather than dangerous. + if entry and entry.state_group is None: + entry.state_group = state_group_before_event # # now if it's not a state event, we're done diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 3379189785..d5dce1f83f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Fake in memory Redis server that servers can connect to. self._redis_server = FakeRedisPubSubServer() + # We may have an attempt to connect to redis for the external cache already. + self.connect_any_redis_attempts() + store = self.hs.get_datastore() self.database_pool = store.db_pool @@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): fake one. """ clients = self.reactor.tcpClients - self.assertEqual(len(clients), 1) - (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") - self.assertEqual(port, 6379) + while clients: + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) - client_protocol = client_factory.buildProtocol(None) - server_protocol = self._redis_server.buildProtocol(None) + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) - client_to_server_transport = FakeTransport( - server_protocol, self.reactor, client_protocol - ) - client_protocol.makeConnection(client_to_server_transport) - - server_to_client_transport = FakeTransport( - client_protocol, self.reactor, server_protocol - ) - server_protocol.makeConnection(server_to_client_transport) + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) - return client_to_server_transport, server_to_client_transport + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) class TestReplicationDataHandler(GenericWorkerReplicationHandler): @@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol): (channel,) = args self._server.add_subscriber(self) self.send(["subscribe", channel, 1]) + + # Since we use SET/GET to cache things we can safely no-op them. + elif command == b"SET": + self.send("OK") + elif command == b"GET": + self.send(None) else: raise Exception("Unknown command") @@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol): # We assume bytes are just unicode strings. obj = obj.decode("utf-8") + if obj is None: + return "$-1\r\n" if isinstance(obj, str): return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) if isinstance(obj, int): -- cgit 1.5.1 From 26837d5dbeae211968b3d52cdc10f005ba612a9f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jan 2021 10:49:25 -0500 Subject: Do not require the CAS service URL setting (use public_baseurl instead). (#9199) The current configuration is handled for backwards compatibility, but is considered deprecated. --- changelog.d/9199.removal | 1 + docs/sample_config.yaml | 4 ---- synapse/config/cas.py | 12 +++++++----- synapse/config/oidc_config.py | 3 +-- synapse/handlers/cas_handler.py | 6 +----- 5 files changed, 10 insertions(+), 16 deletions(-) create mode 100644 changelog.d/9199.removal (limited to 'synapse/config') diff --git a/changelog.d/9199.removal b/changelog.d/9199.removal new file mode 100644 index 0000000000..fbd2916cbf --- /dev/null +++ b/changelog.d/9199.removal @@ -0,0 +1 @@ +The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 87bfe22237..c2ccd68f3a 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1878,10 +1878,6 @@ cas_config: # #server_url: "https://cas-server.com" - # The public URL of the homeserver. - # - #service_url: "https://homeserver.domain.com:8448" - # The attribute of the CAS response to use as the display name. # # If unset, no displayname will be set. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index c7877b4095..b226890c2a 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -30,7 +30,13 @@ class CasConfig(Config): if self.cas_enabled: self.cas_server_url = cas_config["server_url"] - self.cas_service_url = cas_config["service_url"] + public_base_url = cas_config.get("service_url") or self.public_baseurl + if public_base_url[-1] != "/": + public_base_url += "/" + # TODO Update this to a _synapse URL. + self.cas_service_url = ( + public_base_url + "_matrix/client/r0/login/cas/ticket" + ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_required_attributes = cas_config.get("required_attributes") or {} else: @@ -53,10 +59,6 @@ class CasConfig(Config): # #server_url: "https://cas-server.com" - # The public URL of the homeserver. - # - #service_url: "https://homeserver.domain.com:8448" - # The attribute of the CAS response to use as the display name. # # If unset, no displayname will be set. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index bfeceeed18..0162d7f7b0 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -54,8 +54,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - public_baseurl = self.public_baseurl - self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" + self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 0f342c607b..21b6bc4992 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -99,11 +99,7 @@ class CasHandler: Returns: The URL to use as a "service" parameter. """ - return "%s%s?%s" % ( - self._cas_service_url, - "/_matrix/client/r0/login/cas/ticket", - urllib.parse.urlencode(args), - ) + return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) async def _validate_ticket( self, ticket: str, service_args: Dict[str, str] -- cgit 1.5.1 From a737cc27134c50059440ca33510b0baea53b4225 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 27 Jan 2021 12:41:24 +0000 Subject: Implement MSC2858 support (#9183) Fixes #8928. --- changelog.d/9183.feature | 1 + synapse/config/_base.pyi | 2 + synapse/config/experimental.py | 29 ++++++++++++ synapse/config/homeserver.py | 2 + synapse/handlers/sso.py | 23 +++++++--- synapse/http/server.py | 44 ++++++++++++++---- synapse/rest/client/v1/login.py | 55 ++++++++++++++++++++--- tests/rest/client/v1/test_login.py | 92 ++++++++++++++++++++++++++++++++++++++ tests/utils.py | 3 +- 9 files changed, 230 insertions(+), 21 deletions(-) create mode 100644 changelog.d/9183.feature create mode 100644 synapse/config/experimental.py (limited to 'synapse/config') diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature new file mode 100644 index 0000000000..2d5c735042 --- /dev/null +++ b/changelog.d/9183.feature @@ -0,0 +1 @@ +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858). diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..3ccea4b02d 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -9,6 +9,7 @@ from synapse.config import ( consent_config, database, emailconfig, + experimental, groups, jwt_config, key, @@ -48,6 +49,7 @@ def path_exists(file_path: str): ... class RootConfig: server: server.ServerConfig + experimental: experimental.ExperimentalConfig tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py new file mode 100644 index 0000000000..b1c1c51e4d --- /dev/null +++ b/synapse/config/experimental.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.types import JsonDict + + +class ExperimentalConfig(Config): + """Config section for enabling experimental features""" + + section = "experimental" + + def read_config(self, config: JsonDict, **kwargs): + experimental = config.get("experimental_features") or {} + + # MSC2858 (multiple SSO identity providers) + self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 4bd2b3587b..64a2429f77 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -24,6 +24,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .experimental import ExperimentalConfig from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, + ExperimentalConfig, TlsConfig, FederationConfig, CacheConfig, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d493327a10..afc1341d09 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request from synapse.api.constants import LoginType -from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html @@ -235,7 +235,10 @@ class SsoHandler: respond_with_html(request, code, html) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes, + self, + request: SynapseRequest, + client_redirect_url: bytes, + idp_id: Optional[str], ) -> str: """Handle a request to /login/sso/redirect @@ -243,6 +246,7 @@ class SsoHandler: request: incoming HTTP request client_redirect_url: the URL that we should redirect the client to after login. + idp_id: optional identity provider chosen by the client Returns: the URI to redirect to @@ -252,10 +256,19 @@ class SsoHandler: 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED ) + # if the client chose an IdP, use that + idp = None # type: Optional[SsoIdentityProvider] + if idp_id: + idp = self._identity_providers.get(idp_id) + if not idp: + raise NotFoundError("Unknown identity provider") + # if we only have one auth provider, redirect to it directly - if len(self._identity_providers) == 1: - ap = next(iter(self._identity_providers.values())) - return await ap.handle_redirect_request(request, client_redirect_url) + elif len(self._identity_providers) == 1: + idp = next(iter(self._identity_providers.values())) + + if idp: + return await idp.handle_redirect_request(request, client_redirect_url) # otherwise, redirect to the IDP picker return "/_synapse/client/pick_idp?" + urlencode( diff --git a/synapse/http/server.py b/synapse/http/server.py index e464bfe6c7..d69d579b3a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,10 +22,22 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Iterator, List, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Pattern, + Tuple, + Union, +) import jinja2 from canonicaljson import iterencode_canonical_json +from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -168,11 +180,25 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer: +# Type of a callback method for processing requests +# it is actually called with a SynapseRequest and a kwargs dict for the params, +# but I can't figure out how to represent that. +ServletCallback = Callable[ + ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]] +] + + +class HttpServer(Protocol): """ Interface for registering callbacks on a HTTP server """ - def register_paths(self, method, path_patterns, callback): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -180,12 +206,14 @@ class HttpServer: an unpacked tuple. Args: - method (str): The method to listen to. - path_patterns (list): The regex used to match requests. - callback (function): The function to fire if we receive a matched + method: The HTTP method to listen to. + path_patterns: The regex used to match requests. + callback: The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. - This should return a tuple of (code, response). + This should return either tuple of (code, response), or None. + servlet_classname (str): The name of the handler to be used in prometheus + and opentracing logs. """ pass @@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource): def _get_handler_for_request( self, request: SynapseRequest - ) -> Tuple[Callable, str, Dict[str, str]]: + ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. Returns: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index be938df962..0a561eea60 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.http.server import finish_request +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled self.auth = hs.get_auth() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - # While its valid for us to advertise this login type generally, + sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + + if self._msc2858_enabled: + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the # SSO login flow. # Generally we don't want to advertise login flows that clients @@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet): return result +def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + """ + e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + if idp.idp_icon: + e["icon"] = idp.idp_icon + return e + + class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet): if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) - async def on_GET(self, request: SynapseRequest): + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) sso_url = await self._sso_handler.handle_redirect_request( - request, client_redirect_url + request, client_redirect_url, idp_id, ) logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2672ce24c6..e2bb945453 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d["/_synapse/oidc"] = OIDCResource(self.hs) return d + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.sso"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(channel.json_body["flows"], expected_flows) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results = {} # type: Dict[str, Any] + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker @@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/utils.py b/tests/utils.py index 09614093bc..022223cf24 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore @@ -351,7 +350,7 @@ def mock_getRawHeaders(headers=None): # This is a mock /resource/ not an entire server -class MockHttpResource(HttpServer): +class MockHttpResource: def __init__(self, prefix=""): self.callbacks = [] # 3-tuple of method/pattern/function self.prefix = prefix -- cgit 1.5.1 From e54746bdf7d5c831eabe4dcea76a7626f1de73df Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 27 Jan 2021 10:59:50 -0500 Subject: Clean-up the template loading code. (#9200) * Enables autoescape by default for HTML files. * Adds a new read_template method for reading a single template. * Some logic clean-up. --- UPGRADE.rst | 37 ++++++++++++++++++++++ changelog.d/9200.misc | 1 + synapse/config/_base.py | 42 +++++++++++++++---------- synapse/config/captcha.py | 4 +-- synapse/config/consent_config.py | 2 +- synapse/config/registration.py | 4 +-- synapse/push/mailer.py | 18 +++++++++-- synapse/res/templates/sso_auth_bad_user.html | 2 +- synapse/res/templates/sso_auth_confirm.html | 4 +-- synapse/res/templates/sso_error.html | 2 +- synapse/res/templates/sso_login_idp_picker.html | 12 +++---- synapse/res/templates/sso_redirect_confirm.html | 6 ++-- 12 files changed, 96 insertions(+), 38 deletions(-) create mode 100644 changelog.d/9200.misc (limited to 'synapse/config') diff --git a/UPGRADE.rst b/UPGRADE.rst index d09dbd4e21..e62e647a1d 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -85,6 +85,43 @@ for example: wget https://packages.matrix.org/debian/pool/main/m/matrix-synapse-py3/matrix-synapse-py3_1.3.0+stretch1_amd64.deb dpkg -i matrix-synapse-py3_1.3.0+stretch1_amd64.deb +Upgrading to v1.27.0 +==================== + +Changes to HTML templates +------------------------- + +The HTML templates for SSO and email notifications now have `Jinja2's autoescape `_ +enabled for files ending in ``.html``, ``.htm``, and ``.xml``. If you hae customised +these templates and see issues when viewing them you might need to update them. +It is expected that most configurations will need no changes. + +If you have customised the templates *names* for these templates it is recommended +to verify they end in ``.html`` to ensure autoescape is enabled. + +The above applies to the following templates: + +* ``add_threepid.html`` +* ``add_threepid_failure.html`` +* ``add_threepid_success.html`` +* ``notice_expiry.html`` +* ``notice_expiry.html`` +* ``notif_mail.html`` (which, by default, includes ``room.html`` and ``notif.html``) +* ``password_reset.html`` +* ``password_reset_confirmation.html`` +* ``password_reset_failure.html`` +* ``password_reset_success.html`` +* ``registration.html`` +* ``registration_failure.html`` +* ``registration_success.html`` +* ``sso_account_deactivated.html`` +* ``sso_auth_bad_user.html`` +* ``sso_auth_confirm.html`` +* ``sso_auth_success.html`` +* ``sso_error.html`` +* ``sso_login_idp_picker.html`` +* ``sso_redirect_confirm.html`` + Upgrading to v1.26.0 ==================== diff --git a/changelog.d/9200.misc b/changelog.d/9200.misc new file mode 100644 index 0000000000..5f239ff9da --- /dev/null +++ b/changelog.d/9200.misc @@ -0,0 +1 @@ +Clean-up template loading code. diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 94144efc87..6a0768ce00 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -203,11 +203,28 @@ class Config: with open(file_path) as file_stream: return file_stream.read() + def read_template(self, filename: str) -> jinja2.Template: + """Load a template file from disk. + + This function will attempt to load the given template from the default Synapse + template directory. + + Files read are treated as Jinja templates. The templates is not rendered yet + and has autoescape enabled. + + Args: + filename: A template filename to read. + + Raises: + ConfigError: if the file's path is incorrect or otherwise cannot be read. + + Returns: + A jinja2 template. + """ + return self.read_templates([filename])[0] + def read_templates( - self, - filenames: List[str], - custom_template_directory: Optional[str] = None, - autoescape: bool = False, + self, filenames: List[str], custom_template_directory: Optional[str] = None, ) -> List[jinja2.Template]: """Load a list of template files from disk using the given variables. @@ -215,7 +232,8 @@ class Config: template directory. If `custom_template_directory` is supplied, that directory is tried first. - Files read are treated as Jinja templates. These templates are not rendered yet. + Files read are treated as Jinja templates. The templates are not rendered yet + and have autoescape enabled. Args: filenames: A list of template filenames to read. @@ -223,16 +241,12 @@ class Config: custom_template_directory: A directory to try to look for the templates before using the default Synapse template directory instead. - autoescape: Whether to autoescape variables before inserting them into the - template. - Raises: ConfigError: if the file's path is incorrect or otherwise cannot be read. Returns: A list of jinja2 templates. """ - templates = [] search_directories = [self.default_template_dir] # The loader will first look in the custom template directory (if specified) for the @@ -249,7 +263,7 @@ class Config: search_directories.insert(0, custom_template_directory) loader = jinja2.FileSystemLoader(search_directories) - env = jinja2.Environment(loader=loader, autoescape=autoescape) + env = jinja2.Environment(loader=loader, autoescape=jinja2.select_autoescape(),) # Update the environment with our custom filters env.filters.update( @@ -259,12 +273,8 @@ class Config: } ) - for filename in filenames: - # Load the template - template = env.get_template(filename) - templates.append(template) - - return templates + # Load the templates + return [env.get_template(filename) for filename in filenames] def _format_ts_filter(value: int, format: str): diff --git a/synapse/config/captcha.py b/synapse/config/captcha.py index cb00958165..9e48f865cc 100644 --- a/synapse/config/captcha.py +++ b/synapse/config/captcha.py @@ -28,9 +28,7 @@ class CaptchaConfig(Config): "recaptcha_siteverify_api", "https://www.recaptcha.net/recaptcha/api/siteverify", ) - self.recaptcha_template = self.read_templates( - ["recaptcha.html"], autoescape=True - )[0] + self.recaptcha_template = self.read_template("recaptcha.html") def generate_config_section(self, **kwargs): return """\ diff --git a/synapse/config/consent_config.py b/synapse/config/consent_config.py index 6efa59b110..c47f364b14 100644 --- a/synapse/config/consent_config.py +++ b/synapse/config/consent_config.py @@ -89,7 +89,7 @@ class ConsentConfig(Config): def read_config(self, config, **kwargs): consent_config = config.get("user_consent") - self.terms_template = self.read_templates(["terms.html"], autoescape=True)[0] + self.terms_template = self.read_template("terms.html") if consent_config is None: return diff --git a/synapse/config/registration.py b/synapse/config/registration.py index 4bfc69cb7a..ac48913a0b 100644 --- a/synapse/config/registration.py +++ b/synapse/config/registration.py @@ -176,9 +176,7 @@ class RegistrationConfig(Config): self.session_lifetime = session_lifetime # The success template used during fallback auth. - self.fallback_success_template = self.read_templates( - ["auth_success.html"], autoescape=True - )[0] + self.fallback_success_template = self.read_template("auth_success.html") def generate_config_section(self, generate_secrets=False, **kwargs): if generate_secrets: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 4d875dcb91..745b1dde94 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -668,6 +668,15 @@ class Mailer: def safe_markup(raw_html: str) -> jinja2.Markup: + """ + Sanitise a raw HTML string to a set of allowed tags and attributes, and linkify any bare URLs. + + Args + raw_html: Unsafe HTML. + + Returns: + A Markup object ready to safely use in a Jinja template. + """ return jinja2.Markup( bleach.linkify( bleach.clean( @@ -684,8 +693,13 @@ def safe_markup(raw_html: str) -> jinja2.Markup: def safe_text(raw_text: str) -> jinja2.Markup: """ - Process text: treat it as HTML but escape any tags (ie. just escape the - HTML) then linkify it. + Sanitise text (escape any HTML tags), and then linkify any bare URLs. + + Args + raw_text: Unsafe text which might include HTML markup. + + Returns: + A Markup object ready to safely use in a Jinja template. """ return jinja2.Markup( bleach.linkify(bleach.clean(raw_text, tags=[], attributes={}, strip=False)) diff --git a/synapse/res/templates/sso_auth_bad_user.html b/synapse/res/templates/sso_auth_bad_user.html index 3611191bf9..f7099098c7 100644 --- a/synapse/res/templates/sso_auth_bad_user.html +++ b/synapse/res/templates/sso_auth_bad_user.html @@ -5,7 +5,7 @@

    - We were unable to validate your {{server_name | e}} account via + We were unable to validate your {{ server_name }} account via single-sign-on (SSO), because the SSO Identity Provider returned different details than when you logged in.

    diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html index 0d9de9d465..4e7ca3a2ed 100644 --- a/synapse/res/templates/sso_auth_confirm.html +++ b/synapse/res/templates/sso_auth_confirm.html @@ -5,8 +5,8 @@

    - A client is trying to {{ description | e }}. To confirm this action, - re-authenticate with single sign-on. + A client is trying to {{ description }}. To confirm this action, + re-authenticate with single sign-on. If you did not expect this, your account may be compromised!

    diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index 944bc9c9ca..af8459719a 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -12,7 +12,7 @@

    There was an error during authentication:

    -
    {{ error_description | e }}
    +
    {{ error_description }}

    If you are seeing this page after clicking a link sent to you via email, make sure you only click the confirmation link once, and that you open the diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index 5b38481012..62a640dad2 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -3,22 +3,22 @@ - {{server_name | e}} Login + {{ server_name }} Login

    -

    {{server_name | e}} Login

    +

    {{ server_name }} Login