From 0f1afbe8dc853c8726de797ede10cad2fe336b16 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 11 Aug 2020 18:00:17 +0100 Subject: Change HomeServer definition to work with typing. Duplicating function signatures between server.py and server.pyi is silly. This commit changes that by changing all `build_*` methods to `get_*` methods and changing the `_make_dependency_method` to work work as a descriptor that caches the produced value. There are some changes in other files that were made to fix the typing in server.py. --- synapse/handlers/oidc_handler.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'synapse/handlers/oidc_handler.py') diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 87f0c5e197..fa5ee5de8f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -14,7 +14,7 @@ # limitations under the License. import json import logging -from typing import Dict, Generic, List, Optional, Tuple, TypeVar +from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode import attr @@ -39,9 +39,11 @@ from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.push.mailer import load_jinja2_templates -from synapse.server import HomeServer from synapse.types import UserID, map_username_to_mxid_localpart +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) SESSION_COOKIE_NAME = b"oidc_session" @@ -91,7 +93,7 @@ class OidcHandler: """Handles requests related to the OpenID Connect login flow. """ - def __init__(self, hs: HomeServer): + def __init__(self, hs: "HomeServer"): self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( -- cgit 1.5.1 From e04e465b4d2c66acb8885c31736c7b7bb4e7be52 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 17 Aug 2020 17:05:00 +0100 Subject: Use the default templates when a custom template file cannot be found (#8037) Fixes https://github.com/matrix-org/synapse/issues/6583 --- changelog.d/8037.feature | 1 + docs/sample_config.yaml | 4 +- synapse/config/_base.py | 100 ++++++++++++++++++++- synapse/config/emailconfig.py | 145 ++++++++++++++----------------- synapse/config/saml2_config.py | 14 +-- synapse/config/sso.py | 37 ++++---- synapse/handlers/account_validity.py | 20 +---- synapse/handlers/auth.py | 12 ++- synapse/handlers/oidc_handler.py | 5 +- synapse/push/mailer.py | 72 +-------------- synapse/push/pusher.py | 31 ++----- synapse/python_dependencies.py | 2 - synapse/rest/client/v2_alpha/account.py | 44 +++------- synapse/rest/client/v2_alpha/register.py | 31 ++----- tests/config/test_base.py | 82 +++++++++++++++++ 15 files changed, 310 insertions(+), 290 deletions(-) create mode 100644 changelog.d/8037.feature create mode 100644 tests/config/test_base.py (limited to 'synapse/handlers/oidc_handler.py') diff --git a/changelog.d/8037.feature b/changelog.d/8037.feature new file mode 100644 index 0000000000..2e5127477d --- /dev/null +++ b/changelog.d/8037.feature @@ -0,0 +1 @@ +Use the default template file when its equivalent is not found in a custom template directory. \ No newline at end of file diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 9235b89fb1..f168853f67 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2002,9 +2002,7 @@ email: # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index fd137853b1..1417487427 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,12 +18,16 @@ import argparse import errno import os +import time +import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, List, MutableMapping, Optional +from typing import Any, Callable, List, MutableMapping, Optional import attr +import jinja2 +import pkg_resources import yaml @@ -100,6 +104,11 @@ class Config(object): def __init__(self, root_config=None): self.root = root_config + # Get the path to the default Synapse template directory + self.default_template_dir = pkg_resources.resource_filename( + "synapse", "res/templates" + ) + def __getattr__(self, item: str) -> Any: """ Try and fetch a configuration option that does not exist on this class. @@ -184,6 +193,95 @@ class Config(object): with open(file_path) as file_stream: return file_stream.read() + def read_templates( + self, filenames: List[str], custom_template_directory: Optional[str] = None, + ) -> List[jinja2.Template]: + """Load a list of template files from disk using the given variables. + + This function will attempt to load the given templates from the default Synapse + template directory. If `custom_template_directory` is supplied, that directory + is tried first. + + Files read are treated as Jinja templates. These templates are not rendered yet. + + Args: + filenames: A list of template filenames to read. + + custom_template_directory: A directory to try to look for the templates + before using the default Synapse template directory instead. + + Raises: + ConfigError: if the file's path is incorrect or otherwise cannot be read. + + Returns: + A list of jinja2 templates. + """ + templates = [] + search_directories = [self.default_template_dir] + + # The loader will first look in the custom template directory (if specified) for the + # given filename. If it doesn't find it, it will use the default template dir instead + if custom_template_directory: + # Check that the given template directory exists + if not self.path_exists(custom_template_directory): + raise ConfigError( + "Configured template directory does not exist: %s" + % (custom_template_directory,) + ) + + # Search the custom template directory as well + search_directories.insert(0, custom_template_directory) + + loader = jinja2.FileSystemLoader(search_directories) + env = jinja2.Environment(loader=loader, autoescape=True) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(self.public_baseurl), + } + ) + + for filename in filenames: + # Load the template + template = env.get_template(filename) + templates.append(template) + + return templates + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + class RootConfig(object): """ diff --git a/synapse/config/emailconfig.py b/synapse/config/emailconfig.py index a63acbdc63..7a796996c0 100644 --- a/synapse/config/emailconfig.py +++ b/synapse/config/emailconfig.py @@ -23,7 +23,6 @@ from enum import Enum from typing import Optional import attr -import pkg_resources from ._base import Config, ConfigError @@ -98,21 +97,18 @@ class EmailConfig(Config): if parsed[1] == "": raise RuntimeError("Invalid notif_from address") + # A user-configurable template directory template_dir = email_config.get("template_dir") - # we need an absolute path, because we change directory after starting (and - # we don't yet know what auxiliary templates like mail.css we will need). - # (Note that loading as package_resources with jinja.PackageLoader doesn't - # work for the same reason.) - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates") - - self.email_template_dir = os.path.abspath(template_dir) + if isinstance(template_dir, str): + # We need an absolute path, because we change directory after starting (and + # we don't yet know what auxiliary templates like mail.css we will need). + template_dir = os.path.abspath(template_dir) + elif template_dir is not None: + # If template_dir is something other than a str or None, warn the user + raise ConfigError("Config option email.template_dir must be type str") self.email_enable_notifs = email_config.get("enable_notifs", False) - account_validity_config = config.get("account_validity") or {} - account_validity_renewal_enabled = account_validity_config.get("renew_at") - self.threepid_behaviour_email = ( # Have Synapse handle the email sending if account_threepid_delegates.email # is not defined @@ -166,19 +162,6 @@ class EmailConfig(Config): email_config.get("validation_token_lifetime", "1h") ) - if ( - self.email_enable_notifs - or account_validity_renewal_enabled - or self.threepid_behaviour_email == ThreepidBehaviour.LOCAL - ): - # make sure we can import the required deps - import bleach - import jinja2 - - # prevent unused warnings - jinja2 - bleach - if self.threepid_behaviour_email == ThreepidBehaviour.LOCAL: missing = [] if not self.email_notif_from: @@ -196,49 +179,49 @@ class EmailConfig(Config): # These email templates have placeholders in them, and thus must be # parsed using a templating engine during a request - self.email_password_reset_template_html = email_config.get( + password_reset_template_html = email_config.get( "password_reset_template_html", "password_reset.html" ) - self.email_password_reset_template_text = email_config.get( + password_reset_template_text = email_config.get( "password_reset_template_text", "password_reset.txt" ) - self.email_registration_template_html = email_config.get( + registration_template_html = email_config.get( "registration_template_html", "registration.html" ) - self.email_registration_template_text = email_config.get( + registration_template_text = email_config.get( "registration_template_text", "registration.txt" ) - self.email_add_threepid_template_html = email_config.get( + add_threepid_template_html = email_config.get( "add_threepid_template_html", "add_threepid.html" ) - self.email_add_threepid_template_text = email_config.get( + add_threepid_template_text = email_config.get( "add_threepid_template_text", "add_threepid.txt" ) - self.email_password_reset_template_failure_html = email_config.get( + password_reset_template_failure_html = email_config.get( "password_reset_template_failure_html", "password_reset_failure.html" ) - self.email_registration_template_failure_html = email_config.get( + registration_template_failure_html = email_config.get( "registration_template_failure_html", "registration_failure.html" ) - self.email_add_threepid_template_failure_html = email_config.get( + add_threepid_template_failure_html = email_config.get( "add_threepid_template_failure_html", "add_threepid_failure.html" ) # These templates do not support any placeholder variables, so we # will read them from disk once during setup - email_password_reset_template_success_html = email_config.get( + password_reset_template_success_html = email_config.get( "password_reset_template_success_html", "password_reset_success.html" ) - email_registration_template_success_html = email_config.get( + registration_template_success_html = email_config.get( "registration_template_success_html", "registration_success.html" ) - email_add_threepid_template_success_html = email_config.get( + add_threepid_template_success_html = email_config.get( "add_threepid_template_success_html", "add_threepid_success.html" ) - # Check templates exist - for f in [ + # Read all templates from disk + ( self.email_password_reset_template_html, self.email_password_reset_template_text, self.email_registration_template_html, @@ -248,32 +231,36 @@ class EmailConfig(Config): self.email_password_reset_template_failure_html, self.email_registration_template_failure_html, self.email_add_threepid_template_failure_html, - email_password_reset_template_success_html, - email_registration_template_success_html, - email_add_threepid_template_success_html, - ]: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find template file %s" % (p,)) - - # Retrieve content of web templates - filepath = os.path.join( - self.email_template_dir, email_password_reset_template_success_html + password_reset_template_success_html_template, + registration_template_success_html_template, + add_threepid_template_success_html_template, + ) = self.read_templates( + [ + password_reset_template_html, + password_reset_template_text, + registration_template_html, + registration_template_text, + add_threepid_template_html, + add_threepid_template_text, + password_reset_template_failure_html, + registration_template_failure_html, + add_threepid_template_failure_html, + password_reset_template_success_html, + registration_template_success_html, + add_threepid_template_success_html, + ], + template_dir, ) - self.email_password_reset_template_success_html = self.read_file( - filepath, "email.password_reset_template_success_html" - ) - filepath = os.path.join( - self.email_template_dir, email_registration_template_success_html - ) - self.email_registration_template_success_html_content = self.read_file( - filepath, "email.registration_template_success_html" + + # Render templates that do not contain any placeholders + self.email_password_reset_template_success_html_content = ( + password_reset_template_success_html_template.render() ) - filepath = os.path.join( - self.email_template_dir, email_add_threepid_template_success_html + self.email_registration_template_success_html_content = ( + registration_template_success_html_template.render() ) - self.email_add_threepid_template_success_html_content = self.read_file( - filepath, "email.add_threepid_template_success_html" + self.email_add_threepid_template_success_html_content = ( + add_threepid_template_success_html_template.render() ) if self.email_enable_notifs: @@ -290,17 +277,19 @@ class EmailConfig(Config): % (", ".join(missing),) ) - self.email_notif_template_html = email_config.get( + notif_template_html = email_config.get( "notif_template_html", "notif_mail.html" ) - self.email_notif_template_text = email_config.get( + notif_template_text = email_config.get( "notif_template_text", "notif_mail.txt" ) - for f in self.email_notif_template_text, self.email_notif_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.email_notif_template_html, + self.email_notif_template_text, + ) = self.read_templates( + [notif_template_html, notif_template_text], template_dir, + ) self.email_notif_for_new_users = email_config.get( "notif_for_new_users", True @@ -309,18 +298,20 @@ class EmailConfig(Config): "client_base_url", email_config.get("riot_base_url", None) ) - if account_validity_renewal_enabled: - self.email_expiry_template_html = email_config.get( + if self.account_validity.renew_by_email_enabled: + expiry_template_html = email_config.get( "expiry_template_html", "notice_expiry.html" ) - self.email_expiry_template_text = email_config.get( + expiry_template_text = email_config.get( "expiry_template_text", "notice_expiry.txt" ) - for f in self.email_expiry_template_text, self.email_expiry_template_html: - p = os.path.join(self.email_template_dir, f) - if not os.path.isfile(p): - raise ConfigError("Unable to find email template file %s" % (p,)) + ( + self.account_validity_template_html, + self.account_validity_template_text, + ) = self.read_templates( + [expiry_template_html, expiry_template_text], template_dir, + ) subjects_config = email_config.get("subjects", {}) subjects = {} @@ -400,9 +391,7 @@ class EmailConfig(Config): # Directory in which Synapse will try to find the template files below. # If not set, default templates from within the Synapse package will be used. # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. + # Do not uncomment this setting unless you want to customise the templates. # # Synapse will look for the following templates in this directory: # diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index 9277b5f342..036f8c0e90 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -18,8 +18,6 @@ import logging from typing import Any, List import attr -import jinja2 -import pkg_resources from synapse.python_dependencies import DependencyException, check_requirements from synapse.util.module_loader import load_module, load_python_module @@ -171,15 +169,9 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "15m") ) - template_dir = saml2_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - - loader = jinja2.FileSystemLoader(template_dir) - # enable auto-escape here, to having to remember to escape manually in the - # template - env = jinja2.Environment(loader=loader, autoescape=True) - self.saml2_error_html_template = env.get_template("saml_error.html") + self.saml2_error_html_template = self.read_templates( + ["saml_error.html"], saml2_config.get("template_dir") + ) def _default_saml_config_dict( self, required_attributes: set, optional_attributes: set diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 73b7296399..4427676167 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -12,11 +12,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from typing import Any, Dict -import pkg_resources - from ._base import Config @@ -29,22 +26,32 @@ class SSOConfig(Config): def read_config(self, config, **kwargs): sso_config = config.get("sso") or {} # type: Dict[str, Any] - # Pick a template directory in order of: - # * The sso-specific template_dir - # * /path/to/synapse/install/res/templates + # The sso-specific template_dir template_dir = sso_config.get("template_dir") - if not template_dir: - template_dir = pkg_resources.resource_filename("synapse", "res/templates",) - self.sso_template_dir = template_dir - self.sso_account_deactivated_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_account_deactivated.html"), - "sso_account_deactivated_template", + # Read templates from disk + ( + self.sso_redirect_confirm_template, + self.sso_auth_confirm_template, + self.sso_error_template, + sso_account_deactivated_template, + sso_auth_success_template, + ) = self.read_templates( + [ + "sso_redirect_confirm.html", + "sso_auth_confirm.html", + "sso_error.html", + "sso_account_deactivated.html", + "sso_auth_success.html", + ], + template_dir, ) - self.sso_auth_success_template = self.read_file( - os.path.join(self.sso_template_dir, "sso_auth_success.html"), - "sso_auth_success_template", + + # These templates have no placeholders, so render them here + self.sso_account_deactivated_template = ( + sso_account_deactivated_template.render() ) + self.sso_auth_success_template = sso_auth_success_template.render() self.sso_client_whitelist = sso_config.get("client_whitelist") or [] diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 590135d19c..b865bf5b48 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -26,11 +26,6 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util import stringutils -try: - from synapse.push.mailer import load_jinja2_templates -except ImportError: - load_jinja2_templates = None - logger = logging.getLogger(__name__) @@ -47,9 +42,11 @@ class AccountValidityHandler(object): if ( self._account_validity.enabled and self._account_validity.renew_by_email_enabled - and load_jinja2_templates ): # Don't do email-specific configuration if renewal by email is disabled. + self._template_html = self.config.account_validity_template_html + self._template_text = self.config.account_validity_template_text + try: app_name = self.hs.config.email_app_name @@ -65,17 +62,6 @@ class AccountValidityHandler(object): self._raw_from = email.utils.parseaddr(self._from_string)[1] - self._template_html, self._template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_expiry_template_html, - self.config.email_expiry_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) - # Check the renewal emails to send and send them every 30min. def send_emails(): # run as a background process to make sure that the database transactions diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c24e7bafe0..68d6870e40 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -42,7 +42,6 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi -from synapse.push.mailer import load_jinja2_templates from synapse.types import Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.threepids import canonicalise_email @@ -132,18 +131,17 @@ class AuthHandler(BaseHandler): # after the SSO completes and before redirecting them back to their client. # It notifies the user they are about to give access to their matrix account # to the client. - self._sso_redirect_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_redirect_confirm.html"], - )[0] + self._sso_redirect_confirm_template = hs.config.sso_redirect_confirm_template + # The following template is shown during user interactive authentication # in the fallback auth scenario. It notifies the user that they are # authenticating for an operation to occur on their account. - self._sso_auth_confirm_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_auth_confirm.html"], - )[0] + self._sso_auth_confirm_template = hs.config.sso_auth_confirm_template + # The following template is shown after a successful user interactive # authentication session. It tells the user they can close the window. self._sso_auth_success_template = hs.config.sso_auth_success_template + # The following template is shown during the SSO authentication process if # the account is deactivated. self._sso_account_deactivated_template = ( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index fa5ee5de8f..87d28a7ae9 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -38,7 +38,6 @@ from synapse.config import ConfigError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.push.mailer import load_jinja2_templates from synapse.types import UserID, map_username_to_mxid_localpart if TYPE_CHECKING: @@ -123,9 +122,7 @@ class OidcHandler: self._hostname = hs.hostname # type: str self._server_name = hs.config.server_name # type: str self._macaroon_secret_key = hs.config.macaroon_secret_key - self._error_template = load_jinja2_templates( - hs.config.sso_template_dir, ["sso_error.html"] - )[0] + self._error_template = hs.config.sso_error_template # identifier for the external_ids table self._auth_provider_id = "oidc" diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index af117fddf9..c38e037281 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -16,8 +16,7 @@ import email.mime.multipart import email.utils import logging -import time -import urllib +import urllib.parse from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from typing import Iterable, List, TypeVar @@ -640,72 +639,3 @@ def string_ordinal_total(s): for c in s: tot += ord(c) return tot - - -def format_ts_filter(value, format): - return time.strftime(format, time.localtime(value / 1000)) - - -def load_jinja2_templates( - template_dir, - template_filenames, - apply_format_ts_filter=False, - apply_mxc_to_http_filter=False, - public_baseurl=None, -): - """Loads and returns one or more jinja2 templates and applies optional filters - - Args: - template_dir (str): The directory where templates are stored - template_filenames (list[str]): A list of template filenames - apply_format_ts_filter (bool): Whether to apply a template filter that formats - timestamps - apply_mxc_to_http_filter (bool): Whether to apply a template filter that converts - mxc urls to http urls - public_baseurl (str|None): The public baseurl of the server. Required for - apply_mxc_to_http_filter to be enabled - - Returns: - A list of jinja2 templates corresponding to the given list of filenames, - with order preserved - """ - logger.info( - "loading email templates %s from '%s'", template_filenames, template_dir - ) - loader = jinja2.FileSystemLoader(template_dir) - env = jinja2.Environment(loader=loader) - - if apply_format_ts_filter: - env.filters["format_ts"] = format_ts_filter - - if apply_mxc_to_http_filter and public_baseurl: - env.filters["mxc_to_http"] = _create_mxc_to_http_filter(public_baseurl) - - templates = [] - for template_filename in template_filenames: - template = env.get_template(template_filename) - templates.append(template) - - return templates - - -def _create_mxc_to_http_filter(public_baseurl): - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - serverAndMediaId = value[6:] - fragment = None - if "#" in serverAndMediaId: - (serverAndMediaId, fragment) = serverAndMediaId.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - serverAndMediaId, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter diff --git a/synapse/push/pusher.py b/synapse/push/pusher.py index 8ad0bf5936..f626797133 100644 --- a/synapse/push/pusher.py +++ b/synapse/push/pusher.py @@ -15,22 +15,13 @@ import logging +from synapse.push.emailpusher import EmailPusher +from synapse.push.mailer import Mailer + from .httppusher import HttpPusher logger = logging.getLogger(__name__) -# We try importing this if we can (it will fail if we don't -# have the optional email dependencies installed). We don't -# yet have the config to know if we need the email pusher, -# but importing this after daemonizing seems to fail -# (even though a simple test of importing from a daemonized -# process works fine) -try: - from synapse.push.emailpusher import EmailPusher - from synapse.push.mailer import Mailer, load_jinja2_templates -except Exception: - pass - class PusherFactory(object): def __init__(self, hs): @@ -43,16 +34,8 @@ class PusherFactory(object): if hs.config.email_enable_notifs: self.mailers = {} # app_name -> Mailer - self.notif_template_html, self.notif_template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_notif_template_html, - self.config.email_notif_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) + self._notif_template_html = hs.config.email_notif_template_html + self._notif_template_text = hs.config.email_notif_template_text self.pusher_types["email"] = self._create_email_pusher @@ -73,8 +56,8 @@ class PusherFactory(object): mailer = Mailer( hs=self.hs, app_name=app_name, - template_html=self.notif_template_html, - template_text=self.notif_template_text, + template_html=self._notif_template_html, + template_text=self._notif_template_text, ) self.mailers[app_name] = mailer return EmailPusher(self.hs, pusherdict, mailer) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index e5f22fb858..3250d41dde 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -78,8 +78,6 @@ CONDITIONAL_REQUIREMENTS = { "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], # we use execute_batch, which arrived in psycopg 2.7. "postgres": ["psycopg2>=2.7"], - # ConsentResource uses select_autoescape, which arrived in jinja 2.9 - "resources.consent": ["Jinja2>=2.9"], # ACME support is required to provision TLS certificates from authorities # that use the protocol, such as Let's Encrypt. "acme": [ diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index fead85074b..203e76b9f2 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -32,7 +32,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import Mailer, load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.stringutils import assert_valid_client_secret, random_string from synapse.util.threepids import canonicalise_email, check_3pid_allowed @@ -53,21 +53,11 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): self.identity_handler = hs.get_handlers().identity_handler if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_password_reset_template_html, - self.config.email_password_reset_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_password_reset_template_html, + template_text=self.config.email_password_reset_template_text, ) async def on_POST(self, request): @@ -169,9 +159,8 @@ class PasswordResetSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_password_reset_template_failure_html], + self._failure_email_template = ( + self.config.email_password_reset_template_failure_html ) async def on_GET(self, request, medium): @@ -214,14 +203,14 @@ class PasswordResetSubmitTokenServlet(RestServlet): return None # Otherwise show the success template - html = self.config.email_password_reset_template_success_html + html = self.config.email_password_reset_template_success_html_content status_code = 200 except ThreepidValidationError as e: status_code = e.code # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) @@ -411,19 +400,11 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): self.store = self.hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_add_threepid_template_html, - self.config.email_add_threepid_template_text, - ], - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_add_threepid_template_html, + template_text=self.config.email_add_threepid_template_text, ) async def on_POST(self, request): @@ -578,9 +559,8 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): self.clock = hs.get_clock() self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_add_threepid_template_failure_html], + self._failure_email_template = ( + self.config.email_add_threepid_template_failure_html ) async def on_GET(self, request): @@ -631,7 +611,7 @@ class AddThreepidEmailSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index f808175698..7290fd0756 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -44,7 +44,7 @@ from synapse.http.servlet import ( parse_json_object_from_request, parse_string, ) -from synapse.push.mailer import load_jinja2_templates +from synapse.push.mailer import Mailer from synapse.util.msisdn import phone_number_to_msisdn from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -81,23 +81,11 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): self.config = hs.config if self.hs.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - from synapse.push.mailer import Mailer, load_jinja2_templates - - template_html, template_text = load_jinja2_templates( - self.config.email_template_dir, - [ - self.config.email_registration_template_html, - self.config.email_registration_template_text, - ], - apply_format_ts_filter=True, - apply_mxc_to_http_filter=True, - public_baseurl=self.config.public_baseurl, - ) self.mailer = Mailer( hs=self.hs, app_name=self.config.email_app_name, - template_html=template_html, - template_text=template_text, + template_html=self.config.email_registration_template_html, + template_text=self.config.email_registration_template_text, ) async def on_POST(self, request): @@ -262,15 +250,8 @@ class RegistrationSubmitTokenServlet(RestServlet): self.store = hs.get_datastore() if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], - ) - - if self.config.threepid_behaviour_email == ThreepidBehaviour.LOCAL: - (self.failure_email_template,) = load_jinja2_templates( - self.config.email_template_dir, - [self.config.email_registration_template_failure_html], + self._failure_email_template = ( + self.config.email_registration_template_failure_html ) async def on_GET(self, request, medium): @@ -318,7 +299,7 @@ class RegistrationSubmitTokenServlet(RestServlet): # Show a failure page with a reason template_vars = {"failure_reason": e.msg} - html = self.failure_email_template.render(**template_vars) + html = self._failure_email_template.render(**template_vars) respond_with_html(request, status_code, html) diff --git a/tests/config/test_base.py b/tests/config/test_base.py new file mode 100644 index 0000000000..42ee5f56d9 --- /dev/null +++ b/tests/config/test_base.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os.path +import tempfile + +from synapse.config import ConfigError +from synapse.util.stringutils import random_string + +from tests import unittest + + +class BaseConfigTestCase(unittest.HomeserverTestCase): + def prepare(self, reactor, clock, hs): + self.hs = hs + + def test_loading_missing_templates(self): + # Use a temporary directory that exists on the system, but that isn't likely to + # contain template files + with tempfile.TemporaryDirectory() as tmp_dir: + # Attempt to load an HTML template from our custom template directory + template = self.hs.config.read_templates(["sso_error.html"], tmp_dir)[0] + + # If no errors, we should've gotten the default template instead + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"error_description": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_custom_templates(self): + # Use a temporary directory that exists on the system + with tempfile.TemporaryDirectory() as tmp_dir: + # Create a temporary bogus template file + with tempfile.NamedTemporaryFile(dir=tmp_dir) as tmp_template: + # Get temporary file's filename + template_filename = os.path.basename(tmp_template.name) + + # Write a custom HTML template + contents = b"{{ test_variable }}" + tmp_template.write(contents) + tmp_template.flush() + + # Attempt to load the template from our custom template directory + template = ( + self.hs.config.read_templates([template_filename], tmp_dir) + )[0] + + # Render the template + a_random_string = random_string(5) + html_content = template.render({"test_variable": a_random_string}) + + # Check that our string exists in the template + self.assertIn( + a_random_string, + html_content, + "Template file did not contain our test string", + ) + + def test_loading_template_from_nonexistent_custom_directory(self): + with self.assertRaises(ConfigError): + self.hs.config.read_templates( + ["some_filename.html"], "a_nonexistent_directory" + ) -- cgit 1.5.1 From eebf52be060876ff14bbcbbc86b64ff9965b3622 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 19 Aug 2020 07:26:03 -0400 Subject: Be stricter about JSON that is accepted by Synapse (#8106) --- changelog.d/8106.bugfix | 1 + synapse/api/errors.py | 6 +++--- synapse/federation/federation_server.py | 5 ++--- synapse/federation/sender/transaction_manager.py | 5 ++--- synapse/handlers/e2e_keys.py | 8 ++++---- synapse/handlers/identity.py | 5 ++--- synapse/handlers/message.py | 5 +++-- synapse/handlers/oidc_handler.py | 6 +++--- synapse/handlers/ui_auth/checkers.py | 5 ++--- synapse/http/client.py | 11 ++++++----- synapse/http/federation/well_known_resolver.py | 5 ++--- synapse/http/servlet.py | 5 ++--- synapse/logging/opentracing.py | 7 +++++-- synapse/replication/tcp/commands.py | 12 +++++------- synapse/rest/client/v1/room.py | 11 +++++++---- synapse/rest/client/v2_alpha/sync.py | 5 ++--- synapse/rest/key/v2/remote_key_resource.py | 8 +++++--- synapse/storage/_base.py | 7 +++---- synapse/storage/databases/main/events_worker.py | 16 ++++++++++++++-- synapse/util/__init__.py | 14 ++++++++++++-- 20 files changed, 85 insertions(+), 62 deletions(-) create mode 100644 changelog.d/8106.bugfix (limited to 'synapse/handlers/oidc_handler.py') diff --git a/changelog.d/8106.bugfix b/changelog.d/8106.bugfix new file mode 100644 index 0000000000..c46c60448f --- /dev/null +++ b/changelog.d/8106.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where invalid JSON would be accepted by Synapse. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 6e40630ab6..a3f314118a 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -21,10 +21,10 @@ import typing from http import HTTPStatus from typing import Dict, List, Optional, Union -from canonicaljson import json - from twisted.web import http +from synapse.util import json_decoder + if typing.TYPE_CHECKING: from synapse.types import JsonDict @@ -593,7 +593,7 @@ class HttpResponseException(CodeMessageException): # try to parse the body as json, to get better errcode/msg, but # default to M_UNKNOWN with the HTTP status as the error text try: - j = json.loads(self.response.decode("utf-8")) + j = json_decoder.decode(self.response.decode("utf-8")) except ValueError: j = {} diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 11c5d63298..630f571cd4 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -28,7 +28,6 @@ from typing import ( Union, ) -from canonicaljson import json from prometheus_client import Counter, Histogram from twisted.internet import defer @@ -63,7 +62,7 @@ from synapse.replication.http.federation import ( ReplicationGetQueryRestServlet, ) from synapse.types import JsonDict, get_domain_from_id -from synapse.util import glob_to_regex, unwrapFirstError +from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache @@ -551,7 +550,7 @@ class FederationServer(FederationBase): for device_id, keys in device_keys.items(): for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_str) + key_id: json_decoder.decode(json_str) } logger.info( diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index c7f6cb3d73..9bd534a313 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -15,8 +15,6 @@ import logging from typing import TYPE_CHECKING, List, Tuple -from canonicaljson import json - from synapse.api.errors import HttpResponseException from synapse.events import EventBase from synapse.federation.persistence import TransactionActions @@ -28,6 +26,7 @@ from synapse.logging.opentracing import ( tags, whitelisted_homeserver, ) +from synapse.util import json_decoder from synapse.util.metrics import measure_func if TYPE_CHECKING: @@ -71,7 +70,7 @@ class TransactionManager(object): for edu in pending_edus: context = edu.get_context() if context: - span_contexts.append(extract_text_map(json.loads(context))) + span_contexts.append(extract_text_map(json_decoder.decode(context))) if keep_destination: edu.strip_context() diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 84169c1022..d8def45e38 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -19,7 +19,7 @@ import logging from typing import Dict, List, Optional, Tuple import attr -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from signedjson.key import VerifyKey, decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from unpaddedbase64 import decode_base64 @@ -35,7 +35,7 @@ from synapse.types import ( get_domain_from_id, get_verify_key_from_cross_signing_key, ) -from synapse.util import unwrapFirstError +from synapse.util import json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination @@ -404,7 +404,7 @@ class E2eKeysHandler(object): for device_id, keys in device_keys.items(): for key_id, json_bytes in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json.loads(json_bytes) + key_id: json_decoder.decode(json_bytes) } @trace @@ -1186,7 +1186,7 @@ def _exception_to_failure(e): def _one_time_keys_match(old_key_json, new_key): - old_key = json.loads(old_key_json) + old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly if not isinstance(old_key, dict) or not isinstance(new_key, dict): diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index 92b7404706..0ce6ddfbe4 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -21,8 +21,6 @@ import logging import urllib.parse from typing import Awaitable, Callable, Dict, List, Optional, Tuple -from canonicaljson import json - from twisted.internet.error import TimeoutError from synapse.api.errors import ( @@ -34,6 +32,7 @@ from synapse.api.errors import ( from synapse.config.emailconfig import ThreepidBehaviour from synapse.http.client import SimpleHttpClient from synapse.types import JsonDict, Requester +from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 from synapse.util.stringutils import assert_valid_client_secret, random_string @@ -177,7 +176,7 @@ class IdentityHandler(BaseHandler): except TimeoutError: raise SynapseError(500, "Timed out contacting identity server") except CodeMessageException as e: - data = json.loads(e.msg) # XXX WAT? + data = json_decoder.decode(e.msg) # XXX WAT? return data logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index b999d91d1a..c955a86be0 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -17,7 +17,7 @@ import logging from typing import TYPE_CHECKING, Dict, List, Optional, Tuple -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from twisted.internet.interfaces import IDelayedCall @@ -55,6 +55,7 @@ from synapse.types import ( UserID, create_requester, ) +from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder from synapse.util.metrics import measure_func @@ -864,7 +865,7 @@ class EventCreationHandler(object): # Ensure that we can round trip before trying to persist in db try: dump = frozendict_json_encoder.encode(event.content) - json.loads(dump) + json_decoder.decode(dump) except Exception: logger.exception("Failed to encode content: %r", event.content) raise diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 87d28a7ae9..dd3703cbd2 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import json import logging from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar from urllib.parse import urlencode @@ -39,6 +38,7 @@ from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.server import HomeServer @@ -367,7 +367,7 @@ class OidcHandler: # and check for an error field. If not, we respond with a generic # error message. try: - resp = json.loads(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) error = resp["error"] description = resp.get("error_description", error) except (ValueError, KeyError): @@ -384,7 +384,7 @@ class OidcHandler: # Since it is a not a 5xx code, body should be a valid JSON. It will # raise if not. - resp = json.loads(resp_body.decode("utf-8")) + resp = json_decoder.decode(resp_body.decode("utf-8")) if "error" in resp: error = resp["error"] diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py index a011e9fe29..9146dc1a3b 100644 --- a/synapse/handlers/ui_auth/checkers.py +++ b/synapse/handlers/ui_auth/checkers.py @@ -16,13 +16,12 @@ import logging from typing import Any -from canonicaljson import json - from twisted.web.client import PartialDownloadError from synapse.api.constants import LoginType from synapse.api.errors import Codes, LoginError, SynapseError from synapse.config.emailconfig import ThreepidBehaviour +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker): except PartialDownloadError as pde: # Twisted is silly data = pde.response - resp_body = json.loads(data.decode("utf-8")) + resp_body = json_decoder.decode(data.decode("utf-8")) if "success" in resp_body: # Note that we do NOT check the hostname here: we explicitly diff --git a/synapse/http/client.py b/synapse/http/client.py index 8aeb70cdec..dad01a8e56 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -19,7 +19,7 @@ import urllib from io import BytesIO import treq -from canonicaljson import encode_canonical_json, json +from canonicaljson import encode_canonical_json from netaddr import IPAddress from prometheus_client import Counter from zope.interface import implementer, provider @@ -47,6 +47,7 @@ from synapse.http import ( from synapse.http.proxyagent import ProxyAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag, start_active_span, tags +from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred logger = logging.getLogger(__name__) @@ -391,7 +392,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body @@ -433,7 +434,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body @@ -463,7 +464,7 @@ class SimpleHttpClient(object): actual_headers.update(headers) body = await self.get_raw(uri, args, headers=headers) - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) async def put_json(self, uri, json_body, args={}, headers=None): """ Puts some json to the given URI. @@ -506,7 +507,7 @@ class SimpleHttpClient(object): body = await make_deferred_yieldable(readBody(response)) if 200 <= response.code < 300: - return json.loads(body.decode("utf-8")) + return json_decoder.decode(body.decode("utf-8")) else: raise HttpResponseException( response.code, response.phrase.decode("ascii", errors="replace"), body diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 89a3b041ce..f794315deb 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json import logging import random import time @@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers from synapse.logging.context import make_deferred_yieldable -from synapse.util import Clock +from synapse.util import Clock, json_decoder from synapse.util.caches.ttlcache import TTLCache from synapse.util.metrics import Measure @@ -181,7 +180,7 @@ class WellKnownResolver(object): if response.code != 200: raise Exception("Non-200 response %s" % (response.code,)) - parsed_body = json.loads(body.decode("utf-8")) + parsed_body = json_decoder.decode(body.decode("utf-8")) logger.info("Response from .well-known: %s", parsed_body) result = parsed_body["m.server"].encode("ascii") diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index a34e5ead88..53acba56cb 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -17,9 +17,8 @@ import logging -from canonicaljson import json - from synapse.api.errors import Codes, SynapseError +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): return None try: - content = json.loads(content_bytes.decode("utf-8")) + content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: logger.warning("Unable to parse JSON: %s", e) raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index 21dbd9f415..abe532d350 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -177,6 +177,7 @@ from canonicaljson import json from twisted.internet import defer from synapse.config import ConfigError +from synapse.util import json_decoder if TYPE_CHECKING: from synapse.http.site import SynapseRequest @@ -499,7 +500,9 @@ def start_active_span_from_edu( if opentracing is None: return _noop_context_manager() - carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {}) + carrier = json_decoder.decode(edu_content.get("context", "{}")).get( + "opentracing", {} + ) context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) _references = [ opentracing.child_of(span_context_from_string(x)) @@ -699,7 +702,7 @@ def span_context_from_string(carrier): Returns: The active span context decoded from a string. """ - carrier = json.loads(carrier) + carrier = json_decoder.decode(carrier) return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier) diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py index d853e4447e..8cd47770c1 100644 --- a/synapse/replication/tcp/commands.py +++ b/synapse/replication/tcp/commands.py @@ -21,9 +21,7 @@ import abc import logging from typing import Tuple, Type -from canonicaljson import json - -from synapse.util import json_encoder as _json_encoder +from synapse.util import json_decoder, json_encoder logger = logging.getLogger(__name__) @@ -125,7 +123,7 @@ class RdataCommand(Command): stream_name, instance_name, None if token == "batch" else int(token), - json.loads(row_json), + json_decoder.decode(row_json), ) def to_line(self): @@ -134,7 +132,7 @@ class RdataCommand(Command): self.stream_name, self.instance_name, str(self.token) if self.token is not None else "batch", - _json_encoder.encode(self.row), + json_encoder.encode(self.row), ) ) @@ -359,7 +357,7 @@ class UserIpCommand(Command): def from_line(cls, line): user_id, jsn = line.split(" ", 1) - access_token, ip, user_agent, device_id, last_seen = json.loads(jsn) + access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) return cls(user_id, access_token, ip, user_agent, device_id, last_seen) @@ -367,7 +365,7 @@ class UserIpCommand(Command): return ( self.user_id + " " - + _json_encoder.encode( + + json_encoder.encode( ( self.access_token, self.ip, diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 2ab30ce897..f216382636 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,8 +21,6 @@ import re from typing import List, Optional from urllib import parse as urlparse -from canonicaljson import json - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -46,6 +44,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID +from synapse.util import json_decoder MYPY = False if MYPY: @@ -519,7 +518,9 @@ class RoomMessageListRestServlet(RestServlet): filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] if ( event_filter and event_filter.filter_json.get("event_format", "client") @@ -631,7 +632,9 @@ class RoomEventContextServlet(RestServlet): filter_str = parse_string(request, b"filter", encoding="utf-8") if filter_str: filter_json = urlparse.unquote(filter_str) - event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter] + event_filter = Filter( + json_decoder.decode(filter_json) + ) # type: Optional[Filter] else: event_filter = None diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py index a5c24fbd63..96488b131a 100644 --- a/synapse/rest/client/v2_alpha/sync.py +++ b/synapse/rest/client/v2_alpha/sync.py @@ -16,8 +16,6 @@ import itertools import logging -from canonicaljson import json - from synapse.api.constants import PresenceState from synapse.api.errors import Codes, StoreError, SynapseError from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection @@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state from synapse.handlers.sync import SyncConfig from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.types import StreamToken +from synapse.util import json_decoder from ._base import client_patterns, set_timeline_upper_limit @@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet): filter_collection = DEFAULT_FILTER_COLLECTION elif filter_id.startswith("{"): try: - filter_object = json.loads(filter_id) + filter_object = json_decoder.decode(filter_id) set_timeline_upper_limit( filter_object, self.hs.config.filter_timeline_limit ) diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index e266204f95..5db7f81c2d 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -15,19 +15,19 @@ import logging from typing import Dict, Set -from canonicaljson import json from signedjson.sign import sign_json from synapse.api.errors import Codes, SynapseError from synapse.crypto.keyring import ServerKeyFetcher from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.servlet import parse_integer, parse_json_object_from_request +from synapse.util import json_decoder logger = logging.getLogger(__name__) class RemoteKey(DirectServeJsonResource): - """HTTP resource for retreiving the TLS certificate and NACL signature + """HTTP resource for retrieving the TLS certificate and NACL signature verification keys for a collection of servers. Checks that the reported X.509 TLS certificate matches the one used in the HTTPS connection. Checks that the NACL signature for the remote server is valid. Returns a dict of @@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource): # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(result["key_json"])) + # If there is a cache miss, request the missing keys, then recurse (and + # ensure the result is sent). if cache_misses and query_remote_on_cache_miss: await self.fetcher.get_keys(cache_misses) await self.query_keys(request, query, query_remote_on_cache_miss=False) else: signed_keys = [] for key_json in json_results: - key_json = json.loads(key_json.decode("utf-8")) + key_json = json_decoder.decode(key_json.decode("utf-8")) for signing_key in self.config.key_server_signing_keys: key_json = sign_json(key_json, self.config.server_name, signing_key) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 6814bf5fcf..ab49d227de 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -19,12 +19,11 @@ import random from abc import ABCMeta from typing import Any, Optional -from canonicaljson import json - from synapse.storage.database import LoggingTransaction # noqa: F401 from synapse.storage.database import make_in_list_sql_clause # noqa: F401 from synapse.storage.database import DatabasePool from synapse.types import Collection, get_domain_from_id +from synapse.util import json_decoder logger = logging.getLogger(__name__) @@ -99,13 +98,13 @@ def db_to_json(db_content): if isinstance(db_content, memoryview): db_content = db_content.tobytes() - # Decode it to a Unicode string before feeding it to json.loads, since + # Decode it to a Unicode string before feeding it to the JSON decoder, since # Python 3.5 does not support deserializing bytes. if isinstance(db_content, (bytes, bytearray)): db_content = db_content.decode("utf8") try: - return json.loads(db_content) + return json_decoder.decode(db_content) except Exception: logging.warning("Tried to decode '%r' as JSON and failed", db_content) raise diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e3a154a527..4a3333c0db 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -596,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore): if not allow_rejected and rejected_reason: continue - d = db_to_json(row["json"]) - internal_metadata = db_to_json(row["internal_metadata"]) + # If the event or metadata cannot be parsed, log the error and act + # as if the event is unknown. + try: + d = db_to_json(row["json"]) + except ValueError: + logger.error("Unable to parse json from event: %s", event_id) + continue + try: + internal_metadata = db_to_json(row["internal_metadata"]) + except ValueError: + logger.error( + "Unable to parse internal_metadata from event: %s", event_id + ) + continue format_version = row["format_version"] if format_version is None: diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index b3f76428b6..b2a22dbd5c 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -25,8 +25,18 @@ from synapse.logging import context logger = logging.getLogger(__name__) -# Create a custom encoder to reduce the whitespace produced by JSON encoding. -json_encoder = json.JSONEncoder(separators=(",", ":")) + +def _reject_invalid_json(val): + """Do not allow Infinity, -Infinity, or NaN values in JSON.""" + raise json.JSONDecodeError("Invalid JSON value: '%s'" % val) + + +# Create a custom encoder to reduce the whitespace produced by JSON encoding and +# ensure that valid JSON is produced. +json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":")) + +# Create a custom decoder to reject Python extensions to JSON. +json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json) def unwrapFirstError(failure): -- cgit 1.5.1 From 3f91638da6ea0aeaf789ddc8ca1e624a11b7ebb2 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 20 Aug 2020 15:42:58 -0400 Subject: Allow denying or shadow banning registrations via the spam checker (#8034) --- changelog.d/8034.feature | 1 + synapse/events/spamcheck.py | 35 ++++++++++++++- synapse/handlers/auth.py | 8 ++++ synapse/handlers/cas_handler.py | 11 ++++- synapse/handlers/oidc_handler.py | 21 +++++++-- synapse/handlers/register.py | 26 ++++++++++- synapse/handlers/saml_handler.py | 18 +++++++- synapse/rest/client/v2_alpha/register.py | 5 +++ synapse/spam_checker_api/__init__.py | 11 +++++ .../main/schema/delta/58/07persist_ui_auth_ips.sql | 25 +++++++++++ synapse/storage/databases/main/ui_auth.py | 39 +++++++++++++++- tests/handlers/test_oidc.py | 18 ++++++-- tests/handlers/test_register.py | 52 +++++++++++++++++++++- tests/handlers/test_user_directory.py | 6 +-- 14 files changed, 258 insertions(+), 18 deletions(-) create mode 100644 changelog.d/8034.feature create mode 100644 synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql (limited to 'synapse/handlers/oidc_handler.py') diff --git a/changelog.d/8034.feature b/changelog.d/8034.feature new file mode 100644 index 0000000000..813e6d0903 --- /dev/null +++ b/changelog.d/8034.feature @@ -0,0 +1 @@ +Add support for shadow-banning users (ignoring any message send requests). diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py index 1ffc9525d1..a7cddac974 100644 --- a/synapse/events/spamcheck.py +++ b/synapse/events/spamcheck.py @@ -15,9 +15,10 @@ # limitations under the License. import inspect -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Tuple -from synapse.spam_checker_api import SpamCheckerApi +from synapse.spam_checker_api import RegistrationBehaviour, SpamCheckerApi +from synapse.types import Collection MYPY = False if MYPY: @@ -160,3 +161,33 @@ class SpamChecker(object): return True return False + + def check_registration_for_spam( + self, + email_threepid: Optional[dict], + username: Optional[str], + request_info: Collection[Tuple[str, str]], + ) -> RegistrationBehaviour: + """Checks if we should allow the given registration request. + + Args: + email_threepid: The email threepid used for registering, if any + username: The request user name, if any + request_info: List of tuples of user agent and IP that + were used during the registration process. + + Returns: + Enum for how the request should be handled + """ + + for spam_checker in self.spam_checkers: + # For backwards compatibility, only run if the method exists on the + # spam checker + checker = getattr(spam_checker, "check_registration_for_spam", None) + if checker: + behaviour = checker(email_threepid, username, request_info) + assert isinstance(behaviour, RegistrationBehaviour) + if behaviour != RegistrationBehaviour.ALLOW: + return behaviour + + return RegistrationBehaviour.ALLOW diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 68d6870e40..654f58ddae 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -364,6 +364,14 @@ class AuthHandler(BaseHandler): # authentication flow. await self.store.set_ui_auth_clientdict(sid, clientdict) + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + + await self.store.add_user_agent_ip_to_ui_auth_session( + session.session_id, user_agent, clientip + ) + if not authdict: raise InteractiveAuthIncompleteError( session.session_id, self._auth_dict_for_flows(flows, session.session_id) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 786e608fa2..a4cc4b9a5a 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -35,6 +35,7 @@ class CasHandler: """ def __init__(self, hs): + self.hs = hs self._hostname = hs.hostname self._auth_handler = hs.get_auth_handler() self._registration_handler = hs.get_registration_handler() @@ -210,8 +211,16 @@ class CasHandler: else: if not registered_user_id: + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders( + b"User-Agent", default=[b""] + )[0].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=user_display_name + localpart=localpart, + default_display_name=user_display_name, + user_agent_ips=(user_agent, ip_address), ) await self._auth_handler.complete_sso_login( diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index dd3703cbd2..c5bd2fea68 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -93,6 +93,7 @@ class OidcHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs self._callback_url = hs.config.oidc_callback_url # type: str self._scopes = hs.config.oidc_scopes # type: List[str] self._client_auth = ClientAuth( @@ -689,9 +690,17 @@ class OidcHandler: self._render_error(request, "invalid_token", str(e)) return + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + # Call the mapper to register/login the user try: - user_id = await self._map_userinfo_to_user(userinfo, token) + user_id = await self._map_userinfo_to_user( + userinfo, token, user_agent, ip_address + ) except MappingException as e: logger.exception("Could not map user") self._render_error(request, "mapping_error", str(e)) @@ -828,7 +837,9 @@ class OidcHandler: now = self._clock.time_msec() return now < expiry - async def _map_userinfo_to_user(self, userinfo: UserInfo, token: Token) -> str: + async def _map_userinfo_to_user( + self, userinfo: UserInfo, token: Token, user_agent: str, ip_address: str + ) -> str: """Maps a UserInfo object to a mxid. UserInfo should have a claim that uniquely identifies users. This claim @@ -843,6 +854,8 @@ class OidcHandler: Args: userinfo: an object representing the user token: a dict with the tokens obtained from the provider + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Raises: MappingException: if there was an error while mapping some properties @@ -899,7 +912,9 @@ class OidcHandler: # It's the first time this user is logging in and the mapped mxid was # not taken, register the user registered_user_id = await self._registration_handler.register_user( - localpart=localpart, default_display_name=attributes["display_name"], + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index ccd96e4626..cde2dbca92 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -26,6 +26,7 @@ from synapse.replication.http.register import ( ReplicationPostRegisterActionsServlet, ReplicationRegisterServlet, ) +from synapse.spam_checker_api import RegistrationBehaviour from synapse.storage.state import StateFilter from synapse.types import RoomAlias, UserID, create_requester @@ -52,6 +53,8 @@ class RegistrationHandler(BaseHandler): self.macaroon_gen = hs.get_macaroon_generator() self._server_notices_mxid = hs.config.server_notices_mxid + self.spam_checker = hs.get_spam_checker() + if hs.config.worker_app: self._register_client = ReplicationRegisterServlet.make_client(hs) self._register_device_client = RegisterDeviceReplicationServlet.make_client( @@ -144,7 +147,7 @@ class RegistrationHandler(BaseHandler): address=None, bind_emails=[], by_admin=False, - shadow_banned=False, + user_agent_ips=None, ): """Registers a new client on the server. @@ -162,7 +165,8 @@ class RegistrationHandler(BaseHandler): bind_emails (List[str]): list of emails to bind to this account. by_admin (bool): True if this registration is being made via the admin api, otherwise False. - shadow_banned (bool): Shadow-ban the created user. + user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used + during the registration process. Returns: str: user_id Raises: @@ -170,6 +174,24 @@ class RegistrationHandler(BaseHandler): """ self.check_registration_ratelimit(address) + result = self.spam_checker.check_registration_for_spam( + threepid, localpart, user_agent_ips or [], + ) + + if result == RegistrationBehaviour.DENY: + logger.info( + "Blocked registration of %r", localpart, + ) + # We return a 429 to make it not obvious that they've been + # denied. + raise SynapseError(429, "Rate limited") + + shadow_banned = result == RegistrationBehaviour.SHADOW_BAN + if shadow_banned: + logger.info( + "Shadow banning registration of %r", localpart, + ) + # do not check_auth_blocking if the call is coming through the Admin API if not by_admin: await self.auth.check_auth_blocking(threepid=threepid) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index c1fcb98454..b426199aa6 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -54,6 +54,7 @@ class Saml2SessionData: class SamlHandler: def __init__(self, hs: "synapse.server.HomeServer"): + self.hs = hs self._saml_client = Saml2Client(hs.config.saml2_sp_config) self._auth = hs.get_auth() self._auth_handler = hs.get_auth_handler() @@ -133,8 +134,14 @@ class SamlHandler: # the dict. self.expire_sessions() + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state + resp_bytes, relay_state, user_agent, ip_address ) # Complete the interactive auth session or the login. @@ -147,7 +154,11 @@ class SamlHandler: await self._auth_handler.complete_sso_login(user_id, request, relay_state) async def _map_saml_response_to_user( - self, resp_bytes: str, client_redirect_url: str + self, + resp_bytes: str, + client_redirect_url: str, + user_agent: str, + ip_address: str, ) -> Tuple[str, Optional[Saml2SessionData]]: """ Given a sample response, retrieve the cached session and user for it. @@ -155,6 +166,8 @@ class SamlHandler: Args: resp_bytes: The SAML response. client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. Returns: Tuple of the user ID and SAML session associated with this response. @@ -291,6 +304,7 @@ class SamlHandler: localpart=localpart, default_display_name=displayname, bind_emails=emails, + user_agent_ips=(user_agent, ip_address), ) await self._datastore.record_user_external_id( diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index 7290fd0756..be0e680ac5 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -591,12 +591,17 @@ class RegisterRestServlet(RestServlet): Codes.THREEPID_IN_USE, ) + entries = await self.store.get_user_agents_ips_to_ui_auth_session( + session_id + ) + registered_user_id = await self.registration_handler.register_user( localpart=desired_username, password_hash=password_hash, guest_access_token=guest_access_token, threepid=threepid, address=client_addr, + user_agent_ips=entries, ) # Necessary due to auth checks prior to the threepid being # written to the db diff --git a/synapse/spam_checker_api/__init__.py b/synapse/spam_checker_api/__init__.py index 7f63f1bfa0..9be92e2565 100644 --- a/synapse/spam_checker_api/__init__.py +++ b/synapse/spam_checker_api/__init__.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from enum import Enum from twisted.internet import defer @@ -25,6 +26,16 @@ if MYPY: logger = logging.getLogger(__name__) +class RegistrationBehaviour(Enum): + """ + Enum to define whether a registration request should allowed, denied, or shadow-banned. + """ + + ALLOW = "allow" + SHADOW_BAN = "shadow_ban" + DENY = "deny" + + class SpamCheckerApi(object): """A proxy object that gets passed to spam checkers so they can get access to rooms and other relevant information. diff --git a/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql new file mode 100644 index 0000000000..4cc96a5341 --- /dev/null +++ b/synapse/storage/databases/main/schema/delta/58/07persist_ui_auth_ips.sql @@ -0,0 +1,25 @@ +/* Copyright 2020 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- A table of the IP address and user-agent used to complete each step of a +-- user-interactive authentication session. +CREATE TABLE IF NOT EXISTS ui_auth_sessions_ips( + session_id TEXT NOT NULL, + ip TEXT NOT NULL, + user_agent TEXT NOT NULL, + UNIQUE (session_id, ip, user_agent), + FOREIGN KEY (session_id) + REFERENCES ui_auth_sessions (session_id) +); diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 6281a41a3d..9eef8e57c5 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import attr @@ -260,6 +260,34 @@ class UIAuthWorkerStore(SQLBaseStore): return serverdict.get(key, default) + async def add_user_agent_ip_to_ui_auth_session( + self, session_id: str, user_agent: str, ip: str, + ): + """Add the given user agent / IP to the tracking table + """ + await self.db_pool.simple_upsert( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id, "user_agent": user_agent, "ip": ip}, + values={}, + desc="add_user_agent_ip_to_ui_auth_session", + ) + + async def get_user_agents_ips_to_ui_auth_session( + self, session_id: str, + ) -> List[Tuple[str, str]]: + """Get the given user agents / IPs used during the ui auth process + + Returns: + List of user_agent/ip pairs + """ + rows = await self.db_pool.simple_select_list( + table="ui_auth_sessions_ips", + keyvalues={"session_id": session_id}, + retcols=("user_agent", "ip"), + desc="get_user_agents_ips_to_ui_auth_session", + ) + return [(row["user_agent"], row["ip"]) for row in rows] + class UIAuthStore(UIAuthWorkerStore): def delete_old_ui_auth_sessions(self, expiration_time: int): @@ -285,6 +313,15 @@ class UIAuthStore(UIAuthWorkerStore): txn.execute(sql, [expiration_time]) session_ids = [r[0] for r in txn.fetchall()] + # Delete the corresponding IP/user agents. + self.db_pool.simple_delete_many_txn( + txn, + table="ui_auth_sessions_ips", + column="session_id", + iterable=session_ids, + keyvalues={}, + ) + # Delete the corresponding completed credentials. self.db_pool.simple_delete_many_txn( txn, diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 1bb25ab684..f92f3b8c15 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -374,12 +374,16 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo) self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) self.handler._auth_handler.complete_sso_login = simple_async_mock() - request = Mock(spec=["args", "getCookie", "addCookie"]) + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) code = "code" state = "state" nonce = "nonce" client_redirect_url = "http://client/redirect" + user_agent = "Browser" + ip_address = "10.0.0.1" session = self.handler._generate_oidc_session_token( state=state, nonce=nonce, @@ -392,6 +396,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args[b"code"] = [code.encode("utf-8")] request.args[b"state"] = [state.encode("utf-8")] + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] + request.getClientIP.return_value = ip_address + yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( @@ -399,7 +407,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_not_called() self.handler._render_error.assert_not_called() @@ -431,7 +441,9 @@ class OidcHandlerTestCase(HomeserverTestCase): ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() - self.handler._map_userinfo_to_user.assert_called_once_with(userinfo, token) + self.handler._map_userinfo_to_user.assert_called_once_with( + userinfo, token, user_agent, ip_address + ) self.handler._fetch_userinfo.assert_called_once_with(token) self.handler._render_error.assert_not_called() diff --git a/tests/handlers/test_register.py b/tests/handlers/test_register.py index e364b1bd62..5c92d0e8c9 100644 --- a/tests/handlers/test_register.py +++ b/tests/handlers/test_register.py @@ -17,18 +17,21 @@ from mock import Mock from twisted.internet import defer +from synapse.api.auth import Auth from synapse.api.constants import UserTypes from synapse.api.errors import Codes, ResourceLimitError, SynapseError from synapse.handlers.register import RegistrationHandler +from synapse.spam_checker_api import RegistrationBehaviour from synapse.types import RoomAlias, UserID, create_requester from tests.test_utils import make_awaitable from tests.unittest import override_config +from tests.utils import mock_getRawHeaders from .. import unittest -class RegistrationHandlers(object): +class RegistrationHandlers: def __init__(self, hs): self.registration_handler = RegistrationHandler(hs) @@ -475,6 +478,53 @@ class RegistrationTestCase(unittest.HomeserverTestCase): self.handler.register_user(localpart=invalid_user_id), SynapseError ) + def test_spam_checker_deny(self): + """A spam checker can deny registration, which results in an error.""" + + class DenyAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.DENY + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [DenyAll()] + + self.get_failure(self.handler.register_user(localpart="user"), SynapseError) + + def test_spam_checker_shadow_ban(self): + """A spam checker can choose to shadow-ban a user, which allows registration to succeed.""" + + class BanAll: + def check_registration_for_spam( + self, email_threepid, username, request_info + ): + return RegistrationBehaviour.SHADOW_BAN + + # Configure a spam checker that denies all users. + spam_checker = self.hs.get_spam_checker() + spam_checker.spam_checkers = [BanAll()] + + user_id = self.get_success(self.handler.register_user(localpart="user")) + + # Get an access token. + token = self.macaroon_generator.generate_access_token(user_id) + self.get_success( + self.store.add_access_token_to_user( + user_id=user_id, token=token, device_id=None, valid_until_ms=None + ) + ) + + # Ensure the user was marked as shadow-banned. + request = Mock(args={}) + request.args[b"access_token"] = [token.encode("ascii")] + request.requestHeaders.getRawHeaders = mock_getRawHeaders() + auth = Auth(self.hs) + requester = self.get_success(auth.get_user_by_req(request)) + + self.assertTrue(requester.shadow_banned) + async def get_or_create_user( self, requester, localpart, displayname, password_hash=None ): diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index 31ed89a5cd..87be94111f 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -238,7 +238,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): def test_spam_checker(self): """ - A user which fails to the spam checks will not appear in search results. + A user which fails the spam checks will not appear in search results. """ u1 = self.register_user("user1", "pass") u1_token = self.login(u1, "pass") @@ -269,7 +269,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): # Configure a spam checker that does not filter any users. spam_checker = self.hs.get_spam_checker() - class AllowAll(object): + class AllowAll: def check_username_for_spam(self, user_profile): # Allow all users. return False @@ -282,7 +282,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.assertEqual(len(s["results"]), 1) # Configure a spam checker that filters all users. - class BlockAll(object): + class BlockAll: def check_username_for_spam(self, user_profile): # All users are spammy. return True -- cgit 1.5.1 From b055dc93220217fe55f8f4d28945f86353c2f3a8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 28 Aug 2020 08:56:36 -0400 Subject: Ensure that the OpenID Connect remote ID is a string. (#8190) --- changelog.d/8190.bugfix | 1 + synapse/handlers/oidc_handler.py | 3 +++ tests/handlers/test_oidc.py | 41 ++++++++++++++++++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 2 deletions(-) create mode 100644 changelog.d/8190.bugfix (limited to 'synapse/handlers/oidc_handler.py') diff --git a/changelog.d/8190.bugfix b/changelog.d/8190.bugfix new file mode 100644 index 0000000000..bf6717ab28 --- /dev/null +++ b/changelog.d/8190.bugfix @@ -0,0 +1 @@ +Fix logging in via OpenID Connect with a provider that uses integer user IDs. diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index c5bd2fea68..1b06f3173f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -869,6 +869,9 @@ class OidcHandler: raise MappingException( "Failed to extract subject from OIDC response: %s" % (e,) ) + # Some OIDC providers use integer IDs, but Synapse expects external IDs + # to be strings. + remote_user_id = str(remote_user_id) logger.info( "Looking for existing mapping for user %s:%s", diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index f92f3b8c15..89ec5fcb31 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -75,7 +75,17 @@ COMMON_CONFIG = { COOKIE_NAME = b"oidc_session" COOKIE_PATH = "/_synapse/oidc" -MockedMappingProvider = Mock(OidcMappingProvider) + +class TestMappingProvider(OidcMappingProvider): + @staticmethod + def parse_config(config): + return + + def get_remote_user_id(self, userinfo): + return userinfo["sub"] + + async def map_user_attributes(self, userinfo, token): + return {"localpart": userinfo["username"], "display_name": None} def simple_async_mock(return_value=None, raises=None): @@ -123,7 +133,7 @@ class OidcHandlerTestCase(HomeserverTestCase): oidc_config["issuer"] = ISSUER oidc_config["scopes"] = SCOPES oidc_config["user_mapping_provider"] = { - "module": __name__ + ".MockedMappingProvider" + "module": __name__ + ".TestMappingProvider", } config["oidc_config"] = oidc_config @@ -580,3 +590,30 @@ class OidcHandlerTestCase(HomeserverTestCase): with self.assertRaises(OidcError) as exc: yield defer.ensureDeferred(self.handler._exchange_code(code)) self.assertEqual(exc.exception.error, "some_error") + + def test_map_userinfo_to_user(self): + """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" + userinfo = { + "sub": "test_user", + "username": "test_user", + } + # The token doesn't matter with the default user mapping provider. + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user:test") + + # Some providers return an integer ID. + userinfo = { + "sub": 1234, + "username": "test_user_2", + } + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_2:test") -- cgit 1.5.1 From 6605470bfb8944d369b8fc73195a380b95b6de9d Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 14 Sep 2020 09:05:36 -0400 Subject: Improve SAML error messages (#8248) --- UPGRADE.rst | 14 +++ changelog.d/8248.feature | 1 + docs/sample_config.yaml | 30 +----- synapse/config/saml2_config.py | 34 +------ synapse/handlers/oidc_handler.py | 4 +- synapse/handlers/saml_handler.py | 169 +++++++++++++++++++++----------- synapse/res/templates/saml_error.html | 52 ---------- synapse/res/templates/sso_error.html | 43 +++++++- synapse/rest/saml2/response_resource.py | 16 +-- 9 files changed, 178 insertions(+), 185 deletions(-) create mode 100644 changelog.d/8248.feature delete mode 100644 synapse/res/templates/saml_error.html (limited to 'synapse/handlers/oidc_handler.py') diff --git a/UPGRADE.rst b/UPGRADE.rst index fc8982ddfe..49e86e628f 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -114,6 +114,20 @@ request to with the query parameters from the original link, presented as a URL-encoded form. See the file itself for more details. +Updated Single Sign-on HTML Templates +------------------------------------- + +The ``saml_error.html`` template was removed from Synapse and replaced with the +``sso_error.html`` template. If your Synapse is configured to use SAML and a +custom ``sso_redirect_confirm_template_dir`` configuration then any customisations +of the ``saml_error.html`` template will need to be merged into the ``sso_error.html`` +template. These templates are similar, but the parameters are slightly different: + +* The ``msg`` parameter should be renamed to ``error_description``. +* There is no longer a ``code`` parameter for the response code. +* A string ``error`` parameter is available that includes a short hint of why a + user is seeing the error page. + Upgrading to v1.18.0 ==================== diff --git a/changelog.d/8248.feature b/changelog.d/8248.feature new file mode 100644 index 0000000000..f3c4a74bc7 --- /dev/null +++ b/changelog.d/8248.feature @@ -0,0 +1 @@ +Consolidate the SSO error template across all configuration. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 2a5b2e0935..fb04ff283d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1485,11 +1485,14 @@ trusted_key_servers: # At least one of `sp_config` or `config_path` must be set in this section to # enable SAML login. # -# (You will probably also want to set the following options to `false` to +# You will probably also want to set the following options to `false` to # disable the regular login/registration flows: # * enable_registration # * password_config.enabled # +# You will also want to investigate the settings under the "sso" configuration +# section below. +# # Once SAML support is enabled, a metadata file will be exposed at # https://:/_matrix/saml2/metadata.xml, which you may be able to # use to configure your SAML IdP with. Alternatively, you can manually configure @@ -1612,31 +1615,6 @@ saml2_config: # - attribute: department # value: "sales" - # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to display to users if something goes wrong during the - # authentication process: 'saml_error.html'. - # - # When rendering, this template is given the following variables: - # * code: an HTML error code corresponding to the error that is being - # returned (typically 400 or 500) - # - # * msg: a textual message describing the error. - # - # The variables will automatically be HTML-escaped. - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" - # OpenID Connect integration. The following settings can be used to make Synapse # use an OpenID Connect Provider for authentication, instead of its internal diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index cc7401888b..99aa8b3bf1 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -169,10 +169,6 @@ class SAML2Config(Config): saml2_config.get("saml_session_lifetime", "15m") ) - self.saml2_error_html_template = self.read_templates( - ["saml_error.html"], saml2_config.get("template_dir") - )[0] - def _default_saml_config_dict( self, required_attributes: set, optional_attributes: set ): @@ -225,11 +221,14 @@ class SAML2Config(Config): # At least one of `sp_config` or `config_path` must be set in this section to # enable SAML login. # - # (You will probably also want to set the following options to `false` to + # You will probably also want to set the following options to `false` to # disable the regular login/registration flows: # * enable_registration # * password_config.enabled # + # You will also want to investigate the settings under the "sso" configuration + # section below. + # # Once SAML support is enabled, a metadata file will be exposed at # https://:/_matrix/saml2/metadata.xml, which you may be able to # use to configure your SAML IdP with. Alternatively, you can manually configure @@ -351,31 +350,6 @@ class SAML2Config(Config): # value: "staff" # - attribute: department # value: "sales" - - # Directory in which Synapse will try to find the template files below. - # If not set, default templates from within the Synapse package will be used. - # - # DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates. - # If you *do* uncomment it, you will need to make sure that all the templates - # below are in the directory. - # - # Synapse will look for the following templates in this directory: - # - # * HTML page to display to users if something goes wrong during the - # authentication process: 'saml_error.html'. - # - # When rendering, this template is given the following variables: - # * code: an HTML error code corresponding to the error that is being - # returned (typically 400 or 500) - # - # * msg: a textual message describing the error. - # - # The variables will automatically be HTML-escaped. - # - # You can see the default templates at: - # https://github.com/matrix-org/synapse/tree/master/synapse/res/templates - # - #template_dir: "res/templates" """ % { "config_dir_path": config_dir_path } diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 1b06f3173f..4230dbaf99 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -131,10 +131,10 @@ class OidcHandler: def _render_error( self, request, error: str, error_description: Optional[str] = None ) -> None: - """Renders the error template and respond with it. + """Render the error template and respond to the request with it. This is used to show errors to the user. The template of this page can - be found under ``synapse/res/templates/sso_error.html``. + be found under `synapse/res/templates/sso_error.html`. Args: request: The incoming request from the browser. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 66b063f991..8715abd4d1 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -21,9 +21,10 @@ import saml2 import saml2.response from saml2.client import Saml2Client -from synapse.api.errors import AuthError, SynapseError +from synapse.api.errors import SynapseError from synapse.config import ConfigError from synapse.config.saml2_config import SamlAttributeRequirement +from synapse.http.server import respond_with_html from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest from synapse.module_api import ModuleApi @@ -41,6 +42,10 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class MappingException(Exception): + """Used to catch errors when mapping the SAML2 response to a user.""" + + @attr.s class Saml2SessionData: """Data we track about SAML2 sessions""" @@ -68,6 +73,7 @@ class SamlHandler: hs.config.saml2_grandfathered_mxid_source_attribute ) self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements + self._error_template = hs.config.sso_error_template # plugin to do custom mapping from saml response to mxid self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class( @@ -84,6 +90,25 @@ class SamlHandler: # a lock on the mappings self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock) + def _render_error( + self, request, error: str, error_description: Optional[str] = None + ) -> None: + """Render the error template and respond to the request with it. + + This is used to show errors to the user. The template of this page can + be found under `synapse/res/templates/sso_error.html`. + + Args: + request: The incoming request from the browser. + We'll respond with an HTML page describing the error. + error: A technical identifier for this error. + error_description: A human-readable description of the error. + """ + html = self._error_template.render( + error=error, error_description=error_description + ) + respond_with_html(request, 400, html) + def handle_redirect_request( self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None ) -> bytes: @@ -134,49 +159,6 @@ class SamlHandler: # the dict. self.expire_sessions() - # Pull out the user-agent and IP from the request. - user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ - 0 - ].decode("ascii", "surrogateescape") - ip_address = self.hs.get_ip_from_request(request) - - user_id, current_session = await self._map_saml_response_to_user( - resp_bytes, relay_state, user_agent, ip_address - ) - - # Complete the interactive auth session or the login. - if current_session and current_session.ui_auth_session_id: - await self._auth_handler.complete_sso_ui_auth( - user_id, current_session.ui_auth_session_id, request - ) - - else: - await self._auth_handler.complete_sso_login(user_id, request, relay_state) - - async def _map_saml_response_to_user( - self, - resp_bytes: str, - client_redirect_url: str, - user_agent: str, - ip_address: str, - ) -> Tuple[str, Optional[Saml2SessionData]]: - """ - Given a sample response, retrieve the cached session and user for it. - - Args: - resp_bytes: The SAML response. - client_redirect_url: The redirect URL passed in by the client. - user_agent: The user agent of the client making the request. - ip_address: The IP address of the client making the request. - - Returns: - Tuple of the user ID and SAML session associated with this response. - - Raises: - SynapseError if there was a problem with the response. - RedirectException: some mapping providers may raise this if they need - to redirect to an interstitial page. - """ try: saml2_auth = self._saml_client.parse_authn_request_response( resp_bytes, @@ -189,12 +171,23 @@ class SamlHandler: # in the (user-visible) exception message, so let's log the exception here # so we can track down the session IDs later. logger.warning(str(e)) - raise SynapseError(400, "Unexpected SAML2 login.") + self._render_error( + request, "unsolicited_response", "Unexpected SAML2 login." + ) + return except Exception as e: - raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,)) + self._render_error( + request, + "invalid_response", + "Unable to parse SAML2 response: %s." % (e,), + ) + return if saml2_auth.not_signed: - raise SynapseError(400, "SAML2 response was not signed.") + self._render_error( + request, "unsigned_respond", "SAML2 response was not signed." + ) + return logger.debug("SAML2 response: %s", saml2_auth.origxml) for assertion in saml2_auth.assertions: @@ -213,15 +206,73 @@ class SamlHandler: saml2_auth.in_response_to, None ) + # Ensure that the attributes of the logged in user meet the required + # attributes. for requirement in self._saml2_attribute_requirements: - _check_attribute_requirement(saml2_auth.ava, requirement) + if not _check_attribute_requirement(saml2_auth.ava, requirement): + self._render_error( + request, "unauthorised", "You are not authorised to log in here." + ) + return + + # Pull out the user-agent and IP from the request. + user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[ + 0 + ].decode("ascii", "surrogateescape") + ip_address = self.hs.get_ip_from_request(request) + + # Call the mapper to register/login the user + try: + user_id = await self._map_saml_response_to_user( + saml2_auth, relay_state, user_agent, ip_address + ) + except MappingException as e: + logger.exception("Could not map user") + self._render_error(request, "mapping_error", str(e)) + return + + # Complete the interactive auth session or the login. + if current_session and current_session.ui_auth_session_id: + await self._auth_handler.complete_sso_ui_auth( + user_id, current_session.ui_auth_session_id, request + ) + + else: + await self._auth_handler.complete_sso_login(user_id, request, relay_state) + + async def _map_saml_response_to_user( + self, + saml2_auth: saml2.response.AuthnResponse, + client_redirect_url: str, + user_agent: str, + ip_address: str, + ) -> str: + """ + Given a SAML response, retrieve the user ID for it and possibly register the user. + + Args: + saml2_auth: The parsed SAML2 response. + client_redirect_url: The redirect URL passed in by the client. + user_agent: The user agent of the client making the request. + ip_address: The IP address of the client making the request. + + Returns: + The user ID associated with this response. + + Raises: + MappingException if there was a problem mapping the response to a user. + RedirectException: some mapping providers may raise this if they need + to redirect to an interstitial page. + """ remote_user_id = self._user_mapping_provider.get_remote_user_id( saml2_auth, client_redirect_url ) if not remote_user_id: - raise Exception("Failed to extract remote user id from SAML response") + raise MappingException( + "Failed to extract remote user id from SAML response" + ) with (await self._mapping_lock.queue(self._auth_provider_id)): # first of all, check if we already have a mapping for this user @@ -235,7 +286,7 @@ class SamlHandler: ) if registered_user_id is not None: logger.info("Found existing mapping %s", registered_user_id) - return registered_user_id, current_session + return registered_user_id # backwards-compatibility hack: see if there is an existing user with a # suitable mapping from the uid @@ -260,7 +311,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id, current_session + return registered_user_id # Map saml response to user attributes using the configured mapping provider for i in range(1000): @@ -277,7 +328,7 @@ class SamlHandler: localpart = attribute_dict.get("mxid_localpart") if not localpart: - raise Exception( + raise MappingException( "Error parsing SAML2 response: SAML mapping provider plugin " "did not return a mxid_localpart value" ) @@ -294,8 +345,8 @@ class SamlHandler: else: # Unable to generate a username in 1000 iterations # Break and return error to the user - raise SynapseError( - 500, "Unable to generate a Matrix ID from the SAML response" + raise MappingException( + "Unable to generate a Matrix ID from the SAML response" ) logger.info("Mapped SAML user to local part %s", localpart) @@ -310,7 +361,7 @@ class SamlHandler: await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id ) - return registered_user_id, current_session + return registered_user_id def expire_sessions(self): expire_before = self._clock.time_msec() - self._saml2_session_lifetime @@ -323,11 +374,11 @@ class SamlHandler: del self._outstanding_requests_dict[reqid] -def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement): +def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool: values = ava.get(req.attribute, []) for v in values: if v == req.value: - return + return True logger.info( "SAML2 attribute %s did not match required value '%s' (was '%s')", @@ -335,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement): req.value, values, ) - raise AuthError(403, "You are not authorized to log in here.") + return False DOT_REPLACE_PATTERN = re.compile( @@ -390,7 +441,7 @@ class DefaultSamlMappingProvider: return saml_response.ava["uid"][0] except KeyError: logger.warning("SAML2 response lacks a 'uid' attestation") - raise SynapseError(400, "'uid' not in SAML2 response") + raise MappingException("'uid' not in SAML2 response") def saml_response_to_user_attributes( self, diff --git a/synapse/res/templates/saml_error.html b/synapse/res/templates/saml_error.html deleted file mode 100644 index 01cd9bdaf3..0000000000 --- a/synapse/res/templates/saml_error.html +++ /dev/null @@ -1,52 +0,0 @@ - - - - - SSO login error - - -{# a 403 means we have actively rejected their login #} -{% if code == 403 %} -

You are not allowed to log in here.

-{% else %} -

- There was an error during authentication: -

-
{{ msg }}
-

- If you are seeing this page after clicking a link sent to you via email, make - sure you only click the confirmation link once, and that you open the - validation link in the same client you're logging in from. -

-

- Try logging in again from your Matrix client and if the problem persists - please contact the server's administrator. -

- - -{% endif %} - - diff --git a/synapse/res/templates/sso_error.html b/synapse/res/templates/sso_error.html index 43a211386b..af8459719a 100644 --- a/synapse/res/templates/sso_error.html +++ b/synapse/res/templates/sso_error.html @@ -5,14 +5,49 @@ SSO error -

Oops! Something went wrong during authentication.

+{# If an error of unauthorised is returned it means we have actively rejected their login #} +{% if error == "unauthorised" %} +

You are not allowed to log in here.

+{% else %} +

+ There was an error during authentication: +

+
{{ error_description }}
+

+ If you are seeing this page after clicking a link sent to you via email, make + sure you only click the confirmation link once, and that you open the + validation link in the same client you're logging in from. +

Try logging in again from your Matrix client and if the problem persists please contact the server's administrator.

Error: {{ error }}

- {% if error_description %} -
{{ error_description }}
- {% endif %} + + +{% endif %} diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py index c10188a5d7..f6668fb5e3 100644 --- a/synapse/rest/saml2/response_resource.py +++ b/synapse/rest/saml2/response_resource.py @@ -13,10 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from twisted.python import failure -from synapse.api.errors import SynapseError -from synapse.http.server import DirectServeHtmlResource, return_html_error +from synapse.http.server import DirectServeHtmlResource class SAML2ResponseResource(DirectServeHtmlResource): @@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource): def __init__(self, hs): super().__init__() self._saml_handler = hs.get_saml_handler() - self._error_html_template = hs.config.saml2.saml2_error_html_template async def _async_render_GET(self, request): # We're not expecting any GET request on that resource if everything goes right, # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. # In this case, just tell the user that something went wrong and they should # try to authenticate again. - f = failure.Failure( - SynapseError(400, "Unexpected GET request on /saml2/authn_response") + self._saml_handler._render_error( + request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" ) - return_html_error(f, request, self._error_html_template) async def _async_render_POST(self, request): - try: - await self._saml_handler.handle_saml_response(request) - except Exception: - f = failure.Failure() - return_html_error(f, request, self._error_html_template) + await self._saml_handler.handle_saml_response(request) -- cgit 1.5.1 From abd04b6af0671517a01781c8bd10fef2a6c32cc4 Mon Sep 17 00:00:00 2001 From: Tdxdxoz Date: Fri, 25 Sep 2020 19:01:45 +0800 Subject: Allow existing users to login via OpenID Connect. (#8345) Co-authored-by: Benjamin Koch This adds configuration flags that will match a user to pre-existing users when logging in via OpenID Connect. This is useful when switching to an existing SSO system. --- changelog.d/8345.feature | 1 + docs/sample_config.yaml | 5 +++ synapse/config/oidc_config.py | 6 ++++ synapse/handlers/oidc_handler.py | 42 +++++++++++++++++--------- synapse/storage/databases/main/registration.py | 4 +-- tests/handlers/test_oidc.py | 35 +++++++++++++++++++++ 6 files changed, 76 insertions(+), 17 deletions(-) create mode 100644 changelog.d/8345.feature (limited to 'synapse/handlers/oidc_handler.py') diff --git a/changelog.d/8345.feature b/changelog.d/8345.feature new file mode 100644 index 0000000000..4ee5b6a56e --- /dev/null +++ b/changelog.d/8345.feature @@ -0,0 +1 @@ +Add a configuration option that allows existing users to log in with OpenID Connect. Contributed by @BBBSnowball and @OmmyZhang. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index fb04ff283d..845f537795 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1689,6 +1689,11 @@ oidc_config: # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index e0939bce84..70fc8a2f62 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -56,6 +56,7 @@ class OIDCConfig(Config): self.oidc_userinfo_endpoint = oidc_config.get("userinfo_endpoint") self.oidc_jwks_uri = oidc_config.get("jwks_uri") self.oidc_skip_verification = oidc_config.get("skip_verification", False) + self.oidc_allow_existing_users = oidc_config.get("allow_existing_users", False) ump_config = oidc_config.get("user_mapping_provider", {}) ump_config.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER) @@ -158,6 +159,11 @@ class OIDCConfig(Config): # #skip_verification: true + # Uncomment to allow a user logging in via OIDC to match a pre-existing account instead + # of failing. This could be used if switching from password logins to OIDC. Defaults to false. + # + #allow_existing_users: true + # An external module can be provided here as a custom solution to mapping # attributes returned from a OIDC provider onto a matrix user. # diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 4230dbaf99..0e06e4408d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -114,6 +114,7 @@ class OidcHandler: hs.config.oidc_user_mapping_provider_config ) # type: OidcMappingProvider self._skip_verification = hs.config.oidc_skip_verification # type: bool + self._allow_existing_users = hs.config.oidc_allow_existing_users # type: bool self._http_client = hs.get_proxied_http_client() self._auth_handler = hs.get_auth_handler() @@ -849,7 +850,8 @@ class OidcHandler: If we don't find the user that way, we should register the user, mapping the localpart and the display name from the UserInfo. - If a user already exists with the mxid we've mapped, raise an exception. + If a user already exists with the mxid we've mapped and allow_existing_users + is disabled, raise an exception. Args: userinfo: an object representing the user @@ -905,21 +907,31 @@ class OidcHandler: localpart = map_username_to_mxid_localpart(attributes["localpart"]) - user_id = UserID(localpart, self._hostname) - if await self._datastore.get_users_by_id_case_insensitive(user_id.to_string()): - # This mxid is taken - raise MappingException( - "mxid '{}' is already taken".format(user_id.to_string()) + user_id = UserID(localpart, self._hostname).to_string() + users = await self._datastore.get_users_by_id_case_insensitive(user_id) + if users: + if self._allow_existing_users: + if len(users) == 1: + registered_user_id = next(iter(users)) + elif user_id in users: + registered_user_id = user_id + else: + raise MappingException( + "Attempted to login as '{}' but it matches more than one user inexactly: {}".format( + user_id, list(users.keys()) + ) + ) + else: + # This mxid is taken + raise MappingException("mxid '{}' is already taken".format(user_id)) + else: + # It's the first time this user is logging in and the mapped mxid was + # not taken, register the user + registered_user_id = await self._registration_handler.register_user( + localpart=localpart, + default_display_name=attributes["display_name"], + user_agent_ips=(user_agent, ip_address), ) - - # It's the first time this user is logging in and the mapped mxid was - # not taken, register the user - registered_user_id = await self._registration_handler.register_user( - localpart=localpart, - default_display_name=attributes["display_name"], - user_agent_ips=(user_agent, ip_address), - ) - await self._datastore.record_user_external_id( self._auth_provider_id, remote_user_id, registered_user_id, ) diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 33825e8949..48ce7ecd16 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -393,7 +393,7 @@ class RegistrationWorkerStore(SQLBaseStore): async def get_user_by_external_id( self, auth_provider: str, external_id: str - ) -> str: + ) -> Optional[str]: """Look up a user by their external auth id Args: @@ -401,7 +401,7 @@ class RegistrationWorkerStore(SQLBaseStore): external_id: id on that system Returns: - str|None: the mxid of the user, or None if they are not known + the mxid of the user, or None if they are not known """ return await self.db_pool.simple_select_one_onecol( table="user_external_ids", diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 89ec5fcb31..5910772aa8 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -617,3 +617,38 @@ class OidcHandlerTestCase(HomeserverTestCase): ) ) self.assertEqual(mxid, "@test_user_2:test") + + # Test if the mxid is already taken + store = self.hs.get_datastore() + user3 = UserID.from_string("@test_user_3:test") + self.get_success( + store.register_user(user_id=user3.to_string(), password_hash=None) + ) + userinfo = {"sub": "test3", "username": "test_user_3"} + e = self.get_failure( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ), + MappingException, + ) + self.assertEqual(str(e.value), "mxid '@test_user_3:test' is already taken") + + @override_config({"oidc_config": {"allow_existing_users": True}}) + def test_map_userinfo_to_existing_user(self): + """Existing users can log in with OpenID Connect when allow_existing_users is True.""" + store = self.hs.get_datastore() + user4 = UserID.from_string("@test_user_4:test") + self.get_success( + store.register_user(user_id=user4.to_string(), password_hash=None) + ) + userinfo = { + "sub": "test4", + "username": "test_user_4", + } + token = {} + mxid = self.get_success( + self.handler._map_userinfo_to_user( + userinfo, token, "user-agent", "10.10.10.10" + ) + ) + self.assertEqual(mxid, "@test_user_4:test") -- cgit 1.5.1 From 8b40843392e2df80d4f1108295ae6acd972100b0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 30 Sep 2020 13:02:43 -0400 Subject: Allow additional SSO properties to be passed to the client (#8413) --- changelog.d/8413.feature | 1 + docs/sample_config.yaml | 8 ++ docs/sso_mapping_providers.md | 14 +++- docs/workers.md | 16 ++++ synapse/config/oidc_config.py | 8 ++ synapse/handlers/auth.py | 60 ++++++++++++++- synapse/handlers/oidc_handler.py | 56 +++++++++++++- synapse/rest/client/v1/login.py | 22 ++++-- tests/handlers/test_oidc.py | 160 +++++++++++++++++++++++++-------------- 9 files changed, 278 insertions(+), 67 deletions(-) create mode 100644 changelog.d/8413.feature (limited to 'synapse/handlers/oidc_handler.py') diff --git a/changelog.d/8413.feature b/changelog.d/8413.feature new file mode 100644 index 0000000000..abe40a901c --- /dev/null +++ b/changelog.d/8413.feature @@ -0,0 +1 @@ +Support passing additional single sign-on parameters to the client. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 70cc06a6d8..066844b5a9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1748,6 +1748,14 @@ oidc_config: # #display_name_template: "{{ user.given_name }} {{ user.last_name }}" + # Jinja2 templates for extra attributes to send back to the client during + # login. + # + # Note that these are non-standard and clients will ignore them without modifications. + # + #extra_attributes: + #birthdate: "{{ user.birthdate }}" + # Enable CAS for registration and login. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index abea432343..32b06aa2c5 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -57,7 +57,7 @@ A custom mapping provider must specify the following methods: - This method must return a string, which is the unique identifier for the user. Commonly the ``sub`` claim of the response. * `map_user_attributes(self, userinfo, token)` - - This method should be async. + - This method must be async. - Arguments: - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user information from. @@ -66,6 +66,18 @@ A custom mapping provider must specify the following methods: - Returns a dictionary with two keys: - localpart: A required string, used to generate the Matrix ID. - displayname: An optional string, the display name for the user. +* `get_extra_attributes(self, userinfo, token)` + - This method must be async. + - Arguments: + - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user + information from. + - `token` - A dictionary which includes information necessary to make + further requests to the OpenID provider. + - Returns a dictionary that is suitable to be serialized to JSON. This + will be returned as part of the response during a successful login. + + Note that care should be taken to not overwrite any of the parameters + usually returned as part of the [login response](https://matrix.org/docs/spec/client_server/latest#post-matrix-client-r0-login). ### Default OpenID Mapping Provider diff --git a/docs/workers.md b/docs/workers.md index df0ac84d94..ad4d8ca9f2 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -243,6 +243,22 @@ for the room are in flight: ^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/messages$ +Additionally, the following endpoints should be included if Synapse is configured +to use SSO (you only need to include the ones for whichever SSO provider you're +using): + + # OpenID Connect requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ + ^/_synapse/oidc/callback$ + + # SAML requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect$ + ^/_matrix/saml2/authn_response$ + + # CAS requests. + ^/_matrix/client/(api/v1|r0|unstable)/login/(cas|sso)/redirect$ + ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ + Note that a HTTP listener with `client` and `federation` resources must be configured in the `worker_listeners` option in the worker config. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 70fc8a2f62..f924116819 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -204,6 +204,14 @@ class OIDCConfig(Config): # If unset, no displayname will be set. # #display_name_template: "{{{{ user.given_name }}}} {{{{ user.last_name }}}}" + + # Jinja2 templates for extra attributes to send back to the client during + # login. + # + # Note that these are non-standard and clients will ignore them without modifications. + # + #extra_attributes: + #birthdate: "{{{{ user.birthdate }}}}" """.format( mapping_provider=DEFAULT_USER_MAPPING_PROVIDER ) diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0322b60cfc..00eae92052 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -137,6 +137,15 @@ def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]: } +@attr.s(slots=True) +class SsoLoginExtraAttributes: + """Data we track about SAML2 sessions""" + + # time the session was created, in milliseconds + creation_time = attr.ib(type=int) + extra_attributes = attr.ib(type=JsonDict) + + class AuthHandler(BaseHandler): SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 @@ -239,6 +248,10 @@ class AuthHandler(BaseHandler): # cast to tuple for use with str.startswith self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + # A mapping of user ID to extra attributes to include in the login + # response. + self._extra_attributes = {} # type: Dict[str, SsoLoginExtraAttributes] + async def validate_user_via_ui_auth( self, requester: Requester, @@ -1165,6 +1178,7 @@ class AuthHandler(BaseHandler): registered_user_id: str, request: SynapseRequest, client_redirect_url: str, + extra_attributes: Optional[JsonDict] = None, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1173,6 +1187,8 @@ class AuthHandler(BaseHandler): request: The request to complete. client_redirect_url: The URL to which to redirect the user at the end of the process. + extra_attributes: Extra attributes which will be passed to the client + during successful login. Must be JSON serializable. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1181,19 +1197,30 @@ class AuthHandler(BaseHandler): respond_with_html(request, 403, self._sso_account_deactivated_template) return - self._complete_sso_login(registered_user_id, request, client_redirect_url) + self._complete_sso_login( + registered_user_id, request, client_redirect_url, extra_attributes + ) def _complete_sso_login( self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str, + extra_attributes: Optional[JsonDict] = None, ): """ The synchronous portion of complete_sso_login. This exists purely for backwards compatibility of synapse.module_api.ModuleApi. """ + # Store any extra attributes which will be passed in the login response. + # Note that this is per-user so it may overwrite a previous value, this + # is considered OK since the newest SSO attributes should be most valid. + if extra_attributes: + self._extra_attributes[registered_user_id] = SsoLoginExtraAttributes( + self._clock.time_msec(), extra_attributes, + ) + # Create a login token login_token = self.macaroon_gen.generate_short_term_login_token( registered_user_id @@ -1226,6 +1253,37 @@ class AuthHandler(BaseHandler): ) respond_with_html(request, 200, html) + async def _sso_login_callback(self, login_result: JsonDict) -> None: + """ + A login callback which might add additional attributes to the login response. + + Args: + login_result: The data to be sent to the client. Includes the user + ID and access token. + """ + # Expire attributes before processing. Note that there shouldn't be any + # valid logins that still have extra attributes. + self._expire_sso_extra_attributes() + + extra_attributes = self._extra_attributes.get(login_result["user_id"]) + if extra_attributes: + login_result.update(extra_attributes.extra_attributes) + + def _expire_sso_extra_attributes(self) -> None: + """ + Iterate through the mapping of user IDs to extra attributes and remove any that are no longer valid. + """ + # TODO This should match the amount of time the macaroon is valid for. + LOGIN_TOKEN_EXPIRATION_TIME = 2 * 60 * 1000 + expire_before = self._clock.time_msec() - LOGIN_TOKEN_EXPIRATION_TIME + to_expire = set() + for user_id, data in self._extra_attributes.items(): + if data.creation_time < expire_before: + to_expire.add(user_id) + for user_id in to_expire: + logger.debug("Expiring extra attributes for user %s", user_id) + del self._extra_attributes[user_id] + @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 0e06e4408d..19cd652675 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -37,7 +37,7 @@ from synapse.config import ConfigError from synapse.http.server import respond_with_html from synapse.http.site import SynapseRequest from synapse.logging.context import make_deferred_yieldable -from synapse.types import UserID, map_username_to_mxid_localpart +from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart from synapse.util import json_decoder if TYPE_CHECKING: @@ -707,6 +707,15 @@ class OidcHandler: self._render_error(request, "mapping_error", str(e)) return + # Mapping providers might not have get_extra_attributes: only call this + # method if it exists. + extra_attributes = None + get_extra_attributes = getattr( + self._user_mapping_provider, "get_extra_attributes", None + ) + if get_extra_attributes: + extra_attributes = await get_extra_attributes(userinfo, token) + # and finally complete the login if ui_auth_session_id: await self._auth_handler.complete_sso_ui_auth( @@ -714,7 +723,7 @@ class OidcHandler: ) else: await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url + user_id, request, client_redirect_url, extra_attributes ) def _generate_oidc_session_token( @@ -984,7 +993,7 @@ class OidcMappingProvider(Generic[C]): async def map_user_attributes( self, userinfo: UserInfo, token: Token ) -> UserAttribute: - """Map a ``UserInfo`` objects into user attributes. + """Map a `UserInfo` object into user attributes. Args: userinfo: An object representing the user given by the OIDC provider @@ -995,6 +1004,18 @@ class OidcMappingProvider(Generic[C]): """ raise NotImplementedError() + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + """Map a `UserInfo` object into additional attributes passed to the client during login. + + Args: + userinfo: An object representing the user given by the OIDC provider + token: A dict with the tokens returned by the provider + + Returns: + A dict containing additional attributes. Must be JSON serializable. + """ + return {} + # Used to clear out "None" values in templates def jinja_finalize(thing): @@ -1009,6 +1030,7 @@ class JinjaOidcMappingConfig: subject_claim = attr.ib() # type: str localpart_template = attr.ib() # type: Template display_name_template = attr.ib() # type: Optional[Template] + extra_attributes = attr.ib() # type: Dict[str, Template] class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): @@ -1047,10 +1069,28 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): % (e,) ) + extra_attributes = {} # type Dict[str, Template] + if "extra_attributes" in config: + extra_attributes_config = config.get("extra_attributes") or {} + if not isinstance(extra_attributes_config, dict): + raise ConfigError( + "oidc_config.user_mapping_provider.config.extra_attributes must be a dict" + ) + + for key, value in extra_attributes_config.items(): + try: + extra_attributes[key] = env.from_string(value) + except Exception as e: + raise ConfigError( + "invalid jinja template for oidc_config.user_mapping_provider.config.extra_attributes.%s: %r" + % (key, e) + ) + return JinjaOidcMappingConfig( subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, + extra_attributes=extra_attributes, ) def get_remote_user_id(self, userinfo: UserInfo) -> str: @@ -1071,3 +1111,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): display_name = None return UserAttribute(localpart=localpart, display_name=display_name) + + async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: + extras = {} # type: Dict[str, str] + for key, template in self._config.extra_attributes.items(): + try: + extras[key] = template.render(user=userinfo).strip() + except Exception as e: + # Log an error and skip this value (don't break login for this). + logger.error("Failed to render OIDC extra attribute %s: %s" % (key, e)) + return extras diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 250b03a025..b9347b87c7 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -284,9 +284,7 @@ class LoginRestServlet(RestServlet): self, user_id: str, login_submission: JsonDict, - callback: Optional[ - Callable[[Dict[str, str]], Awaitable[Dict[str, str]]] - ] = None, + callback: Optional[Callable[[Dict[str, str]], Awaitable[None]]] = None, create_non_existent_users: bool = False, ) -> Dict[str, str]: """Called when we've successfully authed the user and now need to @@ -299,12 +297,12 @@ class LoginRestServlet(RestServlet): Args: user_id: ID of the user to register. login_submission: Dictionary of login information. - callback: Callback function to run after registration. + callback: Callback function to run after login. create_non_existent_users: Whether to create the user if they don't exist. Defaults to False. Returns: - result: Dictionary of account information after successful registration. + result: Dictionary of account information after successful login. """ # Before we actually log them in we check if they've already logged in @@ -339,14 +337,24 @@ class LoginRestServlet(RestServlet): return result async def _do_token_login(self, login_submission: JsonDict) -> Dict[str, str]: + """ + Handle the final stage of SSO login. + + Args: + login_submission: The JSON request body. + + Returns: + The body of the JSON response. + """ token = login_submission["token"] auth_handler = self.auth_handler user_id = await auth_handler.validate_short_term_login_token_and_get_user_id( token ) - result = await self._complete_login(user_id, login_submission) - return result + return await self._complete_login( + user_id, login_submission, self.auth_handler._sso_login_callback + ) async def _do_jwt_login(self, login_submission: JsonDict) -> Dict[str, str]: token = login_submission.get("token", None) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index 5910772aa8..d5087e58be 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -21,7 +21,6 @@ from mock import Mock, patch import attr import pymacaroons -from twisted.internet import defer from twisted.python.failure import Failure from twisted.web._newclient import ResponseDone @@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider): async def map_user_attributes(self, userinfo, token): return {"localpart": userinfo["username"], "display_name": None} + # Do not include get_extra_attributes to test backwards compatibility paths. + + +class TestMappingProviderExtra(TestMappingProvider): + async def get_extra_attributes(self, userinfo, token): + return {"phone": userinfo["phone"]} + def simple_async_mock(return_value=None, raises=None): # AsyncMock is not available in python3.5, this mimics part of its behaviour @@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase): config = self.default_config() config["public_baseurl"] = BASE_URL - oidc_config = config.get("oidc_config", {}) + oidc_config = {} oidc_config["enabled"] = True oidc_config["client_id"] = CLIENT_ID oidc_config["client_secret"] = CLIENT_SECRET @@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase): oidc_config["user_mapping_provider"] = { "module": __name__ + ".TestMappingProvider", } + + # Update this config with what's in the default config so that + # override_config works as expected. + oidc_config.update(config.get("oidc_config", {})) config["oidc_config"] = oidc_config hs = self.setup_test_homeserver( @@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET) @override_config({"oidc_config": {"discover": True}}) - @defer.inlineCallbacks def test_discovery(self): """The handler should discover the endpoints from OIDC discovery document.""" # This would throw if some metadata were invalid - metadata = yield defer.ensureDeferred(self.handler.load_metadata()) + metadata = self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_called_once_with(WELL_KNOWN) self.assertEqual(metadata.issuer, ISSUER) @@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase): # subsequent calls should be cached self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_metadata()) + self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) - @defer.inlineCallbacks def test_no_discovery(self): """When discovery is disabled, it should not try to load from discovery document.""" - yield defer.ensureDeferred(self.handler.load_metadata()) + self.get_success(self.handler.load_metadata()) self.http_client.get_json.assert_not_called() @override_config({"oidc_config": COMMON_CONFIG}) - @defer.inlineCallbacks def test_load_jwks(self): """JWKS loading is done once (then cached) if used.""" - jwks = yield defer.ensureDeferred(self.handler.load_jwks()) + jwks = self.get_success(self.handler.load_jwks()) self.http_client.get_json.assert_called_once_with(JWKS_URI) self.assertEqual(jwks, {"keys": []}) # subsequent calls should be cached… self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_jwks()) + self.get_success(self.handler.load_jwks()) self.http_client.get_json.assert_not_called() # …unless forced self.http_client.reset_mock() - yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.get_success(self.handler.load_jwks(force=True)) self.http_client.get_json.assert_called_once_with(JWKS_URI) # Throw if the JWKS uri is missing with self.metadata_edit({"jwks_uri": None}): - with self.assertRaises(RuntimeError): - yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + self.get_failure(self.handler.load_jwks(force=True), RuntimeError) # Return empty key set if JWKS are not used self.handler._scopes = [] # not asking the openid scope self.http_client.get_json.reset_mock() - jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True)) + jwks = self.get_success(self.handler.load_jwks(force=True)) self.http_client.get_json.assert_not_called() self.assertEqual(jwks, {"keys": []}) @@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase): # This should not throw self.handler._validate_metadata() - @defer.inlineCallbacks def test_redirect_request(self): """The redirect request has the right arguments & generates a valid session cookie.""" req = Mock(spec=["addCookie"]) - url = yield defer.ensureDeferred( + url = self.get_success( self.handler.handle_redirect_request(req, b"http://client/redirect") ) url = urlparse(url) @@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase): self.assertEqual(params["nonce"], [nonce]) self.assertEqual(redirect, "http://client/redirect") - @defer.inlineCallbacks def test_callback_error(self): """Errors from the provider returned in the callback are displayed.""" self.handler._render_error = Mock() request = Mock(args={}) request.args[b"error"] = [b"invalid_client"] - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_client", "") request.args[b"error_description"] = [b"some description"] - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_client", "some description") - @defer.inlineCallbacks def test_callback(self): """Code callback works and display errors if something went wrong. @@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase): "sub": "foo", "preferred_username": "bar", } - user_id = UserID("foo", "domain.org") + user_id = "@foo:domain.org" self.handler._render_error = Mock(return_value=None) self.handler._exchange_code = simple_async_mock(return_value=token) self.handler._parse_id_token = simple_async_mock(return_value=userinfo) @@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase): client_redirect_url = "http://client/redirect" user_agent = "Browser" ip_address = "10.0.0.1" - session = self.handler._generate_oidc_session_token( + request.getCookie.return_value = self.handler._generate_oidc_session_token( state=state, nonce=nonce, client_redirect_url=client_redirect_url, ui_auth_session_id=None, ) - request.getCookie.return_value = session request.args = {} request.args[b"code"] = [code.encode("utf-8")] @@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase): request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")] request.getClientIP.return_value = ip_address - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, + user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase): self.handler._map_userinfo_to_user = simple_async_mock( raises=MappingException() ) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mapping_error") self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) # Handle ID token errors self.handler._parse_id_token = simple_async_mock(raises=Exception()) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_token") self.handler._auth_handler.complete_sso_login.reset_mock() @@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase): # With userinfo fetching self.handler._scopes = [] # do not ask the "openid" scope - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.handler._auth_handler.complete_sso_login.assert_called_once_with( - user_id, request, client_redirect_url, + user_id, request, client_redirect_url, {}, ) self.handler._exchange_code.assert_called_once_with(code) self.handler._parse_id_token.assert_not_called() @@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase): # Handle userinfo fetching error self.handler._fetch_userinfo = simple_async_mock(raises=Exception()) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("fetch_error") # Handle code exchange failure self.handler._exchange_code = simple_async_mock( raises=OidcError("invalid_request") ) - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") - @defer.inlineCallbacks def test_callback_session(self): """The callback verifies the session presence and validity""" self.handler._render_error = Mock(return_value=None) @@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase): # Missing cookie request.args = {} request.getCookie.return_value = None - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("missing_session", "No session cookie found") # Missing session parameter request.args = {} request.getCookie.return_value = "session" - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request", "State parameter is missing") # Invalid cookie request.args = {} request.args[b"state"] = [b"state"] request.getCookie.return_value = "session" - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_session") # Mismatching session @@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase): request.args = {} request.args[b"state"] = [b"mismatching state"] request.getCookie.return_value = session - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("mismatching_session") # Valid session request.args = {} request.args[b"state"] = [b"state"] request.getCookie.return_value = session - yield defer.ensureDeferred(self.handler.handle_oidc_callback(request)) + self.get_success(self.handler.handle_oidc_callback(request)) self.assertRenderedError("invalid_request") @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}}) - @defer.inlineCallbacks def test_exchange_code(self): """Code exchange behaves correctly and handles various error scenarios.""" token = {"type": "bearer"} @@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase): return_value=FakeResponse(code=200, phrase=b"OK", body=token_json) ) code = "code" - ret = yield defer.ensureDeferred(self.handler._exchange_code(code)) + ret = self.get_success(self.handler._exchange_code(code)) kwargs = self.http_client.request.call_args[1] self.assertEqual(ret, token) @@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "foo", "error_description": "bar"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "foo") - self.assertEqual(exc.exception.error_description, "bar") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "foo") + self.assertEqual(exc.value.error_description, "bar") # Internal server error with no JSON body self.http_client.request = simple_async_mock( @@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase): code=500, phrase=b"Internal Server Error", body=b"Not JSON", ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "server_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "server_error") # Internal server error with JSON body self.http_client.request = simple_async_mock( @@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase): body=b'{"error": "internal_server_error"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "internal_server_error") + + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "internal_server_error") # 4xx error without "error" field self.http_client.request = simple_async_mock( return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "server_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "server_error") # 2xx error with "error" field self.http_client.request = simple_async_mock( @@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase): code=200, phrase=b"OK", body=b'{"error": "some_error"}', ) ) - with self.assertRaises(OidcError) as exc: - yield defer.ensureDeferred(self.handler._exchange_code(code)) - self.assertEqual(exc.exception.error, "some_error") + exc = self.get_failure(self.handler._exchange_code(code), OidcError) + self.assertEqual(exc.value.error, "some_error") + + @override_config( + { + "oidc_config": { + "user_mapping_provider": { + "module": __name__ + ".TestMappingProviderExtra" + } + } + } + ) + def test_extra_attributes(self): + """ + Login while using a mapping provider that implements get_extra_attributes. + """ + token = { + "type": "bearer", + "id_token": "id_token", + "access_token": "access_token", + } + userinfo = { + "sub": "foo", + "phone": "1234567", + } + user_id = "@foo:domain.org" + self.handler._exchange_code = simple_async_mock(return_value=token) + self.handler._parse_id_token = simple_async_mock(return_value=userinfo) + self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id) + self.handler._auth_handler.complete_sso_login = simple_async_mock() + request = Mock( + spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"] + ) + + state = "state" + client_redirect_url = "http://client/redirect" + request.getCookie.return_value = self.handler._generate_oidc_session_token( + state=state, + nonce="nonce", + client_redirect_url=client_redirect_url, + ui_auth_session_id=None, + ) + + request.args = {} + request.args[b"code"] = [b"code"] + request.args[b"state"] = [state.encode("utf-8")] + + request.requestHeaders = Mock(spec=["getRawHeaders"]) + request.requestHeaders.getRawHeaders.return_value = [b"Browser"] + request.getClientIP.return_value = "10.0.0.1" + + self.get_success(self.handler.handle_oidc_callback(request)) + + self.handler._auth_handler.complete_sso_login.assert_called_once_with( + user_id, request, client_redirect_url, {"phone": "1234567"}, + ) def test_map_userinfo_to_user(self): """Ensure that mapping the userinfo returned from a provider to an MXID works properly.""" -- cgit 1.5.1