From fa50e4bf4ddcb8e98d44700513a28c490f80f02b Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 12:30:41 +0000 Subject: Give `public_baseurl` a default value (#9159) --- synapse/handlers/identity.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index c05036ad1f..f61844d688 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -476,8 +476,6 @@ class IdentityHandler(BaseHandler): except RequestTimedOutError: raise SynapseError(500, "Timed out contacting identity server") - assert self.hs.config.public_baseurl - # we need to tell the client to send the token back to us, since it doesn't # otherwise know where to send it, so add submit_url response parameter # (see also MSC2078) -- cgit 1.5.1 From 0cd2938bc854d947ae8102ded688a626c9fac5b5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 20 Jan 2021 13:15:14 +0000 Subject: Support icons for Identity Providers (#9154) --- changelog.d/9154.feature | 1 + docs/sample_config.yaml | 4 ++ mypy.ini | 1 + synapse/config/oidc_config.py | 20 ++++++ synapse/config/server.py | 2 +- synapse/federation/federation_server.py | 2 +- synapse/federation/transport/server.py | 2 +- synapse/handlers/cas_handler.py | 4 ++ synapse/handlers/oidc_handler.py | 3 + synapse/handlers/room.py | 2 +- synapse/handlers/saml_handler.py | 4 ++ synapse/handlers/sso.py | 5 ++ synapse/http/endpoint.py | 79 --------------------- synapse/res/templates/sso_login_idp_picker.html | 3 + synapse/rest/client/v1/room.py | 3 +- synapse/storage/databases/main/room.py | 6 +- synapse/types.py | 2 +- synapse/util/stringutils.py | 92 +++++++++++++++++++++++++ tests/http/test_endpoint.py | 2 +- 19 files changed, 146 insertions(+), 91 deletions(-) create mode 100644 changelog.d/9154.feature delete mode 100644 synapse/http/endpoint.py (limited to 'synapse/handlers') diff --git a/changelog.d/9154.feature b/changelog.d/9154.feature new file mode 100644 index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9154.feature @@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 7fdd798d70..b49a5da8cc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1726,6 +1726,10 @@ saml2_config: # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # +# idp_icon: An optional icon for this identity provider, which is presented +# by identity picker pages. If given, must be an MXC URI of the format +# mxc:/// +# # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # diff --git a/mypy.ini b/mypy.ini index b996867121..bd99069c81 100644 --- a/mypy.ini +++ b/mypy.ini @@ -100,6 +100,7 @@ files = synapse/util/async_helpers.py, synapse/util/caches, synapse/util/metrics.py, + synapse/util/stringutils.py, tests/replication, tests/test_utils, tests/handlers/test_password_providers.py, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index df55367434..f257fcd412 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -23,6 +23,7 @@ from synapse.config._util import validate_config from synapse.python_dependencies import DependencyException, check_requirements from synapse.types import Collection, JsonDict from synapse.util.module_loader import load_module +from synapse.util.stringutils import parse_and_validate_mxc_uri from ._base import Config, ConfigError @@ -66,6 +67,10 @@ class OIDCConfig(Config): # idp_name: A user-facing name for this identity provider, which is used to # offer the user a choice of login mechanisms. # + # idp_icon: An optional icon for this identity provider, which is presented + # by identity picker pages. If given, must be an MXC URI of the format + # mxc:/// + # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. # @@ -207,6 +212,7 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "properties": { "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, "idp_name": {"type": "string"}, + "idp_icon": {"type": "string"}, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -336,9 +342,20 @@ def _parse_oidc_config_dict( config_path + ("idp_id",), ) + # MSC2858 also specifies that the idp_icon must be a valid MXC uri + idp_icon = oidc_config.get("idp_icon") + if idp_icon is not None: + try: + parse_and_validate_mxc_uri(idp_icon) + except ValueError as e: + raise ConfigError( + "idp_icon must be a valid MXC URI", config_path + ("idp_icon",) + ) from e + return OidcProviderConfig( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), + idp_icon=idp_icon, discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -366,6 +383,9 @@ class OidcProviderConfig: # user-facing name for this identity provider. idp_name = attr.ib(type=str) + # Optional MXC URI for icon for this IdP. + idp_icon = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/config/server.py b/synapse/config/server.py index 75ba161f35..47a0370173 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -26,7 +26,7 @@ import yaml from netaddr import IPSet from synapse.api.room_versions import KNOWN_ROOM_VERSIONS -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name from ._base import Config, ConfigError diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index e5339aca23..171d25c945 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -49,7 +49,6 @@ from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction -from synapse.http.endpoint import parse_server_name from synapse.http.servlet import assert_params_in_dict from synapse.logging.context import ( make_deferred_yieldable, @@ -66,6 +65,7 @@ from synapse.types import JsonDict, get_domain_from_id from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_server_name if TYPE_CHECKING: from synapse.server import HomeServer diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index cfd094e58f..95c64510a9 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -28,7 +28,6 @@ from synapse.api.urls import ( FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX, ) -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( parse_boolean_from_args, @@ -45,6 +44,7 @@ from synapse.logging.opentracing import ( ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id +from synapse.util.stringutils import parse_and_validate_server_name from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index f3430c6713..0f342c607b 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,6 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" + # we do not currently support icons for CAS auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ba686d74b2..1607e12935 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -271,6 +271,9 @@ class OidcProvider: # user-facing name of this auth provider self.idp_name = provider.idp_name + # MXC URI for icon for this auth provider + self.idp_icon = provider.idp_icon + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 3bece6d668..ee27d99135 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -38,7 +38,6 @@ from synapse.api.filtering import Filter from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.events.utils import copy_power_levels_contents -from synapse.http.endpoint import parse_and_validate_server_name from synapse.storage.state import StateFilter from synapse.types import ( JsonDict, @@ -55,6 +54,7 @@ from synapse.types import ( from synapse.util import stringutils from synapse.util.async_helpers import Linearizer from synapse.util.caches.response_cache import ResponseCache +from synapse.util.stringutils import parse_and_validate_server_name from synapse.visibility import filter_events_for_client from ._base import BaseHandler diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index a8376543c9..38461cf79d 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,6 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" + # we do not currently support icons for SAML auth, but this is required by + # the SsoIdentityProvider protocol type. + self.idp_icon = None + # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index dcc85e9871..d493327a10 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -75,6 +75,11 @@ class SsoIdentityProvider(Protocol): def idp_name(self) -> str: """User-facing name for this provider""" + @property + def idp_icon(self) -> Optional[str]: + """Optional MXC URI for user-facing icon""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/http/endpoint.py b/synapse/http/endpoint.py deleted file mode 100644 index 92a5b606c8..0000000000 --- a/synapse/http/endpoint.py +++ /dev/null @@ -1,79 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2014-2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging -import re - -logger = logging.getLogger(__name__) - - -def parse_server_name(server_name): - """Split a server name into host/port parts. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - try: - if server_name[-1] == "]": - # ipv6 literal, hopefully - return server_name, None - - domain_port = server_name.rsplit(":", 1) - domain = domain_port[0] - port = int(domain_port[1]) if domain_port[1:] else None - return domain, port - except Exception: - raise ValueError("Invalid server name '%s'" % server_name) - - -VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") - - -def parse_and_validate_server_name(server_name): - """Split a server name into host/port parts and do some basic validation. - - Args: - server_name (str): server name to parse - - Returns: - Tuple[str, int|None]: host/port parts. - - Raises: - ValueError if the server name could not be parsed. - """ - host, port = parse_server_name(server_name) - - # these tests don't need to be bulletproof as we'll find out soon enough - # if somebody is giving us invalid data. What we *do* need is to be sure - # that nobody is sneaking IP literals in that look like hostnames, etc. - - # look for ipv6 literals - if host[0] == "[": - if host[-1] != "]": - raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) - return host, port - - # otherwise it should only be alphanumerics. - if not VALID_HOST_REGEX.match(host): - raise ValueError( - "Server name '%s' contains invalid characters" % (server_name,) - ) - - return host, port diff --git a/synapse/res/templates/sso_login_idp_picker.html b/synapse/res/templates/sso_login_idp_picker.html index f53c9cd679..5b38481012 100644 --- a/synapse/res/templates/sso_login_idp_picker.html +++ b/synapse/res/templates/sso_login_idp_picker.html @@ -17,6 +17,9 @@
  • +{% if p.idp_icon %} + +{% endif %}
  • {% endfor %} diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index e6725b03b0..f95627ee61 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -32,7 +32,6 @@ from synapse.api.errors import ( ) from synapse.api.filtering import Filter from synapse.events.utils import format_event_for_client_v2 -from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.servlet import ( RestServlet, assert_params_in_dict, @@ -47,7 +46,7 @@ from synapse.storage.state import StateFilter from synapse.streams.config import PaginationConfig from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID from synapse.util import json_decoder -from synapse.util.stringutils import random_string +from synapse.util.stringutils import parse_and_validate_server_name, random_string if TYPE_CHECKING: import synapse.server diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 284f2ce77c..a9fcb5f59c 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -16,7 +16,6 @@ import collections import logging -import re from abc import abstractmethod from enum import Enum from typing import Any, Dict, List, Optional, Tuple @@ -30,6 +29,7 @@ from synapse.storage.databases.main.search import SearchStore from synapse.types import JsonDict, ThirdPartyInstanceID from synapse.util import json_encoder from synapse.util.caches.descriptors import cached +from synapse.util.stringutils import MXC_REGEX logger = logging.getLogger(__name__) @@ -660,8 +660,6 @@ class RoomWorkerStore(SQLBaseStore): The local and remote media as a lists of tuples where the key is the hostname and the value is the media ID. """ - mxc_re = re.compile("^mxc://([^/]+)/([^/#?]+)") - sql = """ SELECT stream_ordering, json FROM events JOIN event_json USING (room_id, event_id) @@ -688,7 +686,7 @@ class RoomWorkerStore(SQLBaseStore): for url in (content_url, thumbnail_url): if not url: continue - matches = mxc_re.match(url) + matches = MXC_REGEX.match(url) if matches: hostname = matches.group(1) media_id = matches.group(2) diff --git a/synapse/types.py b/synapse/types.py index 20a43d05bf..eafe729dfe 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -37,7 +37,7 @@ from signedjson.key import decode_verify_key_bytes from unpaddedbase64 import decode_base64 from synapse.api.errors import Codes, SynapseError -from synapse.http.endpoint import parse_and_validate_server_name +from synapse.util.stringutils import parse_and_validate_server_name if TYPE_CHECKING: from synapse.appservice.api import ApplicationService diff --git a/synapse/util/stringutils.py b/synapse/util/stringutils.py index b103c8694c..f8038bf861 100644 --- a/synapse/util/stringutils.py +++ b/synapse/util/stringutils.py @@ -18,6 +18,7 @@ import random import re import string from collections.abc import Iterable +from typing import Optional, Tuple from synapse.api.errors import Codes, SynapseError @@ -26,6 +27,15 @@ _string_with_symbols = string.digits + string.ascii_letters + ".,;:^&*-_+=#~@" # https://matrix.org/docs/spec/client_server/r0.6.0#post-matrix-client-r0-register-email-requesttoken client_secret_regex = re.compile(r"^[0-9a-zA-Z\.\=\_\-]+$") +# https://matrix.org/docs/spec/client_server/r0.6.1#matrix-content-mxc-uris, +# together with https://github.com/matrix-org/matrix-doc/issues/2177 which basically +# says "there is no grammar for media ids" +# +# The server_name part of this is purposely lax: use parse_and_validate_mxc for +# additional validation. +# +MXC_REGEX = re.compile("^mxc://([^/]+)/([^/#?]+)$") + # random_string and random_string_with_symbols are used for a range of things, # some cryptographically important, some less so. We use SystemRandom to make sure # we get cryptographically-secure randoms. @@ -59,6 +69,88 @@ def assert_valid_client_secret(client_secret): ) +def parse_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + try: + if server_name[-1] == "]": + # ipv6 literal, hopefully + return server_name, None + + domain_port = server_name.rsplit(":", 1) + domain = domain_port[0] + port = int(domain_port[1]) if domain_port[1:] else None + return domain, port + except Exception: + raise ValueError("Invalid server name '%s'" % server_name) + + +VALID_HOST_REGEX = re.compile("\\A[0-9a-zA-Z.-]+\\Z") + + +def parse_and_validate_server_name(server_name: str) -> Tuple[str, Optional[int]]: + """Split a server name into host/port parts and do some basic validation. + + Args: + server_name: server name to parse + + Returns: + host/port parts. + + Raises: + ValueError if the server name could not be parsed. + """ + host, port = parse_server_name(server_name) + + # these tests don't need to be bulletproof as we'll find out soon enough + # if somebody is giving us invalid data. What we *do* need is to be sure + # that nobody is sneaking IP literals in that look like hostnames, etc. + + # look for ipv6 literals + if host[0] == "[": + if host[-1] != "]": + raise ValueError("Mismatched [...] in server name '%s'" % (server_name,)) + return host, port + + # otherwise it should only be alphanumerics. + if not VALID_HOST_REGEX.match(host): + raise ValueError( + "Server name '%s' contains invalid characters" % (server_name,) + ) + + return host, port + + +def parse_and_validate_mxc_uri(mxc: str) -> Tuple[str, Optional[int], str]: + """Parse the given string as an MXC URI + + Checks that the "server name" part is a valid server name + + Args: + mxc: the (alleged) MXC URI to be checked + Returns: + hostname, port, media id + Raises: + ValueError if the URI cannot be parsed + """ + m = MXC_REGEX.match(mxc) + if not m: + raise ValueError("mxc URI %r did not match expected format" % (mxc,)) + server_name = m.group(1) + media_id = m.group(2) + host, port = parse_and_validate_server_name(server_name) + return host, port, media_id + + def shortstr(iterable: Iterable, maxitems: int = 5) -> str: """If iterable has maxitems or fewer, return the stringification of a list containing those items. diff --git a/tests/http/test_endpoint.py b/tests/http/test_endpoint.py index b2e9533b07..d06ea518ce 100644 --- a/tests/http/test_endpoint.py +++ b/tests/http/test_endpoint.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from synapse.http.endpoint import parse_and_validate_server_name, parse_server_name +from synapse.util.stringutils import parse_and_validate_server_name, parse_server_name from tests import unittest -- cgit 1.5.1 From dd8da8c5f6ac525a7456437913a03f68d4504605 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 26 Jan 2021 13:57:31 +0000 Subject: Precompute joined hosts and store in Redis (#9198) --- changelog.d/9198.misc | 1 + stubs/txredisapi.pyi | 12 +++- synapse/config/_base.pyi | 2 + synapse/federation/sender/__init__.py | 50 +++++++++----- synapse/handlers/federation.py | 5 ++ synapse/handlers/message.py | 42 ++++++++++++ synapse/replication/tcp/external_cache.py | 105 ++++++++++++++++++++++++++++++ synapse/replication/tcp/handler.py | 15 +---- synapse/server.py | 30 +++++++++ synapse/state/__init__.py | 11 +++- tests/replication/_base.py | 41 +++++++----- 11 files changed, 265 insertions(+), 49 deletions(-) create mode 100644 changelog.d/9198.misc create mode 100644 synapse/replication/tcp/external_cache.py (limited to 'synapse/handlers') diff --git a/changelog.d/9198.misc b/changelog.d/9198.misc new file mode 100644 index 0000000000..a6cb77fbb2 --- /dev/null +++ b/changelog.d/9198.misc @@ -0,0 +1 @@ +Precompute joined hosts and store in Redis. diff --git a/stubs/txredisapi.pyi b/stubs/txredisapi.pyi index bdc892ec82..618548a305 100644 --- a/stubs/txredisapi.pyi +++ b/stubs/txredisapi.pyi @@ -15,11 +15,21 @@ """Contains *incomplete* type hints for txredisapi. """ -from typing import List, Optional, Type, Union +from typing import Any, List, Optional, Type, Union class RedisProtocol: def publish(self, channel: str, message: bytes): ... async def ping(self) -> None: ... + async def set( + self, + key: str, + value: Any, + expire: Optional[int] = None, + pexpire: Optional[int] = None, + only_if_not_exists: bool = False, + only_if_exists: bool = False, + ) -> None: ... + async def get(self, key: str) -> Any: ... class SubscriberProtocol(RedisProtocol): def __init__(self, *args, **kwargs): ... diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..8ba669059a 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -18,6 +18,7 @@ from synapse.config import ( password_auth_providers, push, ratelimiting, + redis, registration, repository, room_directory, @@ -79,6 +80,7 @@ class RootConfig: roomdirectory: room_directory.RoomDirectoryConfig thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig tracer: tracer.TracerConfig + redis: redis.RedisConfig config_classes: List = ... def __init__(self) -> None: ... diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 604cfd1935..643b26ae6d 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -142,6 +142,8 @@ class FederationSender: self._wake_destinations_needing_catchup, ) + self._external_cache = hs.get_external_cache() + def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue: """Get or create a PerDestinationQueue for the given destination @@ -197,22 +199,40 @@ class FederationSender: if not event.internal_metadata.should_proactively_send(): return - try: - # Get the state from before the event. - # We need to make sure that this is the state from before - # the event and not from after it. - # Otherwise if the last member on a server in a room is - # banned then it won't receive the event because it won't - # be in the room after the ban. - destinations = await self.state.get_hosts_in_room_at_events( - event.room_id, event_ids=event.prev_event_ids() - ) - except Exception: - logger.exception( - "Failed to calculate hosts in room for event: %s", - event.event_id, + destinations = None # type: Optional[Set[str]] + if not event.prev_event_ids(): + # If there are no prev event IDs then the state is empty + # and so no remote servers in the room + destinations = set() + else: + # We check the external cache for the destinations, which is + # stored per state group. + + sg = await self._external_cache.get( + "event_to_prev_state_group", event.event_id ) - return + if sg: + destinations = await self._external_cache.get( + "get_joined_hosts", str(sg) + ) + + if destinations is None: + try: + # Get the state from before the event. + # We need to make sure that this is the state from before + # the event and not from after it. + # Otherwise if the last member on a server in a room is + # banned then it won't receive the event because it won't + # be in the room after the ban. + destinations = await self.state.get_hosts_in_room_at_events( + event.room_id, event_ids=event.prev_event_ids() + ) + except Exception: + logger.exception( + "Failed to calculate hosts in room for event: %s", + event.event_id, + ) + return destinations = { d diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index fd8de8696d..b6dc7f99b6 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -2093,6 +2093,11 @@ class FederationHandler(BaseHandler): if event.type == EventTypes.GuestAccess and not context.rejected: await self.maybe_kick_guest_users(event) + # If we are going to send this event over federation we precaclculate + # the joined hosts. + if event.internal_metadata.get_send_on_behalf_of(): + await self.event_creation_handler.cache_joined_hosts_for_event(event) + return context async def _check_for_soft_fail( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9dfeab09cd..e2a7d567fa 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -432,6 +432,8 @@ class EventCreationHandler: self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages + self._external_cache = hs.get_external_cache() + async def create_event( self, requester: Requester, @@ -939,6 +941,8 @@ class EventCreationHandler: await self.action_generator.handle_push_actions_for_event(event, context) + await self.cache_joined_hosts_for_event(event) + try: # If we're a worker we need to hit out to the master. writer_instance = self._events_shard_config.get_instance(event.room_id) @@ -978,6 +982,44 @@ class EventCreationHandler: await self.store.remove_push_actions_from_staging(event.event_id) raise + async def cache_joined_hosts_for_event(self, event: EventBase) -> None: + """Precalculate the joined hosts at the event, when using Redis, so that + external federation senders don't have to recalculate it themselves. + """ + + if not self._external_cache.is_enabled(): + return + + # We actually store two mappings, event ID -> prev state group, + # state group -> joined hosts, which is much more space efficient + # than event ID -> joined hosts. + # + # Note: We have to cache event ID -> prev state group, as we don't + # store that in the DB. + # + # Note: We always set the state group -> joined hosts cache, even if + # we already set it, so that the expiry time is reset. + + state_entry = await self.state.resolve_state_groups_for_events( + event.room_id, event_ids=event.prev_event_ids() + ) + + if state_entry.state_group: + joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry) + + await self._external_cache.set( + "event_to_prev_state_group", + event.event_id, + state_entry.state_group, + expiry_ms=60 * 60 * 1000, + ) + await self._external_cache.set( + "get_joined_hosts", + str(state_entry.state_group), + list(joined_hosts), + expiry_ms=60 * 60 * 1000, + ) + async def _validate_canonical_alias( self, directory_handler, room_alias_str: str, expected_room_id: str ) -> None: diff --git a/synapse/replication/tcp/external_cache.py b/synapse/replication/tcp/external_cache.py new file mode 100644 index 0000000000..34fa3ff5b3 --- /dev/null +++ b/synapse/replication/tcp/external_cache.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Any, Optional + +from prometheus_client import Counter + +from synapse.logging.context import make_deferred_yieldable +from synapse.util import json_decoder, json_encoder + +if TYPE_CHECKING: + from synapse.server import HomeServer + +set_counter = Counter( + "synapse_external_cache_set", + "Number of times we set a cache", + labelnames=["cache_name"], +) + +get_counter = Counter( + "synapse_external_cache_get", + "Number of times we get a cache", + labelnames=["cache_name", "hit"], +) + + +logger = logging.getLogger(__name__) + + +class ExternalCache: + """A cache backed by an external Redis. Does nothing if no Redis is + configured. + """ + + def __init__(self, hs: "HomeServer"): + self._redis_connection = hs.get_outbound_redis_connection() + + def _get_redis_key(self, cache_name: str, key: str) -> str: + return "cache_v1:%s:%s" % (cache_name, key) + + def is_enabled(self) -> bool: + """Whether the external cache is used or not. + + It's safe to use the cache when this returns false, the methods will + just no-op, but the function is useful to avoid doing unnecessary work. + """ + return self._redis_connection is not None + + async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None: + """Add the key/value to the named cache, with the expiry time given. + """ + + if self._redis_connection is None: + return + + set_counter.labels(cache_name).inc() + + # txredisapi requires the value to be string, bytes or numbers, so we + # encode stuff in JSON. + encoded_value = json_encoder.encode(value) + + logger.debug("Caching %s %s: %r", cache_name, key, encoded_value) + + return await make_deferred_yieldable( + self._redis_connection.set( + self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms, + ) + ) + + async def get(self, cache_name: str, key: str) -> Optional[Any]: + """Look up a key/value in the named cache. + """ + + if self._redis_connection is None: + return None + + result = await make_deferred_yieldable( + self._redis_connection.get(self._get_redis_key(cache_name, key)) + ) + + logger.debug("Got cache result %s %s: %r", cache_name, key, result) + + get_counter.labels(cache_name, result is not None).inc() + + if not result: + return None + + # For some reason the integers get magically converted back to integers + if isinstance(result, int): + return result + + return json_decoder.decode(result) diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index 58d46a5951..8ea8dcd587 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -286,13 +286,6 @@ class ReplicationCommandHandler: if hs.config.redis.redis_enabled: from synapse.replication.tcp.redis import ( RedisDirectTcpReplicationClientFactory, - lazyConnection, - ) - - logger.info( - "Connecting to redis (host=%r port=%r)", - hs.config.redis_host, - hs.config.redis_port, ) # First let's ensure that we have a ReplicationStreamer started. @@ -303,13 +296,7 @@ class ReplicationCommandHandler: # connection after SUBSCRIBE is called). # First create the connection for sending commands. - outbound_redis_connection = lazyConnection( - hs=hs, - host=hs.config.redis_host, - port=hs.config.redis_port, - password=hs.config.redis.redis_password, - reconnect=True, - ) + outbound_redis_connection = hs.get_outbound_redis_connection() # Now create the factory/connection for the subscription stream. self._factory = RedisDirectTcpReplicationClientFactory( diff --git a/synapse/server.py b/synapse/server.py index 9cdda83aa1..9bdd3177d7 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -103,6 +103,7 @@ from synapse.notifier import Notifier from synapse.push.action_generator import ActionGenerator from synapse.push.pusherpool import PusherPool from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.tcp.external_cache import ExternalCache from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.resource import ReplicationStreamer from synapse.replication.tcp.streams import STREAMS_MAP, Stream @@ -128,6 +129,8 @@ from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) if TYPE_CHECKING: + from txredisapi import RedisProtocol + from synapse.handlers.oidc_handler import OidcHandler from synapse.handlers.saml_handler import SamlHandler @@ -716,6 +719,33 @@ class HomeServer(metaclass=abc.ABCMeta): def get_account_data_handler(self) -> AccountDataHandler: return AccountDataHandler(self) + @cache_in_self + def get_external_cache(self) -> ExternalCache: + return ExternalCache(self) + + @cache_in_self + def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]: + if not self.config.redis.redis_enabled: + return None + + # We only want to import redis module if we're using it, as we have + # `txredisapi` as an optional dependency. + from synapse.replication.tcp.redis import lazyConnection + + logger.info( + "Connecting to redis (host=%r port=%r) for external cache", + self.config.redis_host, + self.config.redis_port, + ) + + return lazyConnection( + hs=self, + host=self.config.redis_host, + port=self.config.redis_port, + password=self.config.redis.redis_password, + reconnect=True, + ) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 84f59c7d85..3bd9ff8ca0 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -310,6 +310,7 @@ class StateHandler: state_group_before_event = None state_group_before_event_prev_group = None deltas_to_state_group_before_event = None + entry = None else: # otherwise, we'll need to resolve the state across the prev_events. @@ -340,9 +341,13 @@ class StateHandler: current_state_ids=state_ids_before_event, ) - # XXX: can we update the state cache entry for the new state group? or - # could we set a flag on resolve_state_groups_for_events to tell it to - # always make a state group? + # Assign the new state group to the cached state entry. + # + # Note that this can race in that we could generate multiple state + # groups for the same state entry, but that is just inefficient + # rather than dangerous. + if entry and entry.state_group is None: + entry.state_group = state_group_before_event # # now if it's not a state event, we're done diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 3379189785..d5dce1f83f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Fake in memory Redis server that servers can connect to. self._redis_server = FakeRedisPubSubServer() + # We may have an attempt to connect to redis for the external cache already. + self.connect_any_redis_attempts() + store = self.hs.get_datastore() self.database_pool = store.db_pool @@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): fake one. """ clients = self.reactor.tcpClients - self.assertEqual(len(clients), 1) - (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") - self.assertEqual(port, 6379) + while clients: + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) - client_protocol = client_factory.buildProtocol(None) - server_protocol = self._redis_server.buildProtocol(None) + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) - client_to_server_transport = FakeTransport( - server_protocol, self.reactor, client_protocol - ) - client_protocol.makeConnection(client_to_server_transport) - - server_to_client_transport = FakeTransport( - client_protocol, self.reactor, server_protocol - ) - server_protocol.makeConnection(server_to_client_transport) + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) - return client_to_server_transport, server_to_client_transport + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) class TestReplicationDataHandler(GenericWorkerReplicationHandler): @@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol): (channel,) = args self._server.add_subscriber(self) self.send(["subscribe", channel, 1]) + + # Since we use SET/GET to cache things we can safely no-op them. + elif command == b"SET": + self.send("OK") + elif command == b"GET": + self.send(None) else: raise Exception("Unknown command") @@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol): # We assume bytes are just unicode strings. obj = obj.decode("utf-8") + if obj is None: + return "$-1\r\n" if isinstance(obj, str): return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) if isinstance(obj, int): -- cgit 1.5.1 From 26837d5dbeae211968b3d52cdc10f005ba612a9f Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jan 2021 10:49:25 -0500 Subject: Do not require the CAS service URL setting (use public_baseurl instead). (#9199) The current configuration is handled for backwards compatibility, but is considered deprecated. --- changelog.d/9199.removal | 1 + docs/sample_config.yaml | 4 ---- synapse/config/cas.py | 12 +++++++----- synapse/config/oidc_config.py | 3 +-- synapse/handlers/cas_handler.py | 6 +----- 5 files changed, 10 insertions(+), 16 deletions(-) create mode 100644 changelog.d/9199.removal (limited to 'synapse/handlers') diff --git a/changelog.d/9199.removal b/changelog.d/9199.removal new file mode 100644 index 0000000000..fbd2916cbf --- /dev/null +++ b/changelog.d/9199.removal @@ -0,0 +1 @@ +The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 87bfe22237..c2ccd68f3a 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1878,10 +1878,6 @@ cas_config: # #server_url: "https://cas-server.com" - # The public URL of the homeserver. - # - #service_url: "https://homeserver.domain.com:8448" - # The attribute of the CAS response to use as the display name. # # If unset, no displayname will be set. diff --git a/synapse/config/cas.py b/synapse/config/cas.py index c7877b4095..b226890c2a 100644 --- a/synapse/config/cas.py +++ b/synapse/config/cas.py @@ -30,7 +30,13 @@ class CasConfig(Config): if self.cas_enabled: self.cas_server_url = cas_config["server_url"] - self.cas_service_url = cas_config["service_url"] + public_base_url = cas_config.get("service_url") or self.public_baseurl + if public_base_url[-1] != "/": + public_base_url += "/" + # TODO Update this to a _synapse URL. + self.cas_service_url = ( + public_base_url + "_matrix/client/r0/login/cas/ticket" + ) self.cas_displayname_attribute = cas_config.get("displayname_attribute") self.cas_required_attributes = cas_config.get("required_attributes") or {} else: @@ -53,10 +59,6 @@ class CasConfig(Config): # #server_url: "https://cas-server.com" - # The public URL of the homeserver. - # - #service_url: "https://homeserver.domain.com:8448" - # The attribute of the CAS response to use as the display name. # # If unset, no displayname will be set. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index bfeceeed18..0162d7f7b0 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -54,8 +54,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - public_baseurl = self.public_baseurl - self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" + self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 0f342c607b..21b6bc4992 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -99,11 +99,7 @@ class CasHandler: Returns: The URL to use as a "service" parameter. """ - return "%s%s?%s" % ( - self._cas_service_url, - "/_matrix/client/r0/login/cas/ticket", - urllib.parse.urlencode(args), - ) + return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),) async def _validate_ticket( self, ticket: str, service_args: Dict[str, str] -- cgit 1.5.1 From 1baab2035265cf2543fe3c0ef5412c1ac0740c7e Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 26 Jan 2021 10:50:21 -0500 Subject: Add type hints to various handlers. (#9223) With this change all handlers except the e2e_* ones have type hints enabled. --- changelog.d/9223.misc | 1 + mypy.ini | 14 ++++ synapse/handlers/acme.py | 12 ++-- synapse/handlers/acme_issuing_service.py | 27 +++++--- synapse/handlers/groups_local.py | 83 ++++++++++++------------ synapse/handlers/search.py | 38 ++++++----- synapse/handlers/set_password.py | 10 +-- synapse/handlers/state_deltas.py | 14 +++- synapse/handlers/stats.py | 39 ++++++----- synapse/handlers/typing.py | 69 +++++++++++--------- synapse/handlers/user_directory.py | 9 +-- synapse/storage/databases/main/search.py | 3 +- synapse/storage/databases/main/stats.py | 22 ++++--- synapse/storage/databases/main/user_directory.py | 2 +- 14 files changed, 205 insertions(+), 138 deletions(-) create mode 100644 changelog.d/9223.misc (limited to 'synapse/handlers') diff --git a/changelog.d/9223.misc b/changelog.d/9223.misc new file mode 100644 index 0000000000..9d44b621c9 --- /dev/null +++ b/changelog.d/9223.misc @@ -0,0 +1 @@ +Add type hints to handlers code. diff --git a/mypy.ini b/mypy.ini index bd99069c81..f3700d323c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -26,6 +26,8 @@ files = synapse/handlers/_base.py, synapse/handlers/account_data.py, synapse/handlers/account_validity.py, + synapse/handlers/acme.py, + synapse/handlers/acme_issuing_service.py, synapse/handlers/admin.py, synapse/handlers/appservice.py, synapse/handlers/auth.py, @@ -36,6 +38,7 @@ files = synapse/handlers/directory.py, synapse/handlers/events.py, synapse/handlers/federation.py, + synapse/handlers/groups_local.py, synapse/handlers/identity.py, synapse/handlers/initial_sync.py, synapse/handlers/message.py, @@ -52,8 +55,13 @@ files = synapse/handlers/room_member.py, synapse/handlers/room_member_worker.py, synapse/handlers/saml_handler.py, + synapse/handlers/search.py, + synapse/handlers/set_password.py, synapse/handlers/sso.py, + synapse/handlers/state_deltas.py, + synapse/handlers/stats.py, synapse/handlers/sync.py, + synapse/handlers/typing.py, synapse/handlers/user_directory.py, synapse/handlers/ui_auth, synapse/http/client.py, @@ -194,3 +202,9 @@ ignore_missing_imports = True [mypy-hiredis] ignore_missing_imports = True + +[mypy-josepy.*] +ignore_missing_imports = True + +[mypy-txacme.*] +ignore_missing_imports = True diff --git a/synapse/handlers/acme.py b/synapse/handlers/acme.py index 8476256a59..5ecb2da1ac 100644 --- a/synapse/handlers/acme.py +++ b/synapse/handlers/acme.py @@ -14,6 +14,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING import twisted import twisted.internet.error @@ -22,6 +23,9 @@ from twisted.web.resource import Resource from synapse.app import check_bind_error +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) ACME_REGISTER_FAIL_ERROR = """ @@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC class AcmeHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.reactor = hs.get_reactor() self._acme_domain = hs.config.acme_domain - async def start_listening(self): + async def start_listening(self) -> None: from synapse.handlers import acme_issuing_service # Configure logging for txacme, if you need to debug @@ -85,7 +89,7 @@ class AcmeHandler: logger.error(ACME_REGISTER_FAIL_ERROR) raise - async def provision_certificate(self): + async def provision_certificate(self) -> None: logger.warning("Reprovisioning %s", self._acme_domain) @@ -110,5 +114,3 @@ class AcmeHandler: except Exception: logger.exception("Failed saving!") raise - - return True diff --git a/synapse/handlers/acme_issuing_service.py b/synapse/handlers/acme_issuing_service.py index 7294649d71..ae2a9dd9c2 100644 --- a/synapse/handlers/acme_issuing_service.py +++ b/synapse/handlers/acme_issuing_service.py @@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to imported conditionally. """ import logging +from typing import Dict, Iterable, List import attr +import pem from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import serialization from josepy import JWKRSA @@ -36,20 +38,27 @@ from txacme.util import generate_private_key from zope.interface import implementer from twisted.internet import defer +from twisted.internet.interfaces import IReactorTCP from twisted.python.filepath import FilePath from twisted.python.url import URL +from twisted.web.resource import IResource logger = logging.getLogger(__name__) -def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource): +def create_issuing_service( + reactor: IReactorTCP, + acme_url: str, + account_key_file: str, + well_known_resource: IResource, +) -> AcmeIssuingService: """Create an ACME issuing service, and attach it to a web Resource Args: reactor: twisted reactor - acme_url (str): URL to use to request certificates - account_key_file (str): where to store the account key - well_known_resource (twisted.web.IResource): web resource for .well-known. + acme_url: URL to use to request certificates + account_key_file: where to store the account key + well_known_resource: web resource for .well-known. we will attach a child resource for "acme-challenge". Returns: @@ -83,18 +92,20 @@ class ErsatzStore: A store that only stores in memory. """ - certs = attr.ib(default=attr.Factory(dict)) + certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict)) - def store(self, server_name, pem_objects): + def store( + self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject] + ) -> defer.Deferred: self.certs[server_name] = [o.as_bytes() for o in pem_objects] return defer.succeed(None) -def load_or_create_client_key(key_file): +def load_or_create_client_key(key_file: str) -> JWKRSA: """Load the ACME account key from a file, creating it if it does not exist. Args: - key_file (str): name of the file to use as the account key + key_file: name of the file to use as the account key """ # this is based on txacme.endpoint.load_or_create_client_key, but doesn't # hardcode the 'client.key' filename diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index df29edeb83..71f11ef94a 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -15,9 +15,13 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Dict, Iterable, List, Set from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError -from synapse.types import GroupID, get_domain_from_id +from synapse.types import GroupID, JsonDict, get_domain_from_id + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -56,7 +60,7 @@ def _create_rerouter(func_name): class GroupsLocalWorkerHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.room_list_handler = hs.get_room_list_handler() @@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler: get_group_role = _create_rerouter("get_group_role") get_group_roles = _create_rerouter("get_group_roles") - async def get_group_summary(self, group_id, requester_user_id): + async def get_group_summary( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get the group summary for a group. If the group is remote we check that the users have valid attestations. @@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler: return res - async def get_users_in_group(self, group_id, requester_user_id): + async def get_users_in_group( + self, group_id: str, requester_user_id: str + ) -> JsonDict: """Get users in a group """ if self.is_mine_id(group_id): - res = await self.groups_server_handler.get_users_in_group( + return await self.groups_server_handler.get_users_in_group( group_id, requester_user_id ) - return res group_server_name = get_domain_from_id(group_id) @@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler: return res - async def get_joined_groups(self, user_id): + async def get_joined_groups(self, user_id: str) -> JsonDict: group_ids = await self.store.get_joined_groups(user_id) return {"groups": group_ids} - async def get_publicised_groups_for_user(self, user_id): + async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict: if self.hs.is_mine_id(user_id): result = await self.store.get_publicised_groups_for_user(user_id) @@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler: # TODO: Verify attestations return {"groups": result} - async def bulk_get_publicised_groups(self, user_ids, proxy=True): - destinations = {} + async def bulk_get_publicised_groups( + self, user_ids: Iterable[str], proxy: bool = True + ) -> JsonDict: + destinations = {} # type: Dict[str, Set[str]] local_users = set() for user_id in user_ids: @@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler: raise SynapseError(400, "Some user_ids are not local") results = {} - failed_results = [] + failed_results = [] # type: List[str] for destination, dest_user_ids in destinations.items(): try: r = await self.transport_client.bulk_get_publicised_groups( @@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler: class GroupsLocalHandler(GroupsLocalWorkerHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) # Ensure attestations get renewed @@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): set_group_join_policy = _create_rerouter("set_group_join_policy") - async def create_group(self, group_id, user_id, content): + async def create_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Create a group """ @@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): local_attestation = None remote_attestation = None else: - local_attestation = self.attestations.create_attestation(group_id, user_id) - content["attestation"] = local_attestation - - content["user_profile"] = await self.profile_handler.get_profile(user_id) - - try: - res = await self.transport_client.create_group( - get_domain_from_id(group_id), group_id, user_id, content - ) - except HttpResponseException as e: - raise e.to_synapse_error() - except RequestSendFailed: - raise SynapseError(502, "Failed to contact group server") - - remote_attestation = res["attestation"] - await self.attestations.verify_attestation( - remote_attestation, - group_id=group_id, - user_id=user_id, - server_name=get_domain_from_id(group_id), - ) + raise SynapseError(400, "Unable to create remote groups") is_publicised = content.get("publicise", False) token = await self.store.register_user_group_membership( @@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def join_group(self, group_id, user_id, content): + async def join_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Request to join a group """ if self.is_mine_id(group_id): @@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def accept_invite(self, group_id, user_id, content): + async def accept_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """Accept an invite to a group """ if self.is_mine_id(group_id): @@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {} - async def invite(self, group_id, user_id, requester_user_id, config): + async def invite( + self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict + ) -> JsonDict: """Invite a user to a group """ content = {"requester_user_id": requester_user_id, "config": config} @@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def on_invite(self, group_id, user_id, content): + async def on_invite( + self, group_id: str, user_id: str, content: JsonDict + ) -> JsonDict: """One of our users were invited to a group """ # TODO: Support auto join and rejection @@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return {"state": "invite", "user_profile": user_profile} async def remove_user_from_group( - self, group_id, user_id, requester_user_id, content - ): + self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict + ) -> JsonDict: """Remove a user from a group """ if user_id == requester_user_id: @@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler): return res - async def user_removed_from_group(self, group_id, user_id, content): + async def user_removed_from_group( + self, group_id: str, user_id: str, content: JsonDict + ) -> None: """One of our users was removed/kicked from a group """ # TODO: Check if user in group diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 66f1bbcfc4..94062e79cb 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -15,23 +15,28 @@ import itertools import logging -from typing import Iterable +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional from unpaddedbase64 import decode_base64, encode_base64 from synapse.api.constants import EventTypes, Membership from synapse.api.errors import NotFoundError, SynapseError from synapse.api.filtering import Filter +from synapse.events import EventBase from synapse.storage.state import StateFilter +from synapse.types import JsonDict, UserID from synapse.visibility import filter_events_for_client from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SearchHandler(BaseHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._event_serializer = hs.get_event_client_serializer() self.storage = hs.get_storage() @@ -87,13 +92,15 @@ class SearchHandler(BaseHandler): return historical_room_ids - async def search(self, user, content, batch=None): + async def search( + self, user: UserID, content: JsonDict, batch: Optional[str] = None + ) -> JsonDict: """Performs a full text search for a user. Args: - user (UserID) - content (dict): Search parameters - batch (str): The next_batch parameter. Used for pagination. + user + content: Search parameters + batch: The next_batch parameter. Used for pagination. Returns: dict to be returned to the client with results of search @@ -186,7 +193,7 @@ class SearchHandler(BaseHandler): # If doing a subset of all rooms seearch, check if any of the rooms # are from an upgraded room, and search their contents as well if search_filter.rooms: - historical_room_ids = [] + historical_room_ids = [] # type: List[str] for room_id in search_filter.rooms: # Add any previous rooms to the search if they exist ids = await self.get_old_rooms_from_upgraded_room(room_id) @@ -209,8 +216,10 @@ class SearchHandler(BaseHandler): rank_map = {} # event_id -> rank of event allowed_events = [] - room_groups = {} # Holds result of grouping by room, if applicable - sender_group = {} # Holds result of grouping by sender, if applicable + # Holds result of grouping by room, if applicable + room_groups = {} # type: Dict[str, JsonDict] + # Holds result of grouping by sender, if applicable + sender_group = {} # type: Dict[str, JsonDict] # Holds the next_batch for the entire result set if one of those exists global_next_batch = None @@ -254,7 +263,7 @@ class SearchHandler(BaseHandler): s["results"].append(e.event_id) elif order_by == "recent": - room_events = [] + room_events = [] # type: List[EventBase] i = 0 pagination_token = batch_token @@ -418,13 +427,10 @@ class SearchHandler(BaseHandler): state_results = {} if include_state: - rooms = {e.room_id for e in allowed_events} - for room_id in rooms: + for room_id in {e.room_id for e in allowed_events}: state = await self.state_handler.get_current_state(room_id) state_results[room_id] = list(state.values()) - state_results.values() - # We're now about to serialize the events. We should not make any # blocking calls after this. Otherwise the 'age' will be wrong @@ -448,9 +454,9 @@ class SearchHandler(BaseHandler): if state_results: s = {} - for room_id, state in state_results.items(): + for room_id, state_events in state_results.items(): s[room_id] = await self._event_serializer.serialize_events( - state, time_now + state_events, time_now ) rooms_cat_res["state"] = s diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py index a5d67f828f..84af2dde7e 100644 --- a/synapse/handlers/set_password.py +++ b/synapse/handlers/set_password.py @@ -13,24 +13,26 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Optional +from typing import TYPE_CHECKING, Optional from synapse.api.errors import Codes, StoreError, SynapseError from synapse.types import Requester from ._base import BaseHandler +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class SetPasswordHandler(BaseHandler): """Handler which deals with changing user account passwords""" - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) self._auth_handler = hs.get_auth_handler() self._device_handler = hs.get_device_handler() - self._password_policy_handler = hs.get_password_policy_handler() async def set_password( self, @@ -38,7 +40,7 @@ class SetPasswordHandler(BaseHandler): password_hash: str, logout_devices: bool, requester: Optional[Requester] = None, - ): + ) -> None: if not self.hs.config.password_localdb_enabled: raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN) diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py index fb4f70e8e2..b3f9875358 100644 --- a/synapse/handlers/state_deltas.py +++ b/synapse/handlers/state_deltas.py @@ -14,15 +14,25 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) class StateDeltasHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() - async def _get_key_change(self, prev_event_id, event_id, key_name, public_value): + async def _get_key_change( + self, + prev_event_id: Optional[str], + event_id: Optional[str], + key_name: str, + public_value: str, + ) -> Optional[bool]: """Given two events check if the `key_name` field in content changed from not matching `public_value` to doing so. diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index dc62b21c06..d261d7cd4e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -12,13 +12,19 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from collections import Counter +from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple + +from typing_extensions import Counter as CounterType from synapse.api.constants import EventTypes, Membership from synapse.metrics import event_processing_positions from synapse.metrics.background_process_metrics import run_as_background_process +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer logger = logging.getLogger(__name__) @@ -31,7 +37,7 @@ class StatsHandler: Heavily derived from UserDirectoryHandler """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastore() self.state = hs.get_state_handler() @@ -44,7 +50,7 @@ class StatsHandler: self.stats_enabled = hs.config.stats_enabled # The current position in the current_state_delta stream - self.pos = None + self.pos = None # type: Optional[int] # Guard to ensure we only process deltas one at a time self._is_processing = False @@ -56,7 +62,7 @@ class StatsHandler: # we start populating stats self.clock.call_later(0, self.notify_new_event) - def notify_new_event(self): + def notify_new_event(self) -> None: """Called when there may be more deltas to process """ if not self.stats_enabled or self._is_processing: @@ -72,7 +78,7 @@ class StatsHandler: run_as_background_process("stats.notify_new_event", process) - async def _unsafe_process(self): + async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB if self.pos is None: self.pos = await self.store.get_stats_positions() @@ -110,10 +116,10 @@ class StatsHandler: ) for room_id, fields in room_count.items(): - room_deltas.setdefault(room_id, {}).update(fields) + room_deltas.setdefault(room_id, Counter()).update(fields) for user_id, fields in user_count.items(): - user_deltas.setdefault(user_id, {}).update(fields) + user_deltas.setdefault(user_id, Counter()).update(fields) logger.debug("room_deltas: %s", room_deltas) logger.debug("user_deltas: %s", user_deltas) @@ -131,19 +137,20 @@ class StatsHandler: self.pos = max_pos - async def _handle_deltas(self, deltas): + async def _handle_deltas( + self, deltas: Iterable[JsonDict] + ) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]: """Called with the state deltas to process Returns: - tuple[dict[str, Counter], dict[str, counter]] Two dicts: the room deltas and the user deltas, mapping from room/user ID to changes in the various fields. """ - room_to_stats_deltas = {} - user_to_stats_deltas = {} + room_to_stats_deltas = {} # type: Dict[str, CounterType[str]] + user_to_stats_deltas = {} # type: Dict[str, CounterType[str]] - room_to_state_updates = {} + room_to_state_updates = {} # type: Dict[str, Dict[str, Any]] for delta in deltas: typ = delta["type"] @@ -173,7 +180,7 @@ class StatsHandler: ) continue - event_content = {} + event_content = {} # type: JsonDict sender = None if event_id is not None: @@ -257,13 +264,13 @@ class StatsHandler: ) if has_changed_joinedness: - delta = +1 if membership == Membership.JOIN else -1 + membership_delta = +1 if membership == Membership.JOIN else -1 user_to_stats_deltas.setdefault(user_id, Counter())[ "joined_rooms" - ] += delta + ] += membership_delta - room_stats_delta["local_users_in_room"] += delta + room_stats_delta["local_users_in_room"] += membership_delta elif typ == EventTypes.Create: room_state["is_federatable"] = ( diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index e919a8f9ed..3f0dfc7a74 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -15,13 +15,13 @@ import logging import random from collections import namedtuple -from typing import TYPE_CHECKING, List, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import TypingStream -from synapse.types import JsonDict, UserID, get_domain_from_id +from synapse.types import JsonDict, Requester, UserID, get_domain_from_id from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.metrics import Measure from synapse.util.wheel_timer import WheelTimer @@ -65,17 +65,17 @@ class FollowerTypingHandler: ) # map room IDs to serial numbers - self._room_serials = {} + self._room_serials = {} # type: Dict[str, int] # map room IDs to sets of users currently typing - self._room_typing = {} + self._room_typing = {} # type: Dict[str, Set[str]] - self._member_last_federation_poke = {} + self._member_last_federation_poke = {} # type: Dict[RoomMember, int] self.wheel_timer = WheelTimer(bucket_size=5000) self._latest_room_serial = 0 self.clock.looping_call(self._handle_timeouts, 5000) - def _reset(self): + def _reset(self) -> None: """Reset the typing handler's data caches. """ # map room IDs to serial numbers @@ -86,7 +86,7 @@ class FollowerTypingHandler: self._member_last_federation_poke = {} self.wheel_timer = WheelTimer(bucket_size=5000) - def _handle_timeouts(self): + def _handle_timeouts(self) -> None: logger.debug("Checking for typing timeouts") now = self.clock.time_msec() @@ -96,7 +96,7 @@ class FollowerTypingHandler: for member in members: self._handle_timeout_for_member(now, member) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: if not self.is_typing(member): # Nothing to do if they're no longer typing return @@ -114,10 +114,10 @@ class FollowerTypingHandler: # each person typing. self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000) - def is_typing(self, member): + def is_typing(self, member: RoomMember) -> bool: return member.user_id in self._room_typing.get(member.room_id, []) - async def _push_remote(self, member, typing): + async def _push_remote(self, member: RoomMember, typing: bool) -> None: if not self.federation: return @@ -148,7 +148,7 @@ class FollowerTypingHandler: def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: """Should be called whenever we receive updates for typing stream. """ @@ -178,7 +178,7 @@ class FollowerTypingHandler: async def _send_changes_in_typing_to_remotes( self, room_id: str, prev_typing: Set[str], now_typing: Set[str] - ): + ) -> None: """Process a change in typing of a room from replication, sending EDUs for any local users. """ @@ -194,12 +194,12 @@ class FollowerTypingHandler: if self.is_mine_id(user_id): await self._push_remote(RoomMember(room_id, user_id), False) - def get_current_token(self): + def get_current_token(self) -> int: return self._latest_room_serial class TypingWriterHandler(FollowerTypingHandler): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) assert hs.config.worker.writers.typing == hs.get_instance_name() @@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler): hs.get_distributor().observe("user_left_room", self.user_left_room) - self._member_typing_until = {} # clock time we expect to stop + # clock time we expect to stop + self._member_typing_until = {} # type: Dict[RoomMember, int] # caches which room_ids changed at which serials self._typing_stream_change_cache = StreamChangeCache( "TypingStreamChangeCache", self._latest_room_serial ) - def _handle_timeout_for_member(self, now: int, member: RoomMember): + def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: super()._handle_timeout_for_member(now, member) if not self.is_typing(member): @@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) return - async def started_typing(self, target_user, requester, room_id, timeout): + async def started_typing( + self, target_user: UserID, requester: Requester, room_id: str, timeout: int + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler): if was_present: # No point sending another notification - return None + return self._push_update(member=member, typing=True) - async def stopped_typing(self, target_user, requester, room_id): + async def stopped_typing( + self, target_user: UserID, requester: Requester, room_id: str + ) -> None: target_user_id = target_user.to_string() auth_user_id = requester.user.to_string() @@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler): self._stopped_typing(member) - def user_left_room(self, user, room_id): + def user_left_room(self, user: UserID, room_id: str) -> None: user_id = user.to_string() if self.is_mine_id(user_id): member = RoomMember(room_id=room_id, user_id=user_id) self._stopped_typing(member) - def _stopped_typing(self, member): + def _stopped_typing(self, member: RoomMember) -> None: if member.user_id not in self._room_typing.get(member.room_id, set()): # No point - return None + return self._member_typing_until.pop(member, None) self._member_last_federation_poke.pop(member, None) self._push_update(member=member, typing=False) - def _push_update(self, member, typing): + def _push_update(self, member: RoomMember, typing: bool) -> None: if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. run_as_background_process( @@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler): self._push_update_local(member=member, typing=typing) - async def _recv_edu(self, origin, content): + async def _recv_edu(self, origin: str, content: JsonDict) -> None: room_id = content["room_id"] user_id = content["user_id"] @@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler): self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT) self._push_update_local(member=member, typing=content["typing"]) - def _push_update_local(self, member, typing): + def _push_update_local(self, member: RoomMember, typing: bool) -> None: room_set = self._room_typing.setdefault(member.room_id, set()) if typing: room_set.add(member.user_id) @@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler): changed_rooms = self._typing_stream_change_cache.get_all_entities_changed( last_id - ) + ) # type: Optional[Iterable[str]] if changed_rooms is None: changed_rooms = self._room_serials @@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler): def process_replication_rows( self, token: int, rows: List[TypingStream.TypingStreamRow] - ): + ) -> None: # The writing process should never get updates from replication. raise Exception("Typing writer instance got typing info over replication") class TypingNotificationEventSource: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs self.clock = hs.get_clock() # We can't call get_typing_handler here because there's a cycle: @@ -427,7 +432,7 @@ class TypingNotificationEventSource: # self.get_typing_handler = hs.get_typing_handler - def _make_event_for(self, room_id): + def _make_event_for(self, room_id: str) -> JsonDict: typing = self.get_typing_handler()._room_typing[room_id] return { "type": "m.typing", @@ -462,7 +467,9 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - async def get_new_events(self, from_key, room_ids, **kwargs): + async def get_new_events( + self, from_key: int, room_ids: Iterable[str], **kwargs + ) -> Tuple[List[JsonDict], int]: with Measure(self.clock, "typing.get_new_events"): from_key = int(from_key) handler = self.get_typing_handler() @@ -478,5 +485,5 @@ class TypingNotificationEventSource: return (events, handler._latest_room_serial) - def get_current_key(self): + def get_current_key(self) -> int: return self.get_typing_handler()._latest_room_serial diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index d4651c8348..8aedf5072e 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler): if self.pos is None: self.pos = await self.store.get_user_directory_stream_pos() - # If still None then the initial background update hasn't happened yet - if self.pos is None: - return None - # Loop round handling deltas until we're up to date while True: with Measure(self.clock, "user_dir_delta"): @@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler): if change: # The user joined event = await self.store.get_event(event_id, allow_none=True) + # It isn't expected for this event to not exist, but we + # don't want the entire background process to break. + if event is None: + continue + profile = ProfileInfo( avatar_url=event.content.get("avatar_url"), display_name=event.content.get("displayname"), diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py index 871af64b11..f5e7d9ef98 100644 --- a/synapse/storage/databases/main/search.py +++ b/synapse/storage/databases/main/search.py @@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla from synapse.storage.database import DatabasePool from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.engines import PostgresEngine, Sqlite3Engine +from synapse.types import Collection logger = logging.getLogger(__name__) @@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore): async def search_rooms( self, - room_ids: List[str], + room_ids: Collection[str], search_term: str, keys: List[str], limit, diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 0cdb3ec1f7..d421d18f8d 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -15,11 +15,12 @@ # limitations under the License. import logging -from collections import Counter from enum import Enum from itertools import chain from typing import Any, Dict, List, Optional, Tuple +from typing_extensions import Counter + from twisted.internet.defer import DeferredLock from synapse.api.constants import EventTypes, Membership @@ -319,7 +320,9 @@ class StatsStore(StateDeltasStore): return slice_list @cached() - async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int: + async def get_earliest_token_for_stats( + self, stats_type: str, id: str + ) -> Optional[int]: """ Fetch the "earliest token". This is used by the room stats delta processor to ignore deltas that have been processed between the @@ -339,7 +342,7 @@ class StatsStore(StateDeltasStore): ) async def bulk_update_stats_delta( - self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int + self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int ) -> None: """Bulk update stats tables for a given stream_id and updates the stats incremental position. @@ -665,7 +668,7 @@ class StatsStore(StateDeltasStore): async def get_changes_room_total_events_and_bytes( self, min_pos: int, max_pos: int - ) -> Dict[str, Dict[str, int]]: + ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: """Fetches the counts of events in the given range of stream IDs. Args: @@ -683,18 +686,19 @@ class StatsStore(StateDeltasStore): max_pos, ) - def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos): + def get_changes_room_total_events_and_bytes_txn( + self, txn, low_pos: int, high_pos: int + ) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]: """Gets the total_events and total_event_bytes counts for rooms and senders, in a range of stream_orderings (including backfilled events). Args: txn - low_pos (int): Low stream ordering - high_pos (int): High stream ordering + low_pos: Low stream ordering + high_pos: High stream ordering Returns: - tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The - room and user deltas for total_events/total_event_bytes in the + The room and user deltas for total_events/total_event_bytes in the format of `stats_id` -> fields """ diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index ef11f1c3b3..7b9729da09 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): desc="get_user_in_directory", ) - async def update_user_directory_stream_pos(self, stream_id: str) -> None: + async def update_user_directory_stream_pos(self, stream_id: int) -> None: await self.db_pool.simple_update_one( table="user_directory_stream_pos", keyvalues={}, -- cgit 1.5.1 From a737cc27134c50059440ca33510b0baea53b4225 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 27 Jan 2021 12:41:24 +0000 Subject: Implement MSC2858 support (#9183) Fixes #8928. --- changelog.d/9183.feature | 1 + synapse/config/_base.pyi | 2 + synapse/config/experimental.py | 29 ++++++++++++ synapse/config/homeserver.py | 2 + synapse/handlers/sso.py | 23 +++++++--- synapse/http/server.py | 44 ++++++++++++++---- synapse/rest/client/v1/login.py | 55 ++++++++++++++++++++--- tests/rest/client/v1/test_login.py | 92 ++++++++++++++++++++++++++++++++++++++ tests/utils.py | 3 +- 9 files changed, 230 insertions(+), 21 deletions(-) create mode 100644 changelog.d/9183.feature create mode 100644 synapse/config/experimental.py (limited to 'synapse/handlers') diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature new file mode 100644 index 0000000000..2d5c735042 --- /dev/null +++ b/changelog.d/9183.feature @@ -0,0 +1 @@ +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858). diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 29aa064e57..3ccea4b02d 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -9,6 +9,7 @@ from synapse.config import ( consent_config, database, emailconfig, + experimental, groups, jwt_config, key, @@ -48,6 +49,7 @@ def path_exists(file_path: str): ... class RootConfig: server: server.ServerConfig + experimental: experimental.ExperimentalConfig tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py new file mode 100644 index 0000000000..b1c1c51e4d --- /dev/null +++ b/synapse/config/experimental.py @@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.types import JsonDict + + +class ExperimentalConfig(Config): + """Config section for enabling experimental features""" + + section = "experimental" + + def read_config(self, config: JsonDict, **kwargs): + experimental = config.get("experimental_features") or {} + + # MSC2858 (multiple SSO identity providers) + self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py index 4bd2b3587b..64a2429f77 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py @@ -24,6 +24,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .experimental import ExperimentalConfig from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, + ExperimentalConfig, TlsConfig, FederationConfig, CacheConfig, diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d493327a10..afc1341d09 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request from synapse.api.constants import LoginType -from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html @@ -235,7 +235,10 @@ class SsoHandler: respond_with_html(request, code, html) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes, + self, + request: SynapseRequest, + client_redirect_url: bytes, + idp_id: Optional[str], ) -> str: """Handle a request to /login/sso/redirect @@ -243,6 +246,7 @@ class SsoHandler: request: incoming HTTP request client_redirect_url: the URL that we should redirect the client to after login. + idp_id: optional identity provider chosen by the client Returns: the URI to redirect to @@ -252,10 +256,19 @@ class SsoHandler: 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED ) + # if the client chose an IdP, use that + idp = None # type: Optional[SsoIdentityProvider] + if idp_id: + idp = self._identity_providers.get(idp_id) + if not idp: + raise NotFoundError("Unknown identity provider") + # if we only have one auth provider, redirect to it directly - if len(self._identity_providers) == 1: - ap = next(iter(self._identity_providers.values())) - return await ap.handle_redirect_request(request, client_redirect_url) + elif len(self._identity_providers) == 1: + idp = next(iter(self._identity_providers.values())) + + if idp: + return await idp.handle_redirect_request(request, client_redirect_url) # otherwise, redirect to the IDP picker return "/_synapse/client/pick_idp?" + urlencode( diff --git a/synapse/http/server.py b/synapse/http/server.py index e464bfe6c7..d69d579b3a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -22,10 +22,22 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Iterator, List, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Pattern, + Tuple, + Union, +) import jinja2 from canonicaljson import iterencode_canonical_json +from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -168,11 +180,25 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer: +# Type of a callback method for processing requests +# it is actually called with a SynapseRequest and a kwargs dict for the params, +# but I can't figure out how to represent that. +ServletCallback = Callable[ + ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]] +] + + +class HttpServer(Protocol): """ Interface for registering callbacks on a HTTP server """ - def register_paths(self, method, path_patterns, callback): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -180,12 +206,14 @@ class HttpServer: an unpacked tuple. Args: - method (str): The method to listen to. - path_patterns (list): The regex used to match requests. - callback (function): The function to fire if we receive a matched + method: The HTTP method to listen to. + path_patterns: The regex used to match requests. + callback: The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. - This should return a tuple of (code, response). + This should return either tuple of (code, response), or None. + servlet_classname (str): The name of the handler to be used in prometheus + and opentracing logs. """ pass @@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource): def _get_handler_for_request( self, request: SynapseRequest - ) -> Tuple[Callable, str, Dict[str, str]]: + ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. Returns: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index be938df962..0a561eea60 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.http.server import finish_request +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled self.auth = hs.get_auth() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - # While its valid for us to advertise this login type generally, + sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + + if self._msc2858_enabled: + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the # SSO login flow. # Generally we don't want to advertise login flows that clients @@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet): return result +def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + """ + e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + if idp.idp_icon: + e["icon"] = idp.idp_icon + return e + + class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet): if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) - async def on_GET(self, request: SynapseRequest): + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) sso_url = await self._sso_handler.handle_redirect_request( - request, client_redirect_url + request, client_redirect_url, idp_id, ) logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index 2672ce24c6..e2bb945453 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("", ""), ('q" =+"', '"fö&=o"')] +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d["/_synapse/oidc"] = OIDCResource(self.hs) return d + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.sso"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(channel.json_body["flows"], expected_flows) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results = {} # type: Dict[str, Any] + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker @@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/utils.py b/tests/utils.py index 09614093bc..022223cf24 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore @@ -351,7 +350,7 @@ def mock_getRawHeaders(headers=None): # This is a mock /resource/ not an entire server -class MockHttpResource(HttpServer): +class MockHttpResource: def __init__(self, prefix=""): self.callbacks = [] # 3-tuple of method/pattern/function self.prefix = prefix -- cgit 1.5.1 From 869667760f571c9edebab660061e17035d57f182 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 27 Jan 2021 21:28:59 +0000 Subject: Support for scraping email addresses from OIDC providers (#9245) --- changelog.d/9245.feature | 1 + docs/sample_config.yaml | 15 +++++++++--- synapse/config/oidc_config.py | 15 +++++++++--- synapse/handlers/oidc_handler.py | 52 +++++++++++++++++++++------------------- 4 files changed, 53 insertions(+), 30 deletions(-) create mode 100644 changelog.d/9245.feature (limited to 'synapse/handlers') diff --git a/changelog.d/9245.feature b/changelog.d/9245.feature new file mode 100644 index 0000000000..b9238207e2 --- /dev/null +++ b/changelog.d/9245.feature @@ -0,0 +1 @@ +Add support to the OpenID Connect integration for adding the user's email address. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 87bfe22237..1c90156db9 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1791,9 +1791,9 @@ saml2_config: # # For the default provider, the following settings are available: # -# sub: name of the claim containing a unique identifier for the -# user. Defaults to 'sub', which OpenID Connect compliant -# providers should provide. +# subject_claim: name of the claim containing a unique identifier +# for the user. Defaults to 'sub', which OpenID Connect +# compliant providers should provide. # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their @@ -1802,6 +1802,9 @@ saml2_config: # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. # +# email_template: Jinja2 template for the email address of the user. +# If unset, no email address will be added to the account. +# # extra_attributes: a map of Jinja2 templates for extra attributes # to send back to the client during login. # Note that these are non-standard and clients will ignore them @@ -1837,6 +1840,12 @@ oidc_providers: # userinfo_endpoint: "https://accounts.example.com/userinfo" # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" # skip_verification: true + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{ user.login }" + # display_name_template: "{ user.name }" + # email_template: "{ user.email }" # For use with Keycloak # diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index bfeceeed18..8237b2e797 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -143,9 +143,9 @@ class OIDCConfig(Config): # # For the default provider, the following settings are available: # - # sub: name of the claim containing a unique identifier for the - # user. Defaults to 'sub', which OpenID Connect compliant - # providers should provide. + # subject_claim: name of the claim containing a unique identifier + # for the user. Defaults to 'sub', which OpenID Connect + # compliant providers should provide. # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their @@ -154,6 +154,9 @@ class OIDCConfig(Config): # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. # + # email_template: Jinja2 template for the email address of the user. + # If unset, no email address will be added to the account. + # # extra_attributes: a map of Jinja2 templates for extra attributes # to send back to the client during login. # Note that these are non-standard and clients will ignore them @@ -189,6 +192,12 @@ class OIDCConfig(Config): # userinfo_endpoint: "https://accounts.example.com/userinfo" # jwks_uri: "https://accounts.example.com/.well-known/jwks.json" # skip_verification: true + # user_mapping_provider: + # config: + # subject_claim: "id" + # localpart_template: "{{ user.login }}" + # display_name_template: "{{ user.name }}" + # email_template: "{{ user.email }}" # For use with Keycloak # diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 1607e12935..324ddb798c 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -1056,7 +1056,8 @@ class OidcSessionData: UserAttributeDict = TypedDict( - "UserAttributeDict", {"localpart": Optional[str], "display_name": Optional[str]} + "UserAttributeDict", + {"localpart": Optional[str], "display_name": Optional[str], "emails": List[str]}, ) C = TypeVar("C") @@ -1135,11 +1136,12 @@ def jinja_finalize(thing): env = Environment(finalize=jinja_finalize) -@attr.s +@attr.s(slots=True, frozen=True) class JinjaOidcMappingConfig: subject_claim = attr.ib(type=str) localpart_template = attr.ib(type=Optional[Template]) display_name_template = attr.ib(type=Optional[Template]) + email_template = attr.ib(type=Optional[Template]) extra_attributes = attr.ib(type=Dict[str, Template]) @@ -1156,23 +1158,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): def parse_config(config: dict) -> JinjaOidcMappingConfig: subject_claim = config.get("subject_claim", "sub") - localpart_template = None # type: Optional[Template] - if "localpart_template" in config: + def parse_template_config(option_name: str) -> Optional[Template]: + if option_name not in config: + return None try: - localpart_template = env.from_string(config["localpart_template"]) + return env.from_string(config[option_name]) except Exception as e: - raise ConfigError( - "invalid jinja template", path=["localpart_template"] - ) from e + raise ConfigError("invalid jinja template", path=[option_name]) from e - display_name_template = None # type: Optional[Template] - if "display_name_template" in config: - try: - display_name_template = env.from_string(config["display_name_template"]) - except Exception as e: - raise ConfigError( - "invalid jinja template", path=["display_name_template"] - ) from e + localpart_template = parse_template_config("localpart_template") + display_name_template = parse_template_config("display_name_template") + email_template = parse_template_config("email_template") extra_attributes = {} # type Dict[str, Template] if "extra_attributes" in config: @@ -1192,6 +1188,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): subject_claim=subject_claim, localpart_template=localpart_template, display_name_template=display_name_template, + email_template=email_template, extra_attributes=extra_attributes, ) @@ -1213,16 +1210,23 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]): # a usable mxid. localpart += str(failures) if failures else "" - display_name = None # type: Optional[str] - if self._config.display_name_template is not None: - display_name = self._config.display_name_template.render( - user=userinfo - ).strip() + def render_template_field(template: Optional[Template]) -> Optional[str]: + if template is None: + return None + return template.render(user=userinfo).strip() + + display_name = render_template_field(self._config.display_name_template) + if display_name == "": + display_name = None - if display_name == "": - display_name = None + emails = [] # type: List[str] + email = render_template_field(self._config.email_template) + if email: + emails.append(email) - return UserAttributeDict(localpart=localpart, display_name=display_name) + return UserAttributeDict( + localpart=localpart, display_name=display_name, emails=emails + ) async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict: extras = {} # type: Dict[str, str] -- cgit 1.5.1 From a083aea396dbd455858e93d6a57a236e192b68e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 27 Jan 2021 21:31:45 +0000 Subject: Add 'brand' field to MSC2858 response (#9242) We've decided to add a 'brand' field to help clients decide how to style the buttons. Also, fix up the allowed characters for idp_id, while I'm in the area. --- changelog.d/9183.feature | 2 +- changelog.d/9242.feature | 1 + docs/openid.md | 3 +++ docs/sample_config.yaml | 13 ++++++---- synapse/config/oidc_config.py | 52 +++++++++++++++++++++------------------- synapse/handlers/cas_handler.py | 3 ++- synapse/handlers/oidc_handler.py | 3 +++ synapse/handlers/saml_handler.py | 3 ++- synapse/handlers/sso.py | 5 ++++ synapse/rest/client/v1/login.py | 2 ++ 10 files changed, 55 insertions(+), 32 deletions(-) create mode 100644 changelog.d/9242.feature (limited to 'synapse/handlers') diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature index 2d5c735042..3bcd9f15d1 100644 --- a/changelog.d/9183.feature +++ b/changelog.d/9183.feature @@ -1 +1 @@ -Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858). +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). diff --git a/changelog.d/9242.feature b/changelog.d/9242.feature new file mode 100644 index 0000000000..3bcd9f15d1 --- /dev/null +++ b/changelog.d/9242.feature @@ -0,0 +1 @@ +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858)). diff --git a/docs/openid.md b/docs/openid.md index b86ae89768..f01f46d326 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -225,6 +225,7 @@ Synapse config: oidc_providers: - idp_id: github idp_name: Github + idp_brand: "org.matrix.github" # optional: styling hint for clients discover: false issuer: "https://github.com/" client_id: "your-client-id" # TO BE FILLED @@ -250,6 +251,7 @@ oidc_providers: oidc_providers: - idp_id: google idp_name: Google + idp_brand: "org.matrix.google" # optional: styling hint for clients issuer: "https://accounts.google.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED @@ -296,6 +298,7 @@ Synapse config: oidc_providers: - idp_id: gitlab idp_name: Gitlab + idp_brand: "org.matrix.gitlab" # optional: styling hint for clients issuer: "https://gitlab.com/" client_id: "your-client-id" # TO BE FILLED client_secret: "your-client-secret" # TO BE FILLED diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 1c90156db9..8777e3254d 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1727,10 +1727,14 @@ saml2_config: # offer the user a choice of login mechanisms. # # idp_icon: An optional icon for this identity provider, which is presented -# by identity picker pages. If given, must be an MXC URI of the format -# mxc:///. (An easy way to obtain such an MXC URI -# is to upload an image to an (unencrypted) room and then copy the "url" -# from the source of the event.) +# by clients and Synapse's own IdP picker page. If given, must be an +# MXC URI of the format mxc:///. (An easy way to +# obtain such an MXC URI is to upload an image to an (unencrypted) room +# and then copy the "url" from the source of the event.) +# +# idp_brand: An optional brand for this identity provider, allowing clients +# to style the login flow according to the identity provider in question. +# See the spec for possible options here. # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -1860,6 +1864,7 @@ oidc_providers: # #- idp_id: github # idp_name: Github + # idp_brand: org.matrix.github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index 8237b2e797..f31511e039 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import string from collections import Counter from typing import Iterable, Optional, Tuple, Type @@ -79,10 +78,14 @@ class OIDCConfig(Config): # offer the user a choice of login mechanisms. # # idp_icon: An optional icon for this identity provider, which is presented - # by identity picker pages. If given, must be an MXC URI of the format - # mxc:///. (An easy way to obtain such an MXC URI - # is to upload an image to an (unencrypted) room and then copy the "url" - # from the source of the event.) + # by clients and Synapse's own IdP picker page. If given, must be an + # MXC URI of the format mxc:///. (An easy way to + # obtain such an MXC URI is to upload an image to an (unencrypted) room + # and then copy the "url" from the source of the event.) + # + # idp_brand: An optional brand for this identity provider, allowing clients + # to style the login flow according to the identity provider in question. + # See the spec for possible options here. # # discover: set to 'false' to disable the use of the OIDC discovery mechanism # to discover endpoints. Defaults to true. @@ -212,6 +215,7 @@ class OIDCConfig(Config): # #- idp_id: github # idp_name: Github + # idp_brand: org.matrix.github # discover: false # issuer: "https://github.com/" # client_id: "your-client-id" # TO BE FILLED @@ -235,11 +239,22 @@ OIDC_PROVIDER_CONFIG_SCHEMA = { "type": "object", "required": ["issuer", "client_id", "client_secret"], "properties": { - # TODO: fix the maxLength here depending on what MSC2528 decides - # remember that we prefix the ID given here with `oidc-` - "idp_id": {"type": "string", "minLength": 1, "maxLength": 128}, + "idp_id": { + "type": "string", + "minLength": 1, + # MSC2858 allows a maxlen of 255, but we prefix with "oidc-" + "maxLength": 250, + "pattern": "^[A-Za-z0-9._~-]+$", + }, "idp_name": {"type": "string"}, "idp_icon": {"type": "string"}, + "idp_brand": { + "type": "string", + # MSC2758-style namespaced identifier + "minLength": 1, + "maxLength": 255, + "pattern": "^[a-z][a-z0-9_.-]*$", + }, "discover": {"type": "boolean"}, "issuer": {"type": "string"}, "client_id": {"type": "string"}, @@ -358,25 +373,8 @@ def _parse_oidc_config_dict( config_path + ("user_mapping_provider", "module"), ) - # MSC2858 will apply certain limits in what can be used as an IdP id, so let's - # enforce those limits now. - # TODO: factor out this stuff to a generic function idp_id = oidc_config.get("idp_id", "oidc") - # TODO: update this validity check based on what MSC2858 decides. - valid_idp_chars = set(string.ascii_lowercase + string.digits + "-._") - - if any(c not in valid_idp_chars for c in idp_id): - raise ConfigError( - 'idp_id may only contain a-z, 0-9, "-", ".", "_"', - config_path + ("idp_id",), - ) - - if idp_id[0] not in string.ascii_lowercase: - raise ConfigError( - "idp_id must start with a-z", config_path + ("idp_id",), - ) - # prefix the given IDP with a prefix specific to the SSO mechanism, to avoid # clashes with other mechs (such as SAML, CAS). # @@ -402,6 +400,7 @@ def _parse_oidc_config_dict( idp_id=idp_id, idp_name=oidc_config.get("idp_name", "OIDC"), idp_icon=idp_icon, + idp_brand=oidc_config.get("idp_brand"), discover=oidc_config.get("discover", True), issuer=oidc_config["issuer"], client_id=oidc_config["client_id"], @@ -432,6 +431,9 @@ class OidcProviderConfig: # Optional MXC URI for icon for this IdP. idp_icon = attr.ib(type=Optional[str]) + # Optional brand identifier for this IdP. + idp_brand = attr.ib(type=Optional[str]) + # whether the OIDC discovery mechanism is used to discover endpoints discover = attr.ib(type=bool) diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py index 0f342c607b..048523ec94 100644 --- a/synapse/handlers/cas_handler.py +++ b/synapse/handlers/cas_handler.py @@ -80,9 +80,10 @@ class CasHandler: # user-facing name of this auth provider self.idp_name = "CAS" - # we do not currently support icons for CAS auth, but this is required by + # we do not currently support brands/icons for CAS auth, but this is required by # the SsoIdentityProvider protocol type. self.idp_icon = None + self.idp_brand = None self._sso_handler = hs.get_sso_handler() diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index 324ddb798c..ca647fa78f 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -274,6 +274,9 @@ class OidcProvider: # MXC URI for icon for this auth provider self.idp_icon = provider.idp_icon + # optional brand identifier for this auth provider + self.idp_brand = provider.idp_brand + self._sso_handler = hs.get_sso_handler() self._sso_handler.register_identity_provider(self) diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 38461cf79d..5946919c33 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -78,9 +78,10 @@ class SamlHandler(BaseHandler): # user-facing name of this auth provider self.idp_name = "SAML" - # we do not currently support icons for SAML auth, but this is required by + # we do not currently support icons/brands for SAML auth, but this is required by # the SsoIdentityProvider protocol type. self.idp_icon = None + self.idp_brand = None # a map from saml session id to Saml2SessionData object self._outstanding_requests_dict = {} # type: Dict[str, Saml2SessionData] diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index afc1341d09..3308b037d2 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -80,6 +80,11 @@ class SsoIdentityProvider(Protocol): """Optional MXC URI for user-facing icon""" return None + @property + def idp_brand(self) -> Optional[str]: + """Optional branding identifier""" + return None + @abc.abstractmethod async def handle_redirect_request( self, diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py index 0a561eea60..0fb9419e58 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py @@ -333,6 +333,8 @@ def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict if idp.idp_icon: e["icon"] = idp.idp_icon + if idp.idp_brand: + e["brand"] = idp.idp_brand return e -- cgit 1.5.1 From a78016dadfb1680f5f77daae9948086b37cbeef8 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 28 Jan 2021 08:34:19 -0500 Subject: Add type hints to E2E handler. (#9232) This finishes adding type hints to the `synapse.handlers` module. --- changelog.d/9232.misc | 1 + mypy.ini | 42 +--- synapse/handlers/device.py | 12 +- synapse/handlers/e2e_keys.py | 223 +++++++++++++--------- synapse/handlers/e2e_room_keys.py | 91 +++++---- synapse/logging/opentracing.py | 2 +- synapse/storage/databases/main/end_to_end_keys.py | 4 +- 7 files changed, 198 insertions(+), 177 deletions(-) create mode 100644 changelog.d/9232.misc (limited to 'synapse/handlers') diff --git a/changelog.d/9232.misc b/changelog.d/9232.misc new file mode 100644 index 0000000000..9d44b621c9 --- /dev/null +++ b/changelog.d/9232.misc @@ -0,0 +1 @@ +Add type hints to handlers code. diff --git a/mypy.ini b/mypy.ini index f3700d323c..68a4533973 100644 --- a/mypy.ini +++ b/mypy.ini @@ -23,47 +23,7 @@ files = synapse/events/validator.py, synapse/events/spamcheck.py, synapse/federation, - synapse/handlers/_base.py, - synapse/handlers/account_data.py, - synapse/handlers/account_validity.py, - synapse/handlers/acme.py, - synapse/handlers/acme_issuing_service.py, - synapse/handlers/admin.py, - synapse/handlers/appservice.py, - synapse/handlers/auth.py, - synapse/handlers/cas_handler.py, - synapse/handlers/deactivate_account.py, - synapse/handlers/device.py, - synapse/handlers/devicemessage.py, - synapse/handlers/directory.py, - synapse/handlers/events.py, - synapse/handlers/federation.py, - synapse/handlers/groups_local.py, - synapse/handlers/identity.py, - synapse/handlers/initial_sync.py, - synapse/handlers/message.py, - synapse/handlers/oidc_handler.py, - synapse/handlers/pagination.py, - synapse/handlers/password_policy.py, - synapse/handlers/presence.py, - synapse/handlers/profile.py, - synapse/handlers/read_marker.py, - synapse/handlers/receipts.py, - synapse/handlers/register.py, - synapse/handlers/room.py, - synapse/handlers/room_list.py, - synapse/handlers/room_member.py, - synapse/handlers/room_member_worker.py, - synapse/handlers/saml_handler.py, - synapse/handlers/search.py, - synapse/handlers/set_password.py, - synapse/handlers/sso.py, - synapse/handlers/state_deltas.py, - synapse/handlers/stats.py, - synapse/handlers/sync.py, - synapse/handlers/typing.py, - synapse/handlers/user_directory.py, - synapse/handlers/ui_auth, + synapse/handlers, synapse/http/client.py, synapse/http/federation/matrix_federation_agent.py, synapse/http/federation/well_known_resolver.py, diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index debb1b4f29..0863154f7a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple from synapse.api import errors from synapse.api.constants import EventTypes @@ -62,7 +62,7 @@ class DeviceWorkerHandler(BaseHandler): self._auth_handler = hs.get_auth_handler() @trace - async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]: + async def get_devices_by_user(self, user_id: str) -> List[JsonDict]: """ Retrieve the given user's devices @@ -85,7 +85,7 @@ class DeviceWorkerHandler(BaseHandler): return devices @trace - async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]: + async def get_device(self, user_id: str, device_id: str) -> JsonDict: """ Retrieve the given device Args: @@ -598,7 +598,7 @@ class DeviceHandler(DeviceWorkerHandler): def _update_device_from_client_ips( - device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]] + device: JsonDict, client_ips: Dict[Tuple[str, str], JsonDict] ) -> None: ip = client_ips.get((device["user_id"], device["device_id"]), {}) device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")}) @@ -946,8 +946,8 @@ class DeviceListUpdater: async def process_cross_signing_key_update( self, user_id: str, - master_key: Optional[Dict[str, Any]], - self_signing_key: Optional[Dict[str, Any]], + master_key: Optional[JsonDict], + self_signing_key: Optional[JsonDict], ) -> List[str]: """Process the given new master and self-signing key for the given remote user. diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 929752150d..8f3a6b35a4 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -16,7 +16,7 @@ # limitations under the License. import logging -from typing import Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple import attr from canonicaljson import encode_canonical_json @@ -31,6 +31,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet from synapse.types import ( + JsonDict, UserID, get_domain_from_id, get_verify_key_from_cross_signing_key, @@ -40,11 +41,14 @@ from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) class E2eKeysHandler: - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() @@ -78,7 +82,9 @@ class E2eKeysHandler: ) @trace - async def query_devices(self, query_body, timeout, from_user_id): + async def query_devices( + self, query_body: JsonDict, timeout: int, from_user_id: str + ) -> JsonDict: """ Handle a device key query from a client { @@ -98,12 +104,14 @@ class E2eKeysHandler: } Args: - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Iterable[str]] # separate users by domain. # make a map from domain to user_id to device_ids @@ -121,7 +129,8 @@ class E2eKeysHandler: set_tag("remote_key_query", remote_queries) # First get local devices. - failures = {} + # A map of destination -> failure response. + failures = {} # type: Dict[str, JsonDict] results = {} if local_query: local_result = await self.query_local_devices(local_query) @@ -135,9 +144,10 @@ class E2eKeysHandler: ) # Now attempt to get any remote devices from our local cache. - remote_queries_not_in_cache = {} + # A map of destination -> user ID -> device IDs. + remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]] if remote_queries: - query_list = [] + query_list = [] # type: List[Tuple[str, Optional[str]]] for user_id, device_ids in remote_queries.items(): if device_ids: query_list.extend((user_id, device_id) for device_id in device_ids) @@ -284,15 +294,15 @@ class E2eKeysHandler: return ret async def get_cross_signing_keys_from_cache( - self, query, from_user_id + self, query: Iterable[str], from_user_id: Optional[str] ) -> Dict[str, Dict[str, dict]]: """Get cross-signing keys for users from the database Args: - query (Iterable[string]) an iterable of user IDs. A dict whose keys + query: an iterable of user IDs. A dict whose keys are user IDs satisfies this, so the query format used for query_devices can be used here. - from_user_id (str): the user making the query. This is used when + from_user_id: the user making the query. This is used when adding cross-signing signatures to limit what signatures users can see. @@ -315,14 +325,12 @@ class E2eKeysHandler: if "self_signing" in user_info: self_signing_keys[user_id] = user_info["self_signing"] - if ( - from_user_id in keys - and keys[from_user_id] is not None - and "user_signing" in keys[from_user_id] - ): - # users can see other users' master and self-signing keys, but can - # only see their own user-signing keys - user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"] + # users can see other users' master and self-signing keys, but can + # only see their own user-signing keys + if from_user_id: + from_user_key = keys.get(from_user_id) + if from_user_key and "user_signing" in from_user_key: + user_signing_keys[from_user_id] = from_user_key["user_signing"] return { "master_keys": master_keys, @@ -344,9 +352,9 @@ class E2eKeysHandler: A map from user_id -> device_id -> device details """ set_tag("local_query", query) - local_query = [] + local_query = [] # type: List[Tuple[str, Optional[str]]] - result_dict = {} + result_dict = {} # type: Dict[str, Dict[str, dict]] for user_id, device_ids in query.items(): # we use UserID.from_string to catch invalid user ids if not self.is_mine(UserID.from_string(user_id)): @@ -380,10 +388,14 @@ class E2eKeysHandler: log_kv(results) return result_dict - async def on_federation_query_client_keys(self, query_body): + async def on_federation_query_client_keys( + self, query_body: Dict[str, Dict[str, Optional[List[str]]]] + ) -> JsonDict: """ Handle a device key query from a federated server """ - device_keys_query = query_body.get("device_keys", {}) + device_keys_query = query_body.get( + "device_keys", {} + ) # type: Dict[str, Optional[List[str]]] res = await self.query_local_devices(device_keys_query) ret = {"device_keys": res} @@ -397,31 +409,34 @@ class E2eKeysHandler: return ret @trace - async def claim_one_time_keys(self, query, timeout): - local_query = [] - remote_queries = {} + async def claim_one_time_keys( + self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int + ) -> JsonDict: + local_query = [] # type: List[Tuple[str, str, str]] + remote_queries = {} # type: Dict[str, Dict[str, Dict[str, str]]] - for user_id, device_keys in query.get("one_time_keys", {}).items(): + for user_id, one_time_keys in query.get("one_time_keys", {}).items(): # we use UserID.from_string to catch invalid user ids if self.is_mine(UserID.from_string(user_id)): - for device_id, algorithm in device_keys.items(): + for device_id, algorithm in one_time_keys.items(): local_query.append((user_id, device_id, algorithm)) else: domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_keys + remote_queries.setdefault(domain, {})[user_id] = one_time_keys set_tag("local_key_query", local_query) set_tag("remote_key_query", remote_queries) results = await self.store.claim_e2e_one_time_keys(local_query) - json_result = {} - failures = {} + # A map of user ID -> device ID -> key ID -> key. + json_result = {} # type: Dict[str, Dict[str, Dict[str, JsonDict]]] + failures = {} # type: Dict[str, JsonDict] for user_id, device_keys in results.items(): for device_id, keys in device_keys.items(): - for key_id, json_bytes in keys.items(): + for key_id, json_str in keys.items(): json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_bytes) + key_id: json_decoder.decode(json_str) } @trace @@ -468,7 +483,9 @@ class E2eKeysHandler: return {"one_time_keys": json_result, "failures": failures} @tag_args - async def upload_keys_for_user(self, user_id, device_id, keys): + async def upload_keys_for_user( + self, user_id: str, device_id: str, keys: JsonDict + ) -> JsonDict: time_now = self.clock.time_msec() @@ -543,8 +560,8 @@ class E2eKeysHandler: return {"one_time_key_counts": result} async def _upload_one_time_keys_for_user( - self, user_id, device_id, time_now, one_time_keys - ): + self, user_id: str, device_id: str, time_now: int, one_time_keys: JsonDict + ) -> None: logger.info( "Adding one_time_keys %r for device %r for user %r at %d", one_time_keys.keys(), @@ -585,12 +602,14 @@ class E2eKeysHandler: log_kv({"message": "Inserting new one_time_keys.", "keys": new_keys}) await self.store.add_e2e_one_time_keys(user_id, device_id, time_now, new_keys) - async def upload_signing_keys_for_user(self, user_id, keys): + async def upload_signing_keys_for_user( + self, user_id: str, keys: JsonDict + ) -> JsonDict: """Upload signing keys for cross-signing Args: - user_id (string): the user uploading the keys - keys (dict[string, dict]): the signing keys + user_id: the user uploading the keys + keys: the signing keys """ # if a master key is uploaded, then check it. Otherwise, load the @@ -667,16 +686,17 @@ class E2eKeysHandler: return {} - async def upload_signatures_for_device_keys(self, user_id, signatures): + async def upload_signatures_for_device_keys( + self, user_id: str, signatures: JsonDict + ) -> JsonDict: """Upload device signatures for cross-signing Args: - user_id (string): the user uploading the signatures - signatures (dict[string, dict[string, dict]]): map of users to - devices to signed keys. This is the submission from the user; an - exception will be raised if it is malformed. + user_id: the user uploading the signatures + signatures: map of users to devices to signed keys. This is the submission + from the user; an exception will be raised if it is malformed. Returns: - dict: response to be sent back to the client. The response will have + The response to be sent back to the client. The response will have a "failures" key, which will be a dict mapping users to devices to errors for the signatures that failed. Raises: @@ -719,7 +739,9 @@ class E2eKeysHandler: return {"failures": failures} - async def _process_self_signatures(self, user_id, signatures): + async def _process_self_signatures( + self, user_id: str, signatures: JsonDict + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of the user's own keys. Signatures of the user's own keys from this API come in two forms: @@ -731,15 +753,14 @@ class E2eKeysHandler: signatures (dict[string, dict]): map of devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure - reasons + A tuple of a list of signatures to store, and a map of users to + devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -834,19 +855,24 @@ class E2eKeysHandler: return signature_list, failures def _check_master_key_signature( - self, user_id, master_key_id, signed_master_key, stored_master_key, devices - ): + self, + user_id: str, + master_key_id: str, + signed_master_key: JsonDict, + stored_master_key: JsonDict, + devices: Dict[str, Dict[str, JsonDict]], + ) -> List["SignatureListItem"]: """Check signatures of a user's master key made by their devices. Args: - user_id (string): the user whose master key is being checked - master_key_id (string): the ID of the user's master key - signed_master_key (dict): the user's signed master key that was uploaded - stored_master_key (dict): our previously-stored copy of the user's master key - devices (iterable(dict)): the user's devices + user_id: the user whose master key is being checked + master_key_id: the ID of the user's master key + signed_master_key: the user's signed master key that was uploaded + stored_master_key: our previously-stored copy of the user's master key + devices: the user's devices Returns: - list[SignatureListItem]: a list of signatures to store + A list of signatures to store Raises: SynapseError: if a signature is invalid @@ -877,25 +903,26 @@ class E2eKeysHandler: return master_key_signature_list - async def _process_other_signatures(self, user_id, signatures): + async def _process_other_signatures( + self, user_id: str, signatures: Dict[str, dict] + ) -> Tuple[List["SignatureListItem"], Dict[str, Dict[str, dict]]]: """Process uploaded signatures of other users' keys. These will be the target user's master keys, signed by the uploading user's user-signing key. Args: - user_id (string): the user uploading the keys - signatures (dict[string, dict]): map of users to devices to signed keys + user_id: the user uploading the keys + signatures: map of users to devices to signed keys Returns: - (list[SignatureListItem], dict[string, dict[string, dict]]): - a list of signatures to store, and a map of users to devices to failure + A list of signatures to store, and a map of users to devices to failure reasons Raises: SynapseError: if the input is malformed """ - signature_list = [] - failures = {} + signature_list = [] # type: List[SignatureListItem] + failures = {} # type: Dict[str, Dict[str, JsonDict]] if not signatures: return signature_list, failures @@ -983,7 +1010,7 @@ class E2eKeysHandler: async def _get_e2e_cross_signing_verify_key( self, user_id: str, key_type: str, from_user_id: str = None - ): + ) -> Tuple[JsonDict, str, VerifyKey]: """Fetch locally or remotely query for a cross-signing public key. First, attempt to fetch the cross-signing public key from storage. @@ -997,8 +1024,7 @@ class E2eKeysHandler: This affects what signatures are fetched. Returns: - dict, str, VerifyKey: the raw key data, the key ID, and the - signedjson verify key + The raw key data, the key ID, and the signedjson verify key Raises: NotFoundError: if the key is not found @@ -1135,16 +1161,18 @@ class E2eKeysHandler: return desired_key, desired_key_id, desired_verify_key -def _check_cross_signing_key(key, user_id, key_type, signing_key=None): +def _check_cross_signing_key( + key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None +) -> None: """Check a cross-signing key uploaded by a user. Performs some basic sanity checking, and ensures that it is signed, if a signature is required. Args: - key (dict): the key data to verify - user_id (str): the user whose key is being checked - key_type (str): the type of key that the key should be - signing_key (VerifyKey): (optional) the signing key that the key should - be signed with. If omitted, signatures will not be checked. + key: the key data to verify + user_id: the user whose key is being checked + key_type: the type of key that the key should be + signing_key: the signing key that the key should be signed with. If + omitted, signatures will not be checked. """ if ( key.get("user_id") != user_id @@ -1162,16 +1190,21 @@ def _check_cross_signing_key(key, user_id, key_type, signing_key=None): ) -def _check_device_signature(user_id, verify_key, signed_device, stored_device): +def _check_device_signature( + user_id: str, + verify_key: VerifyKey, + signed_device: JsonDict, + stored_device: JsonDict, +) -> None: """Check that a signature on a device or cross-signing key is correct and matches the copy of the device/key that we have stored. Throws an exception if an error is detected. Args: - user_id (str): the user ID whose signature is being checked - verify_key (VerifyKey): the key to verify the device with - signed_device (dict): the uploaded signed device data - stored_device (dict): our previously stored copy of the device + user_id: the user ID whose signature is being checked + verify_key: the key to verify the device with + signed_device: the uploaded signed device data + stored_device: our previously stored copy of the device Raises: SynapseError: if the signature was invalid or the sent device is not the @@ -1201,7 +1234,7 @@ def _check_device_signature(user_id, verify_key, signed_device, stored_device): raise SynapseError(400, "Invalid signature", Codes.INVALID_SIGNATURE) -def _exception_to_failure(e): +def _exception_to_failure(e: Exception) -> JsonDict: if isinstance(e, SynapseError): return {"status": e.code, "errcode": e.errcode, "message": str(e)} @@ -1218,7 +1251,7 @@ def _exception_to_failure(e): return {"status": 503, "message": str(e)} -def _one_time_keys_match(old_key_json, new_key): +def _one_time_keys_match(old_key_json: str, new_key: JsonDict) -> bool: old_key = json_decoder.decode(old_key_json) # if either is a string rather than an object, they must match exactly @@ -1239,16 +1272,16 @@ class SignatureListItem: """An item in the signature list as used by upload_signatures_for_device_keys. """ - signing_key_id = attr.ib() - target_user_id = attr.ib() - target_device_id = attr.ib() - signature = attr.ib() + signing_key_id = attr.ib(type=str) + target_user_id = attr.ib(type=str) + target_device_id = attr.ib(type=str) + signature = attr.ib(type=JsonDict) class SigningKeyEduUpdater: """Handles incoming signing key updates from federation and updates the DB""" - def __init__(self, hs, e2e_keys_handler): + def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler): self.store = hs.get_datastore() self.federation = hs.get_federation_client() self.clock = hs.get_clock() @@ -1257,7 +1290,7 @@ class SigningKeyEduUpdater: self._remote_edu_linearizer = Linearizer(name="remote_signing_key") # user_id -> list of updates waiting to be handled. - self._pending_updates = {} + self._pending_updates = {} # type: Dict[str, List[Tuple[JsonDict, JsonDict]]] # Recently seen stream ids. We don't bother keeping these in the DB, # but they're useful to have them about to reduce the number of spurious @@ -1270,13 +1303,15 @@ class SigningKeyEduUpdater: iterable=True, ) - async def incoming_signing_key_update(self, origin, edu_content): + async def incoming_signing_key_update( + self, origin: str, edu_content: JsonDict + ) -> None: """Called on incoming signing key update from federation. Responsible for parsing the EDU and adding to pending updates list. Args: - origin (string): the server that sent the EDU - edu_content (dict): the contents of the EDU + origin: the server that sent the EDU + edu_content: the contents of the EDU """ user_id = edu_content.pop("user_id") @@ -1299,11 +1334,11 @@ class SigningKeyEduUpdater: await self._handle_signing_key_updates(user_id) - async def _handle_signing_key_updates(self, user_id): + async def _handle_signing_key_updates(self, user_id: str) -> None: """Actually handle pending updates. Args: - user_id (string): the user whose updates we are processing + user_id: the user whose updates we are processing """ device_handler = self.e2e_keys_handler.device_handler @@ -1315,7 +1350,7 @@ class SigningKeyEduUpdater: # This can happen since we batch updates return - device_ids = [] + device_ids = [] # type: List[str] logger.info("pending updates: %r", pending_updates) diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index f01b090772..622cae23be 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -15,6 +15,7 @@ # limitations under the License. import logging +from typing import TYPE_CHECKING, List, Optional from synapse.api.errors import ( Codes, @@ -24,8 +25,12 @@ from synapse.api.errors import ( SynapseError, ) from synapse.logging.opentracing import log_kv, trace +from synapse.types import JsonDict from synapse.util.async_helpers import Linearizer +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) @@ -37,7 +42,7 @@ class E2eRoomKeysHandler: The actual payload of the encrypted keys is completely opaque to the handler. """ - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.store = hs.get_datastore() # Used to lock whenever a client is uploading key data. This prevents collisions @@ -48,21 +53,27 @@ class E2eRoomKeysHandler: self._upload_linearizer = Linearizer("upload_room_keys_lock") @trace - async def get_room_keys(self, user_id, version, room_id=None, session_id=None): + async def get_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> List[JsonDict]: """Bulk get the E2E room keys for a given backup, optionally filtered to a given room, or a given session. See EndToEndRoomKeyStore.get_e2e_room_keys for full details. Args: - user_id(str): the user whose keys we're getting - version(str): the version ID of the backup we're getting keys from - room_id(string): room ID to get keys for, for None to get keys for all rooms - session_id(string): session ID to get keys for, for None to get keys for all + user_id: the user whose keys we're getting + version: the version ID of the backup we're getting keys from + room_id: room ID to get keys for, for None to get keys for all rooms + session_id: session ID to get keys for, for None to get keys for all sessions Raises: NotFoundError: if the backup version does not exist Returns: - A deferred list of dicts giving the session_data and message metadata for + A list of dicts giving the session_data and message metadata for these room keys. """ @@ -86,17 +97,23 @@ class E2eRoomKeysHandler: return results @trace - async def delete_room_keys(self, user_id, version, room_id=None, session_id=None): + async def delete_room_keys( + self, + user_id: str, + version: str, + room_id: Optional[str] = None, + session_id: Optional[str] = None, + ) -> JsonDict: """Bulk delete the E2E room keys for a given backup, optionally filtered to a given room or a given session. See EndToEndRoomKeyStore.delete_e2e_room_keys for full details. Args: - user_id(str): the user whose backup we're deleting - version(str): the version ID of the backup we're deleting - room_id(string): room ID to delete keys for, for None to delete keys for all + user_id: the user whose backup we're deleting + version: the version ID of the backup we're deleting + room_id: room ID to delete keys for, for None to delete keys for all rooms - session_id(string): session ID to delete keys for, for None to delete keys + session_id: session ID to delete keys for, for None to delete keys for all sessions Raises: NotFoundError: if the backup version does not exist @@ -128,15 +145,17 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @trace - async def upload_room_keys(self, user_id, version, room_keys): + async def upload_room_keys( + self, user_id: str, version: str, room_keys: JsonDict + ) -> JsonDict: """Bulk upload a list of room keys into a given backup version, asserting that the given version is the current backup version. room_keys are merged into the current backup as described in RoomKeysServlet.on_PUT(). Args: - user_id(str): the user whose backup we're setting - version(str): the version ID of the backup we're updating - room_keys(dict): a nested dict describing the room_keys we're setting: + user_id: the user whose backup we're setting + version: the version ID of the backup we're updating + room_keys: a nested dict describing the room_keys we're setting: { "rooms": { @@ -254,14 +273,16 @@ class E2eRoomKeysHandler: return {"etag": str(version_etag), "count": count} @staticmethod - def _should_replace_room_key(current_room_key, room_key): + def _should_replace_room_key( + current_room_key: Optional[JsonDict], room_key: JsonDict + ) -> bool: """ Determine whether to replace a given current_room_key (if any) with a newly uploaded room_key backup Args: - current_room_key (dict): Optional, the current room_key dict if any - room_key (dict): The new room_key dict which may or may not be fit to + current_room_key: Optional, the current room_key dict if any + room_key : The new room_key dict which may or may not be fit to replace the current_room_key Returns: @@ -286,14 +307,14 @@ class E2eRoomKeysHandler: return True @trace - async def create_version(self, user_id, version_info): + async def create_version(self, user_id: str, version_info: JsonDict) -> str: """Create a new backup version. This automatically becomes the new backup version for the user's keys; previous backups will no longer be writeable to. Args: - user_id(str): the user whose backup version we're creating - version_info(dict): metadata about the new version being created + user_id: the user whose backup version we're creating + version_info: metadata about the new version being created { "algorithm": "m.megolm_backup.v1", @@ -301,7 +322,7 @@ class E2eRoomKeysHandler: } Returns: - A deferred of a string that gives the new version number. + The new version number. """ # TODO: Validate the JSON to make sure it has the right keys. @@ -313,17 +334,19 @@ class E2eRoomKeysHandler: ) return new_version - async def get_version_info(self, user_id, version=None): + async def get_version_info( + self, user_id: str, version: Optional[str] = None + ) -> JsonDict: """Get the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're querying - version(str): Optional; if None gives the most recent version + user_id: the user whose current backup version we're querying + version: Optional; if None gives the most recent version otherwise a historical one. Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of a info dict that gives the info about the new version. + A info dict that gives the info about the new version. { "version": "1234", @@ -346,7 +369,7 @@ class E2eRoomKeysHandler: return res @trace - async def delete_version(self, user_id, version=None): + async def delete_version(self, user_id: str, version: Optional[str] = None) -> None: """Deletes a given version of the user's e2e_room_keys backup Args: @@ -366,17 +389,19 @@ class E2eRoomKeysHandler: raise @trace - async def update_version(self, user_id, version, version_info): + async def update_version( + self, user_id: str, version: str, version_info: JsonDict + ) -> JsonDict: """Update the info about a given version of the user's backup Args: - user_id(str): the user whose current backup version we're updating - version(str): the backup version we're updating - version_info(dict): the new information about the backup + user_id: the user whose current backup version we're updating + version: the backup version we're updating + version_info: the new information about the backup Raises: NotFoundError: if the requested backup version doesn't exist Returns: - A deferred of an empty dict. + An empty dict. """ if "version" not in version_info: version_info["version"] = version diff --git a/synapse/logging/opentracing.py b/synapse/logging/opentracing.py index ab586c318c..0538350f38 100644 --- a/synapse/logging/opentracing.py +++ b/synapse/logging/opentracing.py @@ -791,7 +791,7 @@ def tag_args(func): @wraps(func) def _tag_args_inner(*args, **kwargs): - argspec = inspect.getargspec(func) + argspec = inspect.getfullargspec(func) for i, arg in enumerate(argspec.args[1:]): set_tag("ARG_" + arg, args[i]) set_tag("args", args[len(argspec.args) :]) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c128889bf9..309f1e865b 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -634,7 +634,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def get_e2e_cross_signing_keys_bulk( self, user_ids: List[str], from_user_id: Optional[str] = None - ) -> Dict[str, Dict[str, dict]]: + ) -> Dict[str, Optional[Dict[str, dict]]]: """Returns the cross-signing keys for a set of users. Args: @@ -724,7 +724,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore): async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] - ) -> Dict[str, Dict[str, Dict[str, bytes]]]: + ) -> Dict[str, Dict[str, Dict[str, str]]]: """Take a list of one time keys out of the database. Args: -- cgit 1.5.1 From 4b73488e811714089ba447884dccb9b6ae3ac16c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 28 Jan 2021 17:39:21 +0000 Subject: Ratelimit 3PID /requestToken API (#9238) --- changelog.d/9238.feature | 1 + docs/sample_config.yaml | 6 +- synapse/config/_base.pyi | 2 +- synapse/config/ratelimiting.py | 13 ++++- synapse/handlers/identity.py | 28 ++++++++++ synapse/rest/client/v2_alpha/account.py | 12 +++- synapse/rest/client/v2_alpha/register.py | 6 ++ tests/rest/client/v2_alpha/test_account.py | 90 ++++++++++++++++++++++++++++-- tests/server.py | 9 ++- tests/unittest.py | 5 ++ tests/utils.py | 1 + 11 files changed, 159 insertions(+), 14 deletions(-) create mode 100644 changelog.d/9238.feature (limited to 'synapse/handlers') diff --git a/changelog.d/9238.feature b/changelog.d/9238.feature new file mode 100644 index 0000000000..143a3e14f5 --- /dev/null +++ b/changelog.d/9238.feature @@ -0,0 +1 @@ +Add ratelimited to 3PID `/requestToken` API. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index c2ccd68f3a..e5b6268087 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -824,6 +824,7 @@ log_config: "CONFDIR/SERVERNAME.log.config" # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) +# - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # # The defaults are as shown below. # @@ -857,7 +858,10 @@ log_config: "CONFDIR/SERVERNAME.log.config" # remote: # per_second: 0.01 # burst_count: 3 - +# +#rc_3pid_validation: +# per_second: 0.003 +# burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi index 7ed07a801d..70025b5d60 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi @@ -54,7 +54,7 @@ class RootConfig: tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig - ratelimit: ratelimiting.RatelimitConfig + ratelimiting: ratelimiting.RatelimitConfig media: repository.ContentRepositoryConfig captcha: captcha.CaptchaConfig voip: voip.VoipConfig diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 14b8836197..76f382527d 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -24,7 +24,7 @@ class RateLimitConfig: defaults={"per_second": 0.17, "burst_count": 3.0}, ): self.per_second = config.get("per_second", defaults["per_second"]) - self.burst_count = config.get("burst_count", defaults["burst_count"]) + self.burst_count = int(config.get("burst_count", defaults["burst_count"])) class FederationRateLimitConfig: @@ -102,6 +102,11 @@ class RatelimitConfig(Config): defaults={"per_second": 0.01, "burst_count": 3}, ) + self.rc_3pid_validation = RateLimitConfig( + config.get("rc_3pid_validation") or {}, + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -131,6 +136,7 @@ class RatelimitConfig(Config): # users are joining rooms the server is already in (this is cheap) vs # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) + # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. # # The defaults are as shown below. # @@ -164,7 +170,10 @@ class RatelimitConfig(Config): # remote: # per_second: 0.01 # burst_count: 3 - + # + #rc_3pid_validation: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py index f61844d688..4f7137539b 100644 --- a/synapse/handlers/identity.py +++ b/synapse/handlers/identity.py @@ -27,9 +27,11 @@ from synapse.api.errors import ( HttpResponseException, SynapseError, ) +from synapse.api.ratelimiting import Ratelimiter from synapse.config.emailconfig import ThreepidBehaviour from synapse.http import RequestTimedOutError from synapse.http.client import SimpleHttpClient +from synapse.http.site import SynapseRequest from synapse.types import JsonDict, Requester from synapse.util import json_decoder from synapse.util.hash import sha256_and_url_safe_base64 @@ -57,6 +59,32 @@ class IdentityHandler(BaseHandler): self._web_client_location = hs.config.invite_client_location + # Ratelimiters for `/requestToken` endpoints. + self._3pid_validation_ratelimiter_ip = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + self._3pid_validation_ratelimiter_address = Ratelimiter( + clock=hs.get_clock(), + rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, + burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count, + ) + + def ratelimit_request_token_requests( + self, request: SynapseRequest, medium: str, address: str, + ): + """Used to ratelimit requests to `/requestToken` by IP and address. + + Args: + request: The associated request + medium: The type of threepid, e.g. "msisdn" or "email" + address: The actual threepid ID, e.g. the phone number or email address + """ + + self._3pid_validation_ratelimiter_ip.ratelimit((medium, request.getClientIP())) + self._3pid_validation_ratelimiter_address.ratelimit((medium, address)) + async def threepid_from_creds( self, id_server: str, creds: Dict[str, str] ) -> Optional[JsonDict]: diff --git a/synapse/rest/client/v2_alpha/account.py b/synapse/rest/client/v2_alpha/account.py index 65e68d641b..a84a2fb385 100644 --- a/synapse/rest/client/v2_alpha/account.py +++ b/synapse/rest/client/v2_alpha/account.py @@ -54,7 +54,7 @@ logger = logging.getLogger(__name__) class EmailPasswordRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/password/email/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__() self.hs = hs self.datastore = hs.get_datastore() @@ -103,6 +103,8 @@ class EmailPasswordRequestTokenRestServlet(RestServlet): # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + # The email will be sent to the stored address. # This avoids a potential account hijack by requesting a password reset to # an email address which is controlled by the attacker but which, after @@ -379,6 +381,8 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) @@ -430,7 +434,7 @@ class EmailThreepidRequestTokenRestServlet(RestServlet): class MsisdnThreepidRequestTokenRestServlet(RestServlet): PATTERNS = client_patterns("/account/3pid/msisdn/requestToken$") - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): self.hs = hs super().__init__() self.store = self.hs.get_datastore() @@ -458,6 +462,10 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + if next_link: # Raise if the provided next_link value isn't valid assert_valid_next_link(self.hs, next_link) diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py index b093183e79..10e1891174 100644 --- a/synapse/rest/client/v2_alpha/register.py +++ b/synapse/rest/client/v2_alpha/register.py @@ -126,6 +126,8 @@ class EmailRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests(request, "email", email) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "email", email ) @@ -205,6 +207,10 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet): Codes.THREEPID_DENIED, ) + self.identity_handler.ratelimit_request_token_requests( + request, "msisdn", msisdn + ) + existing_user_id = await self.hs.get_datastore().get_user_id_by_threepid( "msisdn", msisdn ) diff --git a/tests/rest/client/v2_alpha/test_account.py b/tests/rest/client/v2_alpha/test_account.py index cb87b80e33..177dc476da 100644 --- a/tests/rest/client/v2_alpha/test_account.py +++ b/tests/rest/client/v2_alpha/test_account.py @@ -24,7 +24,7 @@ import pkg_resources import synapse.rest.admin from synapse.api.constants import LoginType, Membership -from synapse.api.errors import Codes +from synapse.api.errors import Codes, HttpResponseException from synapse.rest.client.v1 import login, room from synapse.rest.client.v2_alpha import account, register from synapse.rest.synapse.client.password_reset import PasswordResetSubmitTokenResource @@ -112,6 +112,56 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): # Assert we can't log in with the old password self.attempt_wrong_password_login("kermit", old_password) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_email(self): + """Test that we ratelimit /requestToken for the same email. + """ + old_password = "monkey" + new_password = "kangeroo" + + user_id = self.register_user("kermit", old_password) + self.login("kermit", old_password) + + email = "test1@example.com" + + # Add a threepid + self.get_success( + self.store.user_add_threepid( + user_id=user_id, + medium="email", + address=email, + validated_at=0, + added_at=0, + ) + ) + + def reset(ip): + client_secret = "foobar" + session_id = self._request_token(email, client_secret, ip) + + self.assertEquals(len(self.email_attempts), 1) + link = self._get_link_from_email() + + self._validate_token(link) + + self._reset_password(new_password, session_id, client_secret) + + self.email_attempts.clear() + + # We expect to be able to make three requests before getting rate + # limited. + # + # We change IPs to ensure that we're not being ratelimited due to the + # same IP + reset("127.0.0.1") + reset("127.0.0.2") + reset("127.0.0.3") + + with self.assertRaises(HttpResponseException) as cm: + reset("127.0.0.4") + + self.assertEqual(cm.exception.code, 429) + def test_basic_password_reset_canonicalise_email(self): """Test basic password reset flow Request password reset with different spelling @@ -239,13 +289,18 @@ class PasswordResetTestCase(unittest.HomeserverTestCase): self.assertIsNotNone(session_id) - def _request_token(self, email, client_secret): + def _request_token(self, email, client_secret, ip="127.0.0.1"): channel = self.make_request( "POST", b"account/password/email/requestToken", {"client_secret": client_secret, "email": email, "send_attempt": 1}, + client_ip=ip, ) - self.assertEquals(200, channel.code, channel.result) + + if channel.code != 200: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"], + ) return channel.json_body["sid"] @@ -509,6 +564,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def test_address_trim(self): self.get_success(self._add_email(" foo@test.bar ", "foo@test.bar")) + @override_config({"rc_3pid_validation": {"burst_count": 3}}) + def test_ratelimit_by_ip(self): + """Tests that adding emails is ratelimited by IP + """ + + # We expect to be able to set three emails before getting ratelimited. + self.get_success(self._add_email("foo1@test.bar", "foo1@test.bar")) + self.get_success(self._add_email("foo2@test.bar", "foo2@test.bar")) + self.get_success(self._add_email("foo3@test.bar", "foo3@test.bar")) + + with self.assertRaises(HttpResponseException) as cm: + self.get_success(self._add_email("foo4@test.bar", "foo4@test.bar")) + + self.assertEqual(cm.exception.code, 429) + def test_add_email_if_disabled(self): """Test adding email to profile when doing so is disallowed """ @@ -777,7 +847,11 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): body["next_link"] = next_link channel = self.make_request("POST", b"account/3pid/email/requestToken", body,) - self.assertEquals(expect_code, channel.code, channel.result) + + if channel.code != expect_code: + raise HttpResponseException( + channel.code, channel.result["reason"], channel.result["body"], + ) return channel.json_body.get("sid") @@ -823,10 +897,12 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): def _add_email(self, request_email, expected_email): """Test adding an email to profile """ + previous_email_attempts = len(self.email_attempts) + client_secret = "foobar" session_id = self._request_token(request_email, client_secret) - self.assertEquals(len(self.email_attempts), 1) + self.assertEquals(len(self.email_attempts) - previous_email_attempts, 1) link = self._get_link_from_email() self._validate_token(link) @@ -855,4 +931,6 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase): self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual("email", channel.json_body["threepids"][0]["medium"]) - self.assertEqual(expected_email, channel.json_body["threepids"][0]["address"]) + + threepids = {threepid["address"] for threepid in channel.json_body["threepids"]} + self.assertIn(expected_email, threepids) diff --git a/tests/server.py b/tests/server.py index 5a85d5fe7f..6419c445ec 100644 --- a/tests/server.py +++ b/tests/server.py @@ -47,6 +47,7 @@ class FakeChannel: site = attr.ib(type=Site) _reactor = attr.ib() result = attr.ib(type=dict, default=attr.Factory(dict)) + _ip = attr.ib(type=str, default="127.0.0.1") _producer = None @property @@ -120,7 +121,7 @@ class FakeChannel: def getPeer(self): # We give an address so that getClientIP returns a non null entry, # causing us to record the MAU - return address.IPv4Address("TCP", "127.0.0.1", 3423) + return address.IPv4Address("TCP", self._ip, 3423) def getHost(self): return None @@ -196,6 +197,7 @@ def make_request( custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, + client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Make a web request using the given method, path and content, and render it @@ -223,6 +225,9 @@ def make_request( will pump the reactor until the the renderer tells the channel the request is finished. + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: channel """ @@ -250,7 +255,7 @@ def make_request( if isinstance(content, str): content = content.encode("utf8") - channel = FakeChannel(site, reactor) + channel = FakeChannel(site, reactor, ip=client_ip) req = request(channel) req.content = BytesIO(content) diff --git a/tests/unittest.py b/tests/unittest.py index bbd295687c..767d5d6077 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -386,6 +386,7 @@ class HomeserverTestCase(TestCase): custom_headers: Optional[ Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] ] = None, + client_ip: str = "127.0.0.1", ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the @@ -410,6 +411,9 @@ class HomeserverTestCase(TestCase): custom_headers: (name, value) pairs to add as request headers + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: The FakeChannel object which stores the result of the request. """ @@ -426,6 +430,7 @@ class HomeserverTestCase(TestCase): content_is_form, await_result, custom_headers, + client_ip, ) def setup_test_homeserver(self, *args, **kwargs): diff --git a/tests/utils.py b/tests/utils.py index 022223cf24..68033d7535 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -157,6 +157,7 @@ def default_config(name, parse=False): "local": {"per_second": 10000, "burst_count": 10000}, "remote": {"per_second": 10000, "burst_count": 10000}, }, + "rc_3pid_validation": {"per_second": 10000, "burst_count": 10000}, "saml2_enabled": False, "default_identity_server": None, "key_refresh_interval": 24 * 60 * 60 * 1000, -- cgit 1.5.1 From f2c1560eca1e2160087a280261ca78d0708ad721 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 29 Jan 2021 16:38:29 +0000 Subject: Ratelimit invites by room and target user (#9258) --- changelog.d/9258.feature | 1 + docs/sample_config.yaml | 10 ++++ synapse/config/ratelimiting.py | 19 +++++++ synapse/federation/federation_client.py | 2 +- synapse/handlers/federation.py | 4 ++ synapse/handlers/room.py | 7 +++ synapse/handlers/room_member.py | 25 ++++++++- tests/handlers/test_federation.py | 93 ++++++++++++++++++++++++++++++++- tests/rest/client/v1/test_rooms.py | 35 +++++++++++++ 9 files changed, 192 insertions(+), 4 deletions(-) create mode 100644 changelog.d/9258.feature (limited to 'synapse/handlers') diff --git a/changelog.d/9258.feature b/changelog.d/9258.feature new file mode 100644 index 0000000000..0028f42d26 --- /dev/null +++ b/changelog.d/9258.feature @@ -0,0 +1 @@ +Add ratelimits to invites in rooms and to specific users. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 332befd948..7fd35516dc 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -825,6 +825,8 @@ log_config: "CONFDIR/SERVERNAME.log.config" # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. +# - two for ratelimiting how often invites can be sent in a room or to a +# specific user. # # The defaults are as shown below. # @@ -862,6 +864,14 @@ log_config: "CONFDIR/SERVERNAME.log.config" #rc_3pid_validation: # per_second: 0.003 # burst_count: 5 +# +#rc_invites: +# per_room: +# per_second: 0.3 +# burst_count: 10 +# per_user: +# per_second: 0.003 +# burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/config/ratelimiting.py b/synapse/config/ratelimiting.py index 76f382527d..def33a60ad 100644 --- a/synapse/config/ratelimiting.py +++ b/synapse/config/ratelimiting.py @@ -107,6 +107,15 @@ class RatelimitConfig(Config): defaults={"per_second": 0.003, "burst_count": 5}, ) + self.rc_invites_per_room = RateLimitConfig( + config.get("rc_invites", {}).get("per_room", {}), + defaults={"per_second": 0.3, "burst_count": 10}, + ) + self.rc_invites_per_user = RateLimitConfig( + config.get("rc_invites", {}).get("per_user", {}), + defaults={"per_second": 0.003, "burst_count": 5}, + ) + def generate_config_section(self, **kwargs): return """\ ## Ratelimiting ## @@ -137,6 +146,8 @@ class RatelimitConfig(Config): # "remote" for when users are trying to join rooms not on the server (which # can be more expensive) # - one for ratelimiting how often a user or IP can attempt to validate a 3PID. + # - two for ratelimiting how often invites can be sent in a room or to a + # specific user. # # The defaults are as shown below. # @@ -174,6 +185,14 @@ class RatelimitConfig(Config): #rc_3pid_validation: # per_second: 0.003 # burst_count: 5 + # + #rc_invites: + # per_room: + # per_second: 0.3 + # burst_count: 10 + # per_user: + # per_second: 0.003 + # burst_count: 5 # Ratelimiting settings for incoming federation # diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index d330ae5dbc..40e1451201 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -810,7 +810,7 @@ class FederationClient(FederationBase): "User's homeserver does not support this room version", Codes.UNSUPPORTED_ROOM_VERSION, ) - elif e.code == 403: + elif e.code in (403, 429): raise e.to_synapse_error() else: raise diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b6dc7f99b6..dbdfd56ff5 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1617,6 +1617,10 @@ class FederationHandler(BaseHandler): if event.state_key == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") + # We retrieve the room member handler here as to not cause a cyclic dependency + member_handler = self.hs.get_room_member_handler() + member_handler.ratelimit_invite(event.room_id, event.state_key) + # keep a record of the room version, if we don't yet know it. # (this may get overwritten if we later get a different room version in a # join dance). diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index ee27d99135..07b2187eb1 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -126,6 +126,10 @@ class RoomCreationHandler(BaseHandler): self.third_party_event_rules = hs.get_third_party_event_rules() + self._invite_burst_count = ( + hs.config.ratelimiting.rc_invites_per_room.burst_count + ) + async def upgrade_room( self, requester: Requester, old_room_id: str, new_version: RoomVersion ) -> str: @@ -662,6 +666,9 @@ class RoomCreationHandler(BaseHandler): invite_3pid_list = [] invite_list = [] + if len(invite_list) + len(invite_3pid_list) > self._invite_burst_count: + raise SynapseError(400, "Cannot invite so many users at once") + await self.event_creation_handler.assert_accepted_privacy_policy(requester) power_level_content_override = config.get("power_level_content_override") diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index e001e418f9..d335da6f19 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -85,6 +85,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count, ) + self._invites_per_room_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count, + ) + self._invites_per_user_limiter = Ratelimiter( + clock=self.clock, + rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, + burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count, + ) + # This is only used to get at ratelimit function, and # maybe_kick_guest_users. It's fine there are multiple of these as # it doesn't store state. @@ -144,6 +155,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): """ raise NotImplementedError() + def ratelimit_invite(self, room_id: str, invitee_user_id: str): + """Ratelimit invites by room and by target user. + """ + self._invites_per_room_limiter.ratelimit(room_id) + self._invites_per_user_limiter.ratelimit(invitee_user_id) + async def _local_membership_update( self, requester: Requester, @@ -387,8 +404,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): raise SynapseError(403, "This room has been blocked on this server") if effective_membership_state == Membership.INVITE: + target_id = target.to_string() + if ratelimit: + self.ratelimit_invite(room_id, target_id) + # block any attempts to invite the server notices mxid - if target.to_string() == self._server_notices_mxid: + if target_id == self._server_notices_mxid: raise SynapseError(HTTPStatus.FORBIDDEN, "Cannot invite this user") block_invite = False @@ -412,7 +433,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): block_invite = True if not await self.spam_checker.user_may_invite( - requester.user.to_string(), target.to_string(), room_id + requester.user.to_string(), target_id, room_id ): logger.info("Blocking invite due to spam checker") block_invite = True diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 0b24b89a2e..74503112f5 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -16,7 +16,7 @@ import logging from unittest import TestCase from synapse.api.constants import EventTypes -from synapse.api.errors import AuthError, Codes, SynapseError +from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.room_versions import RoomVersions from synapse.events import EventBase from synapse.federation.federation_base import event_from_pdu_json @@ -191,6 +191,97 @@ class FederationTestCase(unittest.HomeserverTestCase): self.assertEqual(sg, sg2) + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invite_by_room_ratelimit(self): + """Tests that invites from federation in a room are actually rate-limited. + """ + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + + def create_invite_for(local_user): + return event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": other_user, + "state_key": local_user, + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + for i in range(3): + self.get_success( + self.handler.on_invite_request( + other_server, + create_invite_for("@user-%d:test" % (i,)), + room_version, + ) + ) + + self.get_failure( + self.handler.on_invite_request( + other_server, create_invite_for("@user-4:test"), room_version, + ), + exc=LimitExceededError, + ) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invite_by_user_ratelimit(self): + """Tests that invites from federation to a particular user are + actually rate-limited. + """ + other_server = "otherserver" + other_user = "@otheruser:" + other_server + + # create the room + user_id = self.register_user("kermit", "test") + tok = self.login("kermit", "test") + + def create_invite(): + room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + room_version = self.get_success(self.store.get_room_version(room_id)) + return event_from_pdu_json( + { + "type": EventTypes.Member, + "content": {"membership": "invite"}, + "room_id": room_id, + "sender": other_user, + "state_key": "@user:test", + "depth": 32, + "prev_events": [], + "auth_events": [], + "origin_server_ts": self.clock.time_msec(), + }, + room_version, + ) + + for i in range(3): + event = create_invite() + self.get_success( + self.handler.on_invite_request(other_server, event, event.room_version,) + ) + + event = create_invite() + self.get_failure( + self.handler.on_invite_request(other_server, event, event.room_version,), + exc=LimitExceededError, + ) + def _build_and_send_join_event(self, other_server, other_user, room_id): join_event = self.get_success( self.handler.on_make_join_request(other_server, room_id, other_user) diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index d4e3165436..2548b3a80c 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -616,6 +616,41 @@ class RoomMemberStateTestCase(RoomBase): self.assertEquals(json.loads(content), channel.json_body) +class RoomInviteRatelimitTestCase(RoomBase): + user_id = "@sid1:red" + + servlets = [ + admin.register_servlets, + profile.register_servlets, + room.register_servlets, + ] + + @unittest.override_config( + {"rc_invites": {"per_room": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_rooms_ratelimit(self): + """Tests that invites in a room are actually rate-limited.""" + room_id = self.helper.create_room_as(self.user_id) + + for i in range(3): + self.helper.invite(room_id, self.user_id, "@user-%s:red" % (i,)) + + self.helper.invite(room_id, self.user_id, "@user-4:red", expect_code=429) + + @unittest.override_config( + {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}} + ) + def test_invites_by_users_ratelimit(self): + """Tests that invites to a specific user are actually rate-limited.""" + + for i in range(3): + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red") + + room_id = self.helper.create_room_as(self.user_id) + self.helper.invite(room_id, self.user_id, "@other-users:red", expect_code=429) + + class RoomJoinRatelimitTestCase(RoomBase): user_id = "@sid1:red" -- cgit 1.5.1 From f78d07bf005f7212bcc74256721677a3b255ea0e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 13:15:51 +0000 Subject: Split out a separate endpoint to complete SSO registration (#9262) There are going to be a couple of paths to get to the final step of SSO reg, and I want the URL in the browser to consistent. So, let's move the final step onto a separate path, which we redirect to. --- changelog.d/9262.feature | 1 + synapse/app/homeserver.py | 2 + synapse/handlers/sso.py | 81 ++++++++++++++++++++++------ synapse/http/server.py | 7 +++ synapse/rest/synapse/client/pick_username.py | 16 +++--- synapse/rest/synapse/client/sso_register.py | 50 +++++++++++++++++ tests/rest/client/v1/test_login.py | 14 ++++- 7 files changed, 145 insertions(+), 26 deletions(-) create mode 100644 changelog.d/9262.feature create mode 100644 synapse/rest/synapse/client/sso_register.py (limited to 'synapse/handlers') diff --git a/changelog.d/9262.feature b/changelog.d/9262.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9262.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 57a2f5237c..86d6f73674 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -62,6 +62,7 @@ from synapse.rest.health import HealthResource from synapse.rest.key.v2 import KeyApiV2Resource from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.rest.well_known import WellKnownResource from synapse.server import HomeServer from synapse.storage import DataStore @@ -192,6 +193,7 @@ class SynapseHomeServer(HomeServer): "/_synapse/admin": AdminRestResource(self), "/_synapse/client/pick_username": pick_username_resource(self), "/_synapse/client/pick_idp": PickIdpResource(self), + "/_synapse/client/sso_register": SsoRegisterResource(self), } ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 3308b037d2..50c5ae142a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -21,12 +21,13 @@ import attr from typing_extensions import NoReturn, Protocol from twisted.web.http import Request +from twisted.web.iweb import IRequest from synapse.api.constants import LoginType from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent -from synapse.http.server import respond_with_html +from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer @@ -141,6 +142,9 @@ class UsernameMappingSession: # expiry time for the session, in milliseconds expiry_time_ms = attr.ib(type=int) + # choices made by the user + chosen_localpart = attr.ib(type=Optional[str], default=None) + # the HTTP cookie used to track the mapping session id USERNAME_MAPPING_SESSION_COOKIE_NAME = b"username_mapping_session" @@ -647,6 +651,25 @@ class SsoHandler: ) respond_with_html(request, 200, html) + def get_mapping_session(self, session_id: str) -> UsernameMappingSession: + """Look up the given username mapping session + + If it is not found, raises a SynapseError with an http code of 400 + + Args: + session_id: session to look up + Returns: + active mapping session + Raises: + SynapseError if the session is not found/has expired + """ + self._expire_old_sessions() + session = self._username_mapping_sessions.get(session_id) + if session: + return session + logger.info("Couldn't find session id %s", session_id) + raise SynapseError(400, "unknown session") + async def check_username_availability( self, localpart: str, session_id: str, ) -> bool: @@ -663,12 +686,7 @@ class SsoHandler: # make sure that there is a valid mapping session, to stop people dictionary- # scanning for accounts - - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + self.get_mapping_session(session_id) logger.info( "[session %s] Checking for availability of username %s", @@ -696,16 +714,33 @@ class SsoHandler: localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie """ - self._expire_old_sessions() - session = self._username_mapping_sessions.get(session_id) - if not session: - logger.info("Couldn't find session id %s", session_id) - raise SynapseError(400, "unknown session") + session = self.get_mapping_session(session_id) + + # update the session with the user's choices + session.chosen_localpart = localpart + + # we're done; now we can register the user + respond_with_redirect(request, b"/_synapse/client/sso_register") + + async def register_sso_user(self, request: Request, session_id: str) -> None: + """Called once we have all the info we need to register a new user. - logger.info("[session %s] Registering localpart %s", session_id, localpart) + Does so and serves an HTTP response + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + """ + session = self.get_mapping_session(session_id) + + logger.info( + "[session %s] Registering localpart %s", + session_id, + session.chosen_localpart, + ) attributes = UserAttributes( - localpart=localpart, + localpart=session.chosen_localpart, display_name=session.display_name, emails=session.emails, ) @@ -720,7 +755,12 @@ class SsoHandler: request.getClientIP(), ) - logger.info("[session %s] Registered userid %s", session_id, user_id) + logger.info( + "[session %s] Registered userid %s with attributes %s", + session_id, + user_id, + attributes, + ) # delete the mapping session and the cookie del self._username_mapping_sessions[session_id] @@ -751,3 +791,14 @@ class SsoHandler: for session_id in to_expire: logger.info("Expiring mapping session %s", session_id) del self._username_mapping_sessions[session_id] + + +def get_username_mapping_session_cookie_from_request(request: IRequest) -> str: + """Extract the session ID from the cookie + + Raises a SynapseError if the cookie isn't found + """ + session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) + if not session_id: + raise SynapseError(code=400, msg="missing session_id") + return session_id.decode("ascii", errors="replace") diff --git a/synapse/http/server.py b/synapse/http/server.py index d69d579b3a..8249732b27 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -761,6 +761,13 @@ def set_clickjacking_protection_headers(request: Request): request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") +def respond_with_redirect(request: Request, url: bytes) -> None: + """Write a 302 response to the request, if it is still alive.""" + logger.debug("Redirect to %s", url.decode("utf-8")) + request.redirect(url) + finish_request(request) + + def finish_request(request: Request): """ Finish writing the response to the request. diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index d3b6803e65..1bc737bad0 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from typing import TYPE_CHECKING import pkg_resources @@ -20,8 +21,7 @@ from twisted.web.http import Request from twisted.web.resource import Resource from twisted.web.static import File -from synapse.api.errors import SynapseError -from synapse.handlers.sso import USERNAME_MAPPING_SESSION_COOKIE_NAME +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest @@ -61,12 +61,10 @@ class AvailabilityCheckResource(DirectServeJsonResource): async def _async_render_GET(self, request: Request): localpart = parse_string(request, "username", required=True) - session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) - if not session_id: - raise SynapseError(code=400, msg="missing session_id") + session_id = get_username_mapping_session_cookie_from_request(request) is_available = await self._sso_handler.check_username_availability( - localpart, session_id.decode("ascii", errors="replace") + localpart, session_id ) return 200, {"available": is_available} @@ -79,10 +77,8 @@ class SubmitResource(DirectServeHtmlResource): async def _async_render_POST(self, request: SynapseRequest): localpart = parse_string(request, "username", required=True) - session_id = request.getCookie(USERNAME_MAPPING_SESSION_COOKIE_NAME) - if not session_id: - raise SynapseError(code=400, msg="missing session_id") + session_id = get_username_mapping_session_cookie_from_request(request) await self._sso_handler.handle_submit_username_request( - request, localpart, session_id.decode("ascii", errors="replace") + request, localpart, session_id ) diff --git a/synapse/rest/synapse/client/sso_register.py b/synapse/rest/synapse/client/sso_register.py new file mode 100644 index 0000000000..dfefeb7796 --- /dev/null +++ b/synapse/rest/synapse/client/sso_register.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import DirectServeHtmlResource + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class SsoRegisterResource(DirectServeHtmlResource): + """A resource which completes SSO registration + + This resource gets mounted at /_synapse/client/sso_register, and is shown + after we collect username and/or consent for a new SSO user. It (finally) registers + the user, and confirms redirect to the client + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + await self._sso_handler.register_sso_user(request, session_id) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index e2bb945453..f01215ed1c 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -31,6 +31,7 @@ from synapse.rest.client.v2_alpha import devices, register from synapse.rest.client.v2_alpha.account import WhoamiRestServlet from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource +from synapse.rest.synapse.client.sso_register import SsoRegisterResource from synapse.types import create_requester from tests import unittest @@ -1215,6 +1216,7 @@ class UsernamePickerTestCase(HomeserverTestCase): d = super().create_resource_dict() d["/_synapse/client/pick_username"] = pick_username_resource(self.hs) + d["/_synapse/client/sso_register"] = SsoRegisterResource(self.hs) d["/_synapse/oidc"] = OIDCResource(self.hs) return d @@ -1253,7 +1255,7 @@ class UsernamePickerTestCase(HomeserverTestCase): self.assertApproximates(session.expiry_time_ms, expected_expiry, tolerance=1000) # Now, submit a username to the username picker, which should serve a redirect - # back to the client + # to the completion page submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( @@ -1270,6 +1272,16 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") + + # send a request to the completion page, which should 302 to the client redirectUrl + chan = self.make_request( + "GET", + path=location_headers[0], + custom_headers=[("Cookie", "username_mapping_session=" + session_id)], + ) + self.assertEqual(chan.code, 302, chan.result) + location_headers = chan.headers.getRawHeaders("Location") + # ensure that the returned location matches the requested redirect URL path, query = location_headers[0].split("?", 1) self.assertEqual(path, "https://x") -- cgit 1.5.1 From 8aed29dc615bee75019fc526a5c91cdc2638b665 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:50:56 +0000 Subject: Improve styling and wording of SSO redirect confirm template (#9272) --- changelog.d/9272.feature | 1 + docs/sample_config.yaml | 14 ++++- synapse/config/sso.py | 14 ++++- synapse/handlers/auth.py | 24 ++++++- synapse/handlers/sso.py | 10 ++- synapse/module_api/__init__.py | 10 ++- synapse/res/templates/sso.css | 83 +++++++++++++++++++++++++ synapse/res/templates/sso_redirect_confirm.html | 34 ++++++++-- tests/handlers/test_cas.py | 8 +-- tests/handlers/test_oidc.py | 24 ++++--- tests/handlers/test_saml.py | 8 +-- 11 files changed, 200 insertions(+), 30 deletions(-) create mode 100644 changelog.d/9272.feature create mode 100644 synapse/res/templates/sso.css (limited to 'synapse/handlers') diff --git a/changelog.d/9272.feature b/changelog.d/9272.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9272.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 8777e3254d..05506a7787 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1971,7 +1971,8 @@ sso: # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # - # When rendering, this template is given three variables: + # When rendering, this template is given the following variables: + # # * redirect_url: the URL the user is about to be redirected to. Needs # manual escaping (see # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). @@ -1984,6 +1985,17 @@ sso: # # * server_name: the homeserver's name. # + # * new_user: a boolean indicating whether this is the user's first time + # logging in. + # + # * user_id: the user's matrix ID. + # + # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. + # None if the user has not set an avatar. + # + # * user_profile.display_name: the user's display name. None if the user + # has not set a display name. + # # * HTML page which notifies the user that they are authenticating to confirm # an operation on their account during the user interactive authentication # process: 'sso_auth_confirm.html'. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index 59be825532..a470112ed4 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -127,7 +127,8 @@ class SSOConfig(Config): # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # - # When rendering, this template is given three variables: + # When rendering, this template is given the following variables: + # # * redirect_url: the URL the user is about to be redirected to. Needs # manual escaping (see # https://jinja.palletsprojects.com/en/2.11.x/templates/#html-escaping). @@ -140,6 +141,17 @@ class SSOConfig(Config): # # * server_name: the homeserver's name. # + # * new_user: a boolean indicating whether this is the user's first time + # logging in. + # + # * user_id: the user's matrix ID. + # + # * user_profile.avatar_url: an MXC URI for the user's avatar, if any. + # None if the user has not set an avatar. + # + # * user_profile.display_name: the user's display name. None if the user + # has not set a display name. + # # * HTML page which notifies the user that they are authenticating to confirm # an operation on their account during the user interactive authentication # process: 'sso_auth_confirm.html'. diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 0e98db22b3..c722a4afa8 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -61,6 +61,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import ModuleApi +from synapse.storage.roommember import ProfileInfo from synapse.types import JsonDict, Requester, UserID from synapse.util import stringutils as stringutils from synapse.util.async_helpers import maybe_awaitable @@ -1396,6 +1397,7 @@ class AuthHandler(BaseHandler): request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, + new_user: bool = False, ): """Having figured out a mxid for this user, complete the HTTP request @@ -1406,6 +1408,8 @@ class AuthHandler(BaseHandler): process. extra_attributes: Extra attributes which will be passed to the client during successful login. Must be JSON serializable. + new_user: True if we should use wording appropriate to a user who has just + registered. """ # If the account has been deactivated, do not proceed with the login # flow. @@ -1414,8 +1418,17 @@ class AuthHandler(BaseHandler): respond_with_html(request, 403, self._sso_account_deactivated_template) return + profile = await self.store.get_profileinfo( + UserID.from_string(registered_user_id).localpart + ) + self._complete_sso_login( - registered_user_id, request, client_redirect_url, extra_attributes + registered_user_id, + request, + client_redirect_url, + extra_attributes, + new_user=new_user, + user_profile_data=profile, ) def _complete_sso_login( @@ -1424,12 +1437,18 @@ class AuthHandler(BaseHandler): request: Request, client_redirect_url: str, extra_attributes: Optional[JsonDict] = None, + new_user: bool = False, + user_profile_data: Optional[ProfileInfo] = None, ): """ The synchronous portion of complete_sso_login. This exists purely for backwards compatibility of synapse.module_api.ModuleApi. """ + + if user_profile_data is None: + user_profile_data = ProfileInfo(None, None) + # Store any extra attributes which will be passed in the login response. # Note that this is per-user so it may overwrite a previous value, this # is considered OK since the newest SSO attributes should be most valid. @@ -1467,6 +1486,9 @@ class AuthHandler(BaseHandler): display_url=redirect_url_no_params, redirect_url=redirect_url, server_name=self._server_name, + new_user=new_user, + user_id=registered_user_id, + user_profile=user_profile_data, ) respond_with_html(request, 200, html) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 50c5ae142a..ceaeb5a376 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -391,6 +391,8 @@ class SsoHandler: to an additional page. (e.g. to prompt for more information) """ + new_user = False + # grab a lock while we try to find a mapping for this user. This seems... # optimistic, especially for implementations that end up redirecting to # interstitial pages. @@ -431,9 +433,14 @@ class SsoHandler: get_request_user_agent(request), request.getClientIP(), ) + new_user = True await self._auth_handler.complete_sso_login( - user_id, request, client_redirect_url, extra_login_attributes + user_id, + request, + client_redirect_url, + extra_login_attributes, + new_user=new_user, ) async def _call_attribute_mapper( @@ -778,6 +785,7 @@ class SsoHandler: request, session.client_redirect_url, session.extra_login_attributes, + new_user=True, ) def _expire_old_sessions(self): diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 72ab5750cc..401d577293 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -279,7 +279,11 @@ class ModuleApi: ) async def complete_sso_login_async( - self, registered_user_id: str, request: SynapseRequest, client_redirect_url: str + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + new_user: bool = False, ): """Complete a SSO login by redirecting the user to a page to confirm whether they want their access token sent to `client_redirect_url`, or redirect them to that @@ -291,9 +295,11 @@ class ModuleApi: request: The request to respond to. client_redirect_url: The URL to which to offer to redirect the user (or to redirect them directly if whitelisted). + new_user: set to true to use wording for the consent appropriate to a user + who has just registered. """ await self._auth_handler.complete_sso_login( - registered_user_id, request, client_redirect_url, + registered_user_id, request, client_redirect_url, new_user=new_user ) @defer.inlineCallbacks diff --git a/synapse/res/templates/sso.css b/synapse/res/templates/sso.css new file mode 100644 index 0000000000..ff9dc94032 --- /dev/null +++ b/synapse/res/templates/sso.css @@ -0,0 +1,83 @@ +body { + font-family: "Inter", "Helvetica", "Arial", sans-serif; + font-size: 14px; + color: #17191C; +} + +header { + max-width: 480px; + width: 100%; + margin: 24px auto; + text-align: center; +} + +header p { + color: #737D8C; + line-height: 24px; +} + +h1 { + font-size: 24px; +} + +h2 { + font-size: 14px; +} + +h2 img { + vertical-align: middle; + margin-right: 8px; + width: 24px; + height: 24px; +} + +label { + cursor: pointer; +} + +main { + max-width: 360px; + width: 100%; + margin: 24px auto; +} + +.primary-button { + border: none; + text-decoration: none; + padding: 12px; + color: white; + background-color: #418DED; + font-weight: bold; + display: block; + border-radius: 12px; + width: 100%; + margin: 16px 0; + cursor: pointer; + text-align: center; +} + +.profile { + display: flex; + justify-content: center; + margin: 24px 0; +} + +.profile .avatar { + width: 36px; + height: 36px; + border-radius: 100%; + display: block; + margin-right: 8px; +} + +.profile .display-name { + font-weight: bold; + margin-bottom: 4px; +} +.profile .user-id { + color: #737D8C; +} + +.profile .display-name, .profile .user-id { + line-height: 18px; +} \ No newline at end of file diff --git a/synapse/res/templates/sso_redirect_confirm.html b/synapse/res/templates/sso_redirect_confirm.html index 20a15e1e74..ce4f573848 100644 --- a/synapse/res/templates/sso_redirect_confirm.html +++ b/synapse/res/templates/sso_redirect_confirm.html @@ -3,12 +3,34 @@ SSO redirect confirmation + + -

    The application at {{ display_url | e }} is requesting full access to your {{ server_name }} Matrix account.

    -

    If you don't recognise this address, you should ignore this and close this tab.

    -

    - I trust this address -

    +
    + {% if new_user %} +

    Your account is now ready

    +

    You've made your account on {{ server_name | e }}.

    + {% else %} +

    Log in

    + {% endif %} +

    Continue to confirm you trust {{ display_url | e }}.

    +
    +
    + {% if user_profile.avatar_url %} +
    + +
    + {% if user_profile.display_name %} +
    {{ user_profile.display_name | e }}
    + {% endif %} +
    {{ user_id | e }}
    +
    +
    + {% endif %} + Continue +
    - \ No newline at end of file + diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py index c37bb6440e..7baf224f7e 100644 --- a/tests/handlers/test_cas.py +++ b/tests/handlers/test_cas.py @@ -62,7 +62,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=True ) def test_map_cas_user_to_existing_user(self): @@ -85,7 +85,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -94,7 +94,7 @@ class CasHandlerTestCase(HomeserverTestCase): self.handler._handle_cas_response(request, cas_response, "redirect_uri", "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=False ) def test_map_cas_user_to_invalid_localpart(self): @@ -112,7 +112,7 @@ class CasHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@f=c3=b6=c3=b6:test", request, "redirect_uri", None + "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True ) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index b3dfa40d25..d8f90b9a80 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -419,7 +419,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, + expected_user_id, request, client_redirect_url, None, new_user=True ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce) @@ -450,7 +450,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - expected_user_id, request, client_redirect_url, None, + expected_user_id, request, client_redirect_url, None, new_user=False ) self.provider._exchange_code.assert_called_once_with(code) self.provider._parse_id_token.assert_not_called() @@ -623,7 +623,11 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(self.handler.handle_oidc_callback(request)) auth_handler.complete_sso_login.assert_called_once_with( - "@foo:test", request, client_redirect_url, {"phone": "1234567"}, + "@foo:test", + request, + client_redirect_url, + {"phone": "1234567"}, + new_user=True, ) def test_map_userinfo_to_user(self): @@ -637,7 +641,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", ANY, ANY, None, + "@test_user:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -648,7 +652,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user_2:test", ANY, ANY, None, + "@test_user_2:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() @@ -685,14 +689,14 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() # Subsequent calls should map to the same mxid. self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -707,7 +711,7 @@ class OidcHandlerTestCase(HomeserverTestCase): } self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - user.to_string(), ANY, ANY, None, + user.to_string(), ANY, ANY, None, new_user=False ) auth_handler.complete_sso_login.reset_mock() @@ -743,7 +747,7 @@ class OidcHandlerTestCase(HomeserverTestCase): self.get_success(_make_callback_with_userinfo(self.hs, userinfo)) auth_handler.complete_sso_login.assert_called_once_with( - "@TEST_USER_2:test", ANY, ANY, None, + "@TEST_USER_2:test", ANY, ANY, None, new_user=False ) def test_map_userinfo_to_invalid_localpart(self): @@ -779,7 +783,7 @@ class OidcHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", ANY, ANY, None, + "@test_user1:test", ANY, ANY, None, new_user=True ) auth_handler.complete_sso_login.reset_mock() diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py index 261c7083d1..a8d6c0f617 100644 --- a/tests/handlers/test_saml.py +++ b/tests/handlers/test_saml.py @@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "redirect_uri", None + "@test_user:test", request, "redirect_uri", None, new_user=True ) @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}}) @@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # check that the auth handler got called as expected auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "", None + "@test_user:test", request, "", None, new_user=False ) # Subsequent calls should map to the same mxid. @@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase): self.handler._handle_authn_response(request, saml_response, "") ) auth_handler.complete_sso_login.assert_called_once_with( - "@test_user:test", request, "", None + "@test_user:test", request, "", None, new_user=False ) def test_map_saml_response_to_invalid_localpart(self): @@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase): # test_user is already taken, so test_user1 gets registered instead. auth_handler.complete_sso_login.assert_called_once_with( - "@test_user1:test", request, "", None + "@test_user1:test", request, "", None, new_user=True ) auth_handler.complete_sso_login.reset_mock() -- cgit 1.5.1 From 4167494c90bc0477bdf4855a79e81dc81bba1377 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:52:50 +0000 Subject: Replace username picker with a template (#9275) There's some prelimiary work here to pull out the construction of a jinja environment to a separate function. I wanted to load the template at display time rather than load time, so that it's easy to update on the fly. Honestly, I think we should do this with all our templates: the risk of ending up with malformed templates is far outweighed by the improved turnaround time for an admin trying to update them. --- changelog.d/9275.feature | 1 + docs/sample_config.yaml | 32 +++++- synapse/config/_base.py | 39 +------ synapse/config/oidc_config.py | 3 +- synapse/config/sso.py | 33 +++++- synapse/handlers/sso.py | 2 +- .../res/templates/sso_auth_account_details.html | 115 +++++++++++++++++++++ synapse/res/templates/sso_auth_account_details.js | 76 ++++++++++++++ synapse/res/username_picker/index.html | 19 ---- synapse/res/username_picker/script.js | 95 ----------------- synapse/res/username_picker/style.css | 27 ----- synapse/rest/consent/consent_resource.py | 1 + synapse/rest/synapse/client/pick_username.py | 79 ++++++++++---- synapse/util/templates.py | 106 +++++++++++++++++++ tests/rest/client/v1/test_login.py | 5 +- 15 files changed, 429 insertions(+), 204 deletions(-) create mode 100644 changelog.d/9275.feature create mode 100644 synapse/res/templates/sso_auth_account_details.html create mode 100644 synapse/res/templates/sso_auth_account_details.js delete mode 100644 synapse/res/username_picker/index.html delete mode 100644 synapse/res/username_picker/script.js delete mode 100644 synapse/res/username_picker/style.css create mode 100644 synapse/util/templates.py (limited to 'synapse/handlers') diff --git a/changelog.d/9275.feature b/changelog.d/9275.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9275.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index 05506a7787..a6fbcc6080 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1801,7 +1801,8 @@ saml2_config: # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their -# own username. +# own username (see 'sso_auth_account_details.html' in the 'sso' +# section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. @@ -1968,6 +1969,35 @@ sso: # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/config/_base.py b/synapse/config/_base.py index 94144efc87..35e5594b73 100644 --- a/synapse/config/_base.py +++ b/synapse/config/_base.py @@ -18,18 +18,18 @@ import argparse import errno import os -import time -import urllib.parse from collections import OrderedDict from hashlib import sha256 from textwrap import dedent -from typing import Any, Callable, Iterable, List, MutableMapping, Optional +from typing import Any, Iterable, List, MutableMapping, Optional import attr import jinja2 import pkg_resources import yaml +from synapse.util.templates import _create_mxc_to_http_filter, _format_ts_filter + class ConfigError(Exception): """Represents a problem parsing the configuration @@ -248,6 +248,7 @@ class Config: # Search the custom template directory as well search_directories.insert(0, custom_template_directory) + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(search_directories) env = jinja2.Environment(loader=loader, autoescape=autoescape) @@ -267,38 +268,6 @@ class Config: return templates -def _format_ts_filter(value: int, format: str): - return time.strftime(format, time.localtime(value / 1000)) - - -def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: - """Create and return a jinja2 filter that converts MXC urls to HTTP - - Args: - public_baseurl: The public, accessible base URL of the homeserver - """ - - def mxc_to_http_filter(value, width, height, resize_method="crop"): - if value[0:6] != "mxc://": - return "" - - server_and_media_id = value[6:] - fragment = None - if "#" in server_and_media_id: - server_and_media_id, fragment = server_and_media_id.split("#", 1) - fragment = "#" + fragment - - params = {"width": width, "height": height, "method": resize_method} - return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( - public_baseurl, - server_and_media_id, - urllib.parse.urlencode(params), - fragment or "", - ) - - return mxc_to_http_filter - - class RootConfig: """ Holder of an application's configuration. diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index f31511e039..784b416f95 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -152,7 +152,8 @@ class OIDCConfig(Config): # # localpart_template: Jinja2 template for the localpart of the MXID. # If this is not set, the user will be prompted to choose their - # own username. + # own username (see 'sso_auth_account_details.html' in the 'sso' + # section of this file). # # display_name_template: Jinja2 template for the display name to set # on first login. If unset, no displayname will be set. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index a470112ed4..e308fc9333 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -27,7 +27,7 @@ class SSOConfig(Config): sso_config = config.get("sso") or {} # type: Dict[str, Any] # The sso-specific template_dir - template_dir = sso_config.get("template_dir") + self.sso_template_dir = sso_config.get("template_dir") # Read templates from disk ( @@ -48,7 +48,7 @@ class SSOConfig(Config): "sso_auth_success.html", "sso_auth_bad_user.html", ], - template_dir, + self.sso_template_dir, ) # These templates have no placeholders, so render them here @@ -124,6 +124,35 @@ class SSOConfig(Config): # # * idp: the 'idp_id' of the chosen IDP. # + # * HTML page to prompt new users to enter a userid and confirm other + # details: 'sso_auth_account_details.html'. This is only shown if the + # SSO implementation (with any user_mapping_provider) does not return + # a localpart. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * idp: details of the SSO Identity Provider that the user logged in + # with: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # + # * user_attributes: an object containing details about the user that + # we received from the IdP. May have the following attributes: + # + # * display_name: the user's display_name + # * emails: a list of email addresses + # + # The template should render a form which submits the following fields: + # + # * username: the localpart of the user's chosen user id + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index ceaeb5a376..ff4750999a 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -530,7 +530,7 @@ class SsoHandler: logger.info("Recorded registration session id %s", session_id) # Set the cookie and redirect to the username picker - e = RedirectException(b"/_synapse/client/pick_username") + e = RedirectException(b"/_synapse/client/pick_username/account_details") e.cookies.append( b"%s=%s; path=/" % (USERNAME_MAPPING_SESSION_COOKIE_NAME, session_id.encode("ascii")) diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html new file mode 100644 index 0000000000..f22b09aec1 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.html @@ -0,0 +1,115 @@ + + + + Synapse Login + + + + + +
    +

    Your account is nearly ready

    +

    Check your details before creating an account on {{ server_name }}

    +
    +
    +
    +
    + +
    @
    + +
    :{{ server_name }}
    +
    + + {% if user_attributes %} +
    +

    Information from {{ idp.idp_name }}

    + {% if user_attributes.avatar_url %} +
    + +
    + {% endif %} + {% if user_attributes.display_name %} +
    +

    {{ user_attributes.display_name }}

    +
    + {% endif %} + {% for email in user_attributes.emails %} +
    +

    {{ email }}

    +
    + {% endfor %} +
    + {% endif %} +
    +
    + + + diff --git a/synapse/res/templates/sso_auth_account_details.js b/synapse/res/templates/sso_auth_account_details.js new file mode 100644 index 0000000000..deef419bb6 --- /dev/null +++ b/synapse/res/templates/sso_auth_account_details.js @@ -0,0 +1,76 @@ +const usernameField = document.getElementById("field-username"); + +function throttle(fn, wait) { + let timeout; + return function() { + const args = Array.from(arguments); + if (timeout) { + clearTimeout(timeout); + } + timeout = setTimeout(fn.bind.apply(fn, [null].concat(args)), wait); + } +} + +function checkUsernameAvailable(username) { + let check_uri = 'check?username=' + encodeURIComponent(username); + return fetch(check_uri, { + // include the cookie + "credentials": "same-origin", + }).then((response) => { + if(!response.ok) { + // for non-200 responses, raise the body of the response as an exception + return response.text().then((text) => { throw new Error(text); }); + } else { + return response.json(); + } + }).then((json) => { + if(json.error) { + return {message: json.error}; + } else if(json.available) { + return {available: true}; + } else { + return {message: username + " is not available, please choose another."}; + } + }); +} + +function validateUsername(username) { + usernameField.setCustomValidity(""); + if (usernameField.validity.valueMissing) { + usernameField.setCustomValidity("Please provide a username"); + return; + } + if (usernameField.validity.patternMismatch) { + usernameField.setCustomValidity("Invalid username, please only use " + allowedCharactersString); + return; + } + usernameField.setCustomValidity("Checking if username is available …"); + throttledCheckUsernameAvailable(username); +} + +const throttledCheckUsernameAvailable = throttle(function(username) { + const handleError = function(err) { + // don't prevent form submission on error + usernameField.setCustomValidity(""); + console.log(err.message); + }; + try { + checkUsernameAvailable(username).then(function(result) { + if (!result.available) { + usernameField.setCustomValidity(result.message); + usernameField.reportValidity(); + } else { + usernameField.setCustomValidity(""); + } + }, handleError); + } catch (err) { + handleError(err); + } +}, 500); + +usernameField.addEventListener("input", function(evt) { + validateUsername(usernameField.value); +}); +usernameField.addEventListener("change", function(evt) { + validateUsername(usernameField.value); +}); diff --git a/synapse/res/username_picker/index.html b/synapse/res/username_picker/index.html deleted file mode 100644 index 37ea8bb6d8..0000000000 --- a/synapse/res/username_picker/index.html +++ /dev/null @@ -1,19 +0,0 @@ - - - - Synapse Login - - - -
    -
    - - - -
    - - - -
    - - diff --git a/synapse/res/username_picker/script.js b/synapse/res/username_picker/script.js deleted file mode 100644 index 416a7c6f41..0000000000 --- a/synapse/res/username_picker/script.js +++ /dev/null @@ -1,95 +0,0 @@ -let inputField = document.getElementById("field-username"); -let inputForm = document.getElementById("form"); -let submitButton = document.getElementById("button-submit"); -let message = document.getElementById("message"); - -// Submit username and receive response -function showMessage(messageText) { - // Unhide the message text - message.classList.remove("hidden"); - - message.textContent = messageText; -}; - -function doSubmit() { - showMessage("Success. Please wait a moment for your browser to redirect."); - - // remove the event handler before re-submitting the form. - delete inputForm.onsubmit; - inputForm.submit(); -} - -function onResponse(response) { - // Display message - showMessage(response); - - // Enable submit button and input field - submitButton.classList.remove('button--disabled'); - submitButton.value = "Submit"; -}; - -let allowedUsernameCharacters = RegExp("[^a-z0-9\\.\\_\\=\\-\\/]"); -function usernameIsValid(username) { - return !allowedUsernameCharacters.test(username); -} -let allowedCharactersString = "lowercase letters, digits, ., _, -, /, ="; - -function buildQueryString(params) { - return Object.keys(params) - .map(k => encodeURIComponent(k) + '=' + encodeURIComponent(params[k])) - .join('&'); -} - -function submitUsername(username) { - if(username.length == 0) { - onResponse("Please enter a username."); - return; - } - if(!usernameIsValid(username)) { - onResponse("Invalid username. Only the following characters are allowed: " + allowedCharactersString); - return; - } - - // if this browser doesn't support fetch, skip the availability check. - if(!window.fetch) { - doSubmit(); - return; - } - - let check_uri = 'check?' + buildQueryString({"username": username}); - fetch(check_uri, { - // include the cookie - "credentials": "same-origin", - }).then((response) => { - if(!response.ok) { - // for non-200 responses, raise the body of the response as an exception - return response.text().then((text) => { throw text; }); - } else { - return response.json(); - } - }).then((json) => { - if(json.error) { - throw json.error; - } else if(json.available) { - doSubmit(); - } else { - onResponse("This username is not available, please choose another."); - } - }).catch((err) => { - onResponse("Error checking username availability: " + err); - }); -} - -function clickSubmit() { - event.preventDefault(); - if(submitButton.classList.contains('button--disabled')) { return; } - - // Disable submit button and input field - submitButton.classList.add('button--disabled'); - - // Submit username - submitButton.value = "Checking..."; - submitUsername(inputField.value); -}; - -inputForm.onsubmit = clickSubmit; diff --git a/synapse/res/username_picker/style.css b/synapse/res/username_picker/style.css deleted file mode 100644 index 745bd4c684..0000000000 --- a/synapse/res/username_picker/style.css +++ /dev/null @@ -1,27 +0,0 @@ -input[type="text"] { - font-size: 100%; - background-color: #ededf0; - border: 1px solid #fff; - border-radius: .2em; - padding: .5em .9em; - display: block; - width: 26em; -} - -.button--disabled { - border-color: #fff; - background-color: transparent; - color: #000; - text-transform: none; -} - -.hidden { - display: none; -} - -.tooltip { - background-color: #f9f9fa; - padding: 1em; - margin: 1em 0; -} - diff --git a/synapse/rest/consent/consent_resource.py b/synapse/rest/consent/consent_resource.py index b3e4d5612e..8b9ef26cf2 100644 --- a/synapse/rest/consent/consent_resource.py +++ b/synapse/rest/consent/consent_resource.py @@ -100,6 +100,7 @@ class ConsentResource(DirectServeHtmlResource): consent_template_directory = hs.config.user_consent_template_dir + # TODO: switch to synapse.util.templates.build_jinja_env loader = jinja2.FileSystemLoader(consent_template_directory) self._jinja_env = jinja2.Environment( loader=loader, autoescape=jinja2.select_autoescape(["html", "htm", "xml"]) diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 1bc737bad0..27540d3bbe 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -13,41 +13,41 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging from typing import TYPE_CHECKING -import pkg_resources - from twisted.web.http import Request from twisted.web.resource import Resource -from twisted.web.static import File +from synapse.api.errors import SynapseError from synapse.handlers.sso import get_username_mapping_session_cookie_from_request -from synapse.http.server import DirectServeHtmlResource, DirectServeJsonResource +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + respond_with_html, +) from synapse.http.servlet import parse_string from synapse.http.site import SynapseRequest +from synapse.util.templates import build_jinja_env if TYPE_CHECKING: from synapse.server import HomeServer +logger = logging.getLogger(__name__) + def pick_username_resource(hs: "HomeServer") -> Resource: """Factory method to generate the username picker resource. - This resource gets mounted under /_synapse/client/pick_username. The top-level - resource is just a File resource which serves up the static files in the resources - "res" directory, but it has a couple of children: + This resource gets mounted under /_synapse/client/pick_username and has two + children: - * "submit", which does the mechanics of registering the new user, and redirects the - browser back to the client URL - - * "check": checks if a userid is free. + * "account_details": renders the form and handles the POSTed response + * "check": a JSON endpoint which checks if a userid is free. """ - # XXX should we make this path customisable so that admins can restyle it? - base_path = pkg_resources.resource_filename("synapse", "res/username_picker") - - res = File(base_path) - res.putChild(b"submit", SubmitResource(hs)) + res = Resource() + res.putChild(b"account_details", AccountDetailsResource(hs)) res.putChild(b"check", AvailabilityCheckResource(hs)) return res @@ -69,15 +69,54 @@ class AvailabilityCheckResource(DirectServeJsonResource): return 200, {"available": is_available} -class SubmitResource(DirectServeHtmlResource): +class AccountDetailsResource(DirectServeHtmlResource): def __init__(self, hs: "HomeServer"): super().__init__() self._sso_handler = hs.get_sso_handler() - async def _async_render_POST(self, request: SynapseRequest): - localpart = parse_string(request, "username", required=True) + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + idp_id = session.auth_provider_id + template_params = { + "idp": self._sso_handler.get_identity_providers()[idp_id], + "user_attributes": { + "display_name": session.display_name, + "emails": session.emails, + }, + } + + template = self._jinja_env.get_template("sso_auth_account_details.html") + html = template.render(template_params) + respond_with_html(request, 200, html) - session_id = get_username_mapping_session_cookie_from_request(request) + async def _async_render_POST(self, request: SynapseRequest): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + localpart = parse_string(request, "username", required=True) + except SynapseError as e: + logger.warning("[session %s] bad param: %s", session_id, e) + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return await self._sso_handler.handle_submit_username_request( request, localpart, session_id diff --git a/synapse/util/templates.py b/synapse/util/templates.py new file mode 100644 index 0000000000..7e5109d206 --- /dev/null +++ b/synapse/util/templates.py @@ -0,0 +1,106 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for dealing with jinja2 templates""" + +import time +import urllib.parse +from typing import TYPE_CHECKING, Callable, Iterable, Union + +import jinja2 + +if TYPE_CHECKING: + from synapse.config.homeserver import HomeServerConfig + + +def build_jinja_env( + template_search_directories: Iterable[str], + config: "HomeServerConfig", + autoescape: Union[bool, Callable[[str], bool], None] = None, +) -> jinja2.Environment: + """Set up a Jinja2 environment to load templates from the given search path + + The returned environment defines the following filters: + - format_ts: formats timestamps as strings in the server's local timezone + (XXX: why is that useful??) + - mxc_to_http: converts mxc: uris to http URIs. Args are: + (uri, width, height, resize_method="crop") + + and the following global variables: + - server_name: matrix server name + + Args: + template_search_directories: directories to search for templates + + config: homeserver config, for things like `server_name` and `public_baseurl` + + autoescape: whether template variables should be autoescaped. bool, or + a function mapping from template name to bool. Defaults to escaping templates + whose names end in .html, .xml or .htm. + + Returns: + jinja environment + """ + + if autoescape is None: + autoescape = jinja2.select_autoescape() + + loader = jinja2.FileSystemLoader(template_search_directories) + env = jinja2.Environment(loader=loader, autoescape=autoescape) + + # Update the environment with our custom filters + env.filters.update( + { + "format_ts": _format_ts_filter, + "mxc_to_http": _create_mxc_to_http_filter(config.public_baseurl), + } + ) + + # common variables for all templates + env.globals.update({"server_name": config.server_name}) + + return env + + +def _create_mxc_to_http_filter(public_baseurl: str) -> Callable: + """Create and return a jinja2 filter that converts MXC urls to HTTP + + Args: + public_baseurl: The public, accessible base URL of the homeserver + """ + + def mxc_to_http_filter(value, width, height, resize_method="crop"): + if value[0:6] != "mxc://": + return "" + + server_and_media_id = value[6:] + fragment = None + if "#" in server_and_media_id: + server_and_media_id, fragment = server_and_media_id.split("#", 1) + fragment = "#" + fragment + + params = {"width": width, "height": height, "method": resize_method} + return "%s_matrix/media/v1/thumbnail/%s?%s%s" % ( + public_baseurl, + server_and_media_id, + urllib.parse.urlencode(params), + fragment or "", + ) + + return mxc_to_http_filter + + +def _format_ts_filter(value: int, format: str): + return time.strftime(format, time.localtime(value / 1000)) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py index ded22a9767..66dfdaffbc 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py @@ -1222,7 +1222,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # that should redirect to the username picker self.assertEqual(channel.code, 302, channel.result) picker_url = channel.headers.getRawHeaders("Location")[0] - self.assertEqual(picker_url, "/_synapse/client/pick_username") + self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details") # ... with a username_mapping_session cookie cookies = {} # type: Dict[str,str] @@ -1247,11 +1247,10 @@ class UsernamePickerTestCase(HomeserverTestCase): # Now, submit a username to the username picker, which should serve a redirect # to the completion page - submit_path = picker_url + "/submit" content = urlencode({b"username": b"bobby"}).encode("utf8") chan = self.make_request( "POST", - path=submit_path, + path=picker_url, content=content, content_is_form=True, custom_headers=[ -- cgit 1.5.1 From a800603561c0cb58727474035b6b27ed9e5fc277 Mon Sep 17 00:00:00 2001 From: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> Date: Mon, 1 Feb 2021 15:54:39 +0000 Subject: Prevent email UIA failures from raising a LoginError (#9265) Context, Fixes: https://github.com/matrix-org/synapse/issues/9263 In the past to fix an issue with old Riots re-requesting threepid validation tokens, we raised a `LoginError` during UIA instead of `InteractiveAuthIncompleteError`. This is now breaking the way Tchap logs in - which isn't standard, but also isn't disallowed by the spec. An easy fix is just to remove the 4 year old workaround. --- changelog.d/9265.bugfix | 1 + synapse/handlers/auth.py | 10 ---------- 2 files changed, 1 insertion(+), 10 deletions(-) create mode 100644 changelog.d/9265.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/9265.bugfix b/changelog.d/9265.bugfix new file mode 100644 index 0000000000..34f7bd8ddd --- /dev/null +++ b/changelog.d/9265.bugfix @@ -0,0 +1 @@ +Prevent password hashes from getting dropped if a client failed threepid validation during a User Interactive Auth stage. Removes a workaround for an ancient bug in Riot Web Date: Mon, 1 Feb 2021 17:30:42 +0000 Subject: Make importing display name and email optional (#9277) --- changelog.d/9277.feature | 1 + synapse/handlers/register.py | 5 ++- synapse/handlers/sso.py | 52 ++++++++++++++++++---- .../res/templates/sso_auth_account_details.html | 23 ++++++++++ synapse/rest/synapse/client/pick_username.py | 14 ++++-- 5 files changed, 82 insertions(+), 13 deletions(-) create mode 100644 changelog.d/9277.feature (limited to 'synapse/handlers') diff --git a/changelog.d/9277.feature b/changelog.d/9277.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9277.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index a2cf0f6f3e..b20a5d8605 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -14,8 +14,9 @@ # limitations under the License. """Contains functions for registering clients.""" + import logging -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple from synapse import types from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType @@ -152,7 +153,7 @@ class RegistrationHandler(BaseHandler): user_type: Optional[str] = None, default_display_name: Optional[str] = None, address: Optional[str] = None, - bind_emails: List[str] = [], + bind_emails: Iterable[str] = [], by_admin: bool = False, user_agent_ips: Optional[List[Tuple[str, str]]] = None, ) -> str: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index ff4750999a..d7ca2918f8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -14,7 +14,16 @@ # limitations under the License. import abc import logging -from typing import TYPE_CHECKING, Awaitable, Callable, Dict, List, Mapping, Optional +from typing import ( + TYPE_CHECKING, + Awaitable, + Callable, + Dict, + Iterable, + Mapping, + Optional, + Set, +) from urllib.parse import urlencode import attr @@ -29,7 +38,7 @@ from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html, respond_with_redirect from synapse.http.site import SynapseRequest -from synapse.types import JsonDict, UserID, contains_invalid_mxid_characters +from synapse.types import Collection, JsonDict, UserID, contains_invalid_mxid_characters from synapse.util.async_helpers import Linearizer from synapse.util.stringutils import random_string @@ -115,7 +124,7 @@ class UserAttributes: # enter one. localpart = attr.ib(type=Optional[str]) display_name = attr.ib(type=Optional[str], default=None) - emails = attr.ib(type=List[str], default=attr.Factory(list)) + emails = attr.ib(type=Collection[str], default=attr.Factory(list)) @attr.s(slots=True) @@ -130,7 +139,7 @@ class UsernameMappingSession: # attributes returned by the ID mapper display_name = attr.ib(type=Optional[str]) - emails = attr.ib(type=List[str]) + emails = attr.ib(type=Collection[str]) # An optional dictionary of extra attributes to be provided to the client in the # login response. @@ -144,6 +153,8 @@ class UsernameMappingSession: # choices made by the user chosen_localpart = attr.ib(type=Optional[str], default=None) + use_display_name = attr.ib(type=bool, default=True) + emails_to_use = attr.ib(type=Collection[str], default=()) # the HTTP cookie used to track the mapping session id @@ -710,7 +721,12 @@ class SsoHandler: return not user_infos async def handle_submit_username_request( - self, request: SynapseRequest, localpart: str, session_id: str + self, + request: SynapseRequest, + session_id: str, + localpart: str, + use_display_name: bool, + emails_to_use: Iterable[str], ) -> None: """Handle a request to the username-picker 'submit' endpoint @@ -720,11 +736,30 @@ class SsoHandler: request: HTTP request localpart: localpart requested by the user session_id: ID of the username mapping session, extracted from a cookie + use_display_name: whether the user wants to use the suggested display name + emails_to_use: emails that the user would like to use """ session = self.get_mapping_session(session_id) # update the session with the user's choices session.chosen_localpart = localpart + session.use_display_name = use_display_name + + emails_from_idp = set(session.emails) + filtered_emails = set() # type: Set[str] + + # we iterate through the list rather than just building a set conjunction, so + # that we can log attempts to use unknown addresses + for email in emails_to_use: + if email in emails_from_idp: + filtered_emails.add(email) + else: + logger.warning( + "[session %s] ignoring user request to use unknown email address %r", + session_id, + email, + ) + session.emails_to_use = filtered_emails # we're done; now we can register the user respond_with_redirect(request, b"/_synapse/client/sso_register") @@ -747,11 +782,12 @@ class SsoHandler: ) attributes = UserAttributes( - localpart=session.chosen_localpart, - display_name=session.display_name, - emails=session.emails, + localpart=session.chosen_localpart, emails=session.emails_to_use, ) + if session.use_display_name: + attributes.display_name = session.display_name + # the following will raise a 400 error if the username has been taken in the # meantime. user_id = await self._register_mapped_user( diff --git a/synapse/res/templates/sso_auth_account_details.html b/synapse/res/templates/sso_auth_account_details.html index f22b09aec1..105063825a 100644 --- a/synapse/res/templates/sso_auth_account_details.html +++ b/synapse/res/templates/sso_auth_account_details.html @@ -53,6 +53,14 @@ border-top: 1px solid #E9ECF1; padding: 12px; } + .idp-pick-details .check-row { + display: flex; + align-items: center; + } + + .idp-pick-details .check-row .name { + flex: 1; + } .idp-pick-details .use, .idp-pick-details .idp-value { color: #737D8C; @@ -91,16 +99,31 @@

    Information from {{ idp.idp_name }}

    {% if user_attributes.avatar_url %}
    +
    + + + +
    {% endif %} {% if user_attributes.display_name %}
    +
    + + + +

    {{ user_attributes.display_name }}

    {% endif %} {% for email in user_attributes.emails %}
    +
    + + + +

    {{ email }}

    {% endfor %} diff --git a/synapse/rest/synapse/client/pick_username.py b/synapse/rest/synapse/client/pick_username.py index 27540d3bbe..96077cfcd1 100644 --- a/synapse/rest/synapse/client/pick_username.py +++ b/synapse/rest/synapse/client/pick_username.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, List from twisted.web.http import Request from twisted.web.resource import Resource @@ -26,7 +26,7 @@ from synapse.http.server import ( DirectServeJsonResource, respond_with_html, ) -from synapse.http.servlet import parse_string +from synapse.http.servlet import parse_boolean, parse_string from synapse.http.site import SynapseRequest from synapse.util.templates import build_jinja_env @@ -113,11 +113,19 @@ class AccountDetailsResource(DirectServeHtmlResource): try: localpart = parse_string(request, "username", required=True) + use_display_name = parse_boolean(request, "use_display_name", default=False) + + try: + emails_to_use = [ + val.decode("utf-8") for val in request.args.get(b"use_email", []) + ] # type: List[str] + except ValueError: + raise SynapseError(400, "Query parameter use_email must be utf-8") except SynapseError as e: logger.warning("[session %s] bad param: %s", session_id, e) self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) return await self._sso_handler.handle_submit_username_request( - request, localpart, session_id + request, session_id, localpart, use_display_name, emails_to_use ) -- cgit 1.5.1 From e5d70c8a82f5d615dbdc7f7aa69288681ba2e9f5 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 18:36:04 +0000 Subject: Improve styling and wording of SSO UIA templates (#9286) fixes #9171 --- changelog.d/9286.feature | 1 + docs/sample_config.yaml | 15 +++++++++++ synapse/config/sso.py | 15 +++++++++++ synapse/handlers/auth.py | 4 ++- synapse/res/templates/sso_auth_confirm.html | 32 ++++++++++++++++------- synapse/res/templates/sso_auth_success.html | 39 ++++++++++++++++++----------- 6 files changed, 81 insertions(+), 25 deletions(-) create mode 100644 changelog.d/9286.feature (limited to 'synapse/handlers') diff --git a/changelog.d/9286.feature b/changelog.d/9286.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9286.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index a6fbcc6080..eec082ca8c 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1958,8 +1958,13 @@ sso: # # * providers: a list of available Identity Providers. Each element is # an object with the following attributes: + # # * idp_id: unique identifier for the IdP # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP # # The rendered HTML page should contain a form which submits its results # back as a GET request, with the following query parameters: @@ -2037,6 +2042,16 @@ sso: # # * description: the operation which the user is being asked to confirm # + # * idp: details of the Identity Provider that we will use to confirm + # the user's identity: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # # * HTML page shown after a successful user interactive authentication session: # 'sso_auth_success.html'. # diff --git a/synapse/config/sso.py b/synapse/config/sso.py index e308fc9333..bf82183cdc 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -113,8 +113,13 @@ class SSOConfig(Config): # # * providers: a list of available Identity Providers. Each element is # an object with the following attributes: + # # * idp_id: unique identifier for the IdP # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP # # The rendered HTML page should contain a form which submits its results # back as a GET request, with the following query parameters: @@ -192,6 +197,16 @@ class SSOConfig(Config): # # * description: the operation which the user is being asked to confirm # + # * idp: details of the Identity Provider that we will use to confirm + # the user's identity: an object with the following attributes: + # + # * idp_id: unique identifier for the IdP + # * idp_name: user-facing name for the IdP + # * idp_icon: if specified in the IdP config, an MXC URI for an icon + # for the IdP + # * idp_brand: if specified in the IdP config, a textual identifier + # for the brand of the IdP + # # * HTML page shown after a successful user interactive authentication session: # 'sso_auth_success.html'. # diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index c722a4afa8..6f746711ca 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -1388,7 +1388,9 @@ class AuthHandler(BaseHandler): ) return self._sso_auth_confirm_template.render( - description=session.description, redirect_url=redirect_url, + description=session.description, + redirect_url=redirect_url, + idp=sso_auth_provider, ) async def complete_sso_login( diff --git a/synapse/res/templates/sso_auth_confirm.html b/synapse/res/templates/sso_auth_confirm.html index 0d9de9d465..d572ab87f7 100644 --- a/synapse/res/templates/sso_auth_confirm.html +++ b/synapse/res/templates/sso_auth_confirm.html @@ -1,14 +1,28 @@ - - - Authentication - + + + + + Authentication + + + -
    +
    +

    Confirm it's you to continue

    - A client is trying to {{ description | e }}. To confirm this action, - re-authenticate with single sign-on. - If you did not expect this, your account may be compromised! + A client is trying to {{ description | e }}. To confirm this action + re-authorize your account with single sign-on.

    -
    +

    + If you did not expect this, your account may be compromised. +

    + +
    + + Continue with {{ idp.idp_name | e }} + +
    diff --git a/synapse/res/templates/sso_auth_success.html b/synapse/res/templates/sso_auth_success.html index 03f1419467..3b975d7219 100644 --- a/synapse/res/templates/sso_auth_success.html +++ b/synapse/res/templates/sso_auth_success.html @@ -1,18 +1,27 @@ - - - Authentication Successful - - + + + + + Authentication successful + + + + -
    -

    Thank you

    -

    You may now close this window and return to the application

    -
    +
    +

    Thank you

    +

    + Now we know it’s you, you can close this window and return to the + application. +

    +
    -- cgit 1.5.1 From c543bf87ecf295fa68311beabd1dc013288a2e98 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 18:37:41 +0000 Subject: Collect terms consent from the user during SSO registration (#9276) --- changelog.d/9276.feature | 1 + docs/sample_config.yaml | 22 ++++++ docs/workers.md | 1 + synapse/config/sso.py | 22 ++++++ synapse/handlers/register.py | 2 + synapse/handlers/sso.py | 44 +++++++++++ synapse/res/templates/sso_new_user_consent.html | 39 ++++++++++ synapse/rest/synapse/client/__init__.py | 2 + synapse/rest/synapse/client/new_user_consent.py | 97 +++++++++++++++++++++++++ 9 files changed, 230 insertions(+) create mode 100644 changelog.d/9276.feature create mode 100644 synapse/res/templates/sso_new_user_consent.html create mode 100644 synapse/rest/synapse/client/new_user_consent.py (limited to 'synapse/handlers') diff --git a/changelog.d/9276.feature b/changelog.d/9276.feature new file mode 100644 index 0000000000..c21b197ca1 --- /dev/null +++ b/changelog.d/9276.feature @@ -0,0 +1 @@ +Improve the user experience of setting up an account via single-sign on. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index eec082ca8c..15e9746696 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -2003,6 +2003,28 @@ sso: # # * username: the localpart of the user's chosen user id # + # * HTML page allowing the user to consent to the server's terms and + # conditions. This is only shown for new users, and only if + # `user_consent.require_at_registration` is set. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * user_id: the user's matrix proposed ID. + # + # * user_profile.display_name: the user's proposed display name, if any. + # + # * consent_version: the version of the terms that the user will be + # shown + # + # * terms_url: a link to the page showing the terms. + # + # The template should render a form which submits the following fields: + # + # * accepted_version: the version of the terms accepted by the user + # (ie, 'consent_version' from the input variables). + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/docs/workers.md b/docs/workers.md index 6b8887de36..0da805c333 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -259,6 +259,7 @@ using): ^/_matrix/client/(api/v1|r0|unstable)/login/sso/redirect ^/_synapse/client/pick_idp$ ^/_synapse/client/pick_username + ^/_synapse/client/new_user_consent$ ^/_synapse/client/sso_register$ # OpenID Connect requests. diff --git a/synapse/config/sso.py b/synapse/config/sso.py index bf82183cdc..939eeac6de 100644 --- a/synapse/config/sso.py +++ b/synapse/config/sso.py @@ -158,6 +158,28 @@ class SSOConfig(Config): # # * username: the localpart of the user's chosen user id # + # * HTML page allowing the user to consent to the server's terms and + # conditions. This is only shown for new users, and only if + # `user_consent.require_at_registration` is set. + # + # When rendering, this template is given the following variables: + # + # * server_name: the homeserver's name. + # + # * user_id: the user's matrix proposed ID. + # + # * user_profile.display_name: the user's proposed display name, if any. + # + # * consent_version: the version of the terms that the user will be + # shown + # + # * terms_url: a link to the page showing the terms. + # + # The template should render a form which submits the following fields: + # + # * accepted_version: the version of the terms accepted by the user + # (ie, 'consent_version' from the input variables). + # # * HTML page for a confirmation step before redirecting back to the client # with the login token: 'sso_redirect_confirm.html'. # diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index b20a5d8605..49b085269b 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -694,6 +694,8 @@ class RegistrationHandler(BaseHandler): access_token: The access token of the newly logged in device, or None if `inhibit_login` enabled. """ + # TODO: 3pid registration can actually happen on the workers. Consider + # refactoring it. if self.hs.config.worker_app: await self._post_registration_client( user_id=user_id, auth_result=auth_result, access_token=access_token diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index d7ca2918f8..b450668f1c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -155,6 +155,7 @@ class UsernameMappingSession: chosen_localpart = attr.ib(type=Optional[str], default=None) use_display_name = attr.ib(type=bool, default=True) emails_to_use = attr.ib(type=Collection[str], default=()) + terms_accepted_version = attr.ib(type=Optional[str], default=None) # the HTTP cookie used to track the mapping session id @@ -190,6 +191,8 @@ class SsoHandler: # map from idp_id to SsoIdentityProvider self._identity_providers = {} # type: Dict[str, SsoIdentityProvider] + self._consent_at_registration = hs.config.consent.user_consent_at_registration + def register_identity_provider(self, p: SsoIdentityProvider): p_id = p.idp_id assert p_id not in self._identity_providers @@ -761,6 +764,38 @@ class SsoHandler: ) session.emails_to_use = filtered_emails + # we may now need to collect consent from the user, in which case, redirect + # to the consent-extraction-unit + if self._consent_at_registration: + redirect_url = b"/_synapse/client/new_user_consent" + + # otherwise, redirect to the completion page + else: + redirect_url = b"/_synapse/client/sso_register" + + respond_with_redirect(request, redirect_url) + + async def handle_terms_accepted( + self, request: Request, session_id: str, terms_version: str + ): + """Handle a request to the new-user 'consent' endpoint + + Will serve an HTTP response to the request. + + Args: + request: HTTP request + session_id: ID of the username mapping session, extracted from a cookie + terms_version: the version of the terms which the user viewed and consented + to + """ + logger.info( + "[session %s] User consented to terms version %s", + session_id, + terms_version, + ) + session = self.get_mapping_session(session_id) + session.terms_accepted_version = terms_version + # we're done; now we can register the user respond_with_redirect(request, b"/_synapse/client/sso_register") @@ -816,6 +851,15 @@ class SsoHandler: path=b"/", ) + auth_result = {} + if session.terms_accepted_version: + # TODO: make this less awful. + auth_result[LoginType.TERMS] = True + + await self._registration_handler.post_registration_actions( + user_id, auth_result, access_token=None + ) + await self._auth_handler.complete_sso_login( user_id, request, diff --git a/synapse/res/templates/sso_new_user_consent.html b/synapse/res/templates/sso_new_user_consent.html new file mode 100644 index 0000000000..8c33787c54 --- /dev/null +++ b/synapse/res/templates/sso_new_user_consent.html @@ -0,0 +1,39 @@ + + + + + SSO redirect confirmation + + + + +
    +

    Your account is nearly ready

    +

    Agree to the terms to create your account.

    +
    +
    + +
    + +
    +
    {{ user_profile.display_name }}
    +
    {{ user_id }}
    +
    +
    + + +
    + + diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 6acbc03d73..02310c1900 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING, Mapping from twisted.web.resource import Resource +from synapse.rest.synapse.client.new_user_consent import NewUserConsentResource from synapse.rest.synapse.client.pick_idp import PickIdpResource from synapse.rest.synapse.client.pick_username import pick_username_resource from synapse.rest.synapse.client.sso_register import SsoRegisterResource @@ -39,6 +40,7 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc # enabled (they just won't work very well if it's not) "/_synapse/client/pick_idp": PickIdpResource(hs), "/_synapse/client/pick_username": pick_username_resource(hs), + "/_synapse/client/new_user_consent": NewUserConsentResource(hs), "/_synapse/client/sso_register": SsoRegisterResource(hs), } diff --git a/synapse/rest/synapse/client/new_user_consent.py b/synapse/rest/synapse/client/new_user_consent.py new file mode 100644 index 0000000000..b2e0f93810 --- /dev/null +++ b/synapse/rest/synapse/client/new_user_consent.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +from typing import TYPE_CHECKING + +from twisted.web.http import Request + +from synapse.api.errors import SynapseError +from synapse.handlers.sso import get_username_mapping_session_cookie_from_request +from synapse.http.server import DirectServeHtmlResource, respond_with_html +from synapse.http.servlet import parse_string +from synapse.types import UserID +from synapse.util.templates import build_jinja_env + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class NewUserConsentResource(DirectServeHtmlResource): + """A resource which collects consent to the server's terms from a new user + + This resource gets mounted at /_synapse/client/new_user_consent, and is shown + when we are automatically creating a new user due to an SSO login. + + It shows a template which prompts the user to go and read the Ts and Cs, and click + a clickybox if they have done so. + """ + + def __init__(self, hs: "HomeServer"): + super().__init__() + self._sso_handler = hs.get_sso_handler() + self._server_name = hs.hostname + self._consent_version = hs.config.consent.user_consent_version + + def template_search_dirs(): + if hs.config.sso.sso_template_dir: + yield hs.config.sso.sso_template_dir + yield hs.config.sso.default_template_dir + + self._jinja_env = build_jinja_env(template_search_dirs(), hs.config) + + async def _async_render_GET(self, request: Request) -> None: + try: + session_id = get_username_mapping_session_cookie_from_request(request) + session = self._sso_handler.get_mapping_session(session_id) + except SynapseError as e: + logger.warning("Error fetching session: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + user_id = UserID(session.chosen_localpart, self._server_name) + user_profile = { + "display_name": session.display_name, + } + + template_params = { + "user_id": user_id.to_string(), + "user_profile": user_profile, + "consent_version": self._consent_version, + "terms_url": "/_matrix/consent?v=%s" % (self._consent_version,), + } + + template = self._jinja_env.get_template("sso_new_user_consent.html") + html = template.render(template_params) + respond_with_html(request, 200, html) + + async def _async_render_POST(self, request: Request): + try: + session_id = get_username_mapping_session_cookie_from_request(request) + except SynapseError as e: + logger.warning("Error fetching session cookie: %s", e) + self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code) + return + + try: + accepted_version = parse_string(request, "accepted_version", required=True) + except SynapseError as e: + self._sso_handler.render_error(request, "bad_param", e.msg, code=e.code) + return + + await self._sso_handler.handle_terms_accepted( + request, session_id, accepted_version + ) -- cgit 1.5.1 From 846b9d3df033be1043710e49e89bcba68722071e Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Mon, 1 Feb 2021 22:56:01 +0000 Subject: Put OIDC callback URI under /_synapse/client. (#9288) --- CHANGES.md | 4 +++ UPGRADE.rst | 13 ++++++++- changelog.d/9288.feature | 1 + docs/openid.md | 19 ++++++------- docs/workers.md | 2 +- synapse/config/oidc_config.py | 2 +- synapse/handlers/oidc_handler.py | 8 +++--- synapse/rest/oidc/__init__.py | 27 ------------------- synapse/rest/oidc/callback_resource.py | 30 --------------------- synapse/rest/synapse/client/__init__.py | 4 +-- synapse/rest/synapse/client/oidc/__init__.py | 31 ++++++++++++++++++++++ .../rest/synapse/client/oidc/callback_resource.py | 30 +++++++++++++++++++++ tests/handlers/test_oidc.py | 15 +++++------ 13 files changed, 102 insertions(+), 84 deletions(-) create mode 100644 changelog.d/9288.feature delete mode 100644 synapse/rest/oidc/__init__.py delete mode 100644 synapse/rest/oidc/callback_resource.py create mode 100644 synapse/rest/synapse/client/oidc/__init__.py create mode 100644 synapse/rest/synapse/client/oidc/callback_resource.py (limited to 'synapse/handlers') diff --git a/CHANGES.md b/CHANGES.md index fcd782fa94..e9ff14a03d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -3,6 +3,10 @@ Unreleased Note that this release includes a change in Synapse to use Redis as a cache ─ as well as a pub/sub mechanism ─ if Redis support is enabled. No action is needed by server administrators, and we do not expect resource usage of the Redis instance to change dramatically. +This release also changes the callback URI for OpenID Connect (OIDC) identity +providers. If your server is configured to use single sign-on via an +OIDC/OAuth2 IdP, you may need to make configuration changes. Please review +[UPGRADE.rst](UPGRADE.rst) for more details on these changes. Synapse 1.26.0 (2021-01-27) =========================== diff --git a/UPGRADE.rst b/UPGRADE.rst index eea0322695..d00f718cae 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -88,6 +88,17 @@ for example: Upgrading to v1.27.0 ==================== +Changes to callback URI for OAuth2 / OpenID Connect +--------------------------------------------------- + +This version changes the URI used for callbacks from OAuth2 identity providers. If +your server is configured for single sign-on via an OpenID Connect or OAuth2 identity +provider, you will need to add ``[synapse public baseurl]/_synapse/client/oidc/callback`` +to the list of permitted "redirect URIs" at the identity provider. + +See `docs/openid.md `_ for more information on setting up OpenID +Connect. + Changes to HTML templates ------------------------- @@ -235,7 +246,7 @@ shown below: return {"localpart": localpart} -Removal historical Synapse Admin API +Removal historical Synapse Admin API ------------------------------------ Historically, the Synapse Admin API has been accessible under: diff --git a/changelog.d/9288.feature b/changelog.d/9288.feature new file mode 100644 index 0000000000..efde69fb3c --- /dev/null +++ b/changelog.d/9288.feature @@ -0,0 +1 @@ +Update the redirect URI for OIDC authentication. diff --git a/docs/openid.md b/docs/openid.md index 3d07220967..9d19368845 100644 --- a/docs/openid.md +++ b/docs/openid.md @@ -54,7 +54,8 @@ Here are a few configs for providers that should work with Synapse. ### Microsoft Azure Active Directory Azure AD can act as an OpenID Connect Provider. Register a new application under *App registrations* in the Azure AD management console. The RedirectURI for your -application should point to your matrix server: `[synapse public baseurl]/_synapse/oidc/callback` +application should point to your matrix server: +`[synapse public baseurl]/_synapse/client/oidc/callback` Go to *Certificates & secrets* and register a new client secret. Make note of your Directory (tenant) ID as it will be used in the Azure links. @@ -94,7 +95,7 @@ staticClients: - id: synapse secret: secret redirectURIs: - - '[synapse public baseurl]/_synapse/oidc/callback' + - '[synapse public baseurl]/_synapse/client/oidc/callback' name: 'Synapse' ``` @@ -140,7 +141,7 @@ Follow the [Getting Started Guide](https://www.keycloak.org/getting-started) to | Enabled | `On` | | Client Protocol | `openid-connect` | | Access Type | `confidential` | -| Valid Redirect URIs | `[synapse public baseurl]/_synapse/oidc/callback` | +| Valid Redirect URIs | `[synapse public baseurl]/_synapse/client/oidc/callback` | 5. Click `Save` 6. On the Credentials tab, update the fields: @@ -168,7 +169,7 @@ oidc_providers: ### [Auth0][auth0] 1. Create a regular web application for Synapse -2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/oidc/callback` +2. Set the Allowed Callback URLs to `[synapse public baseurl]/_synapse/client/oidc/callback` 3. Add a rule to add the `preferred_username` claim.
    Code sample @@ -217,7 +218,7 @@ login mechanism needs an attribute to uniquely identify users, and that endpoint does not return a `sub` property, an alternative `subject_claim` has to be set. 1. Create a new OAuth application: https://github.com/settings/applications/new. -2. Set the callback URL to `[synapse public baseurl]/_synapse/oidc/callback`. +2. Set the callback URL to `[synapse public baseurl]/_synapse/client/oidc/callback`. Synapse config: @@ -262,13 +263,13 @@ oidc_providers: display_name_template: "{{ user.name }}" ``` 4. Back in the Google console, add this Authorized redirect URI: `[synapse - public baseurl]/_synapse/oidc/callback`. + public baseurl]/_synapse/client/oidc/callback`. ### Twitch 1. Setup a developer account on [Twitch](https://dev.twitch.tv/) 2. Obtain the OAuth 2.0 credentials by [creating an app](https://dev.twitch.tv/console/apps/) -3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/oidc/callback` +3. Add this OAuth Redirect URL: `[synapse public baseurl]/_synapse/client/oidc/callback` Synapse config: @@ -290,7 +291,7 @@ oidc_providers: 1. Create a [new application](https://gitlab.com/profile/applications). 2. Add the `read_user` and `openid` scopes. -3. Add this Callback URL: `[synapse public baseurl]/_synapse/oidc/callback` +3. Add this Callback URL: `[synapse public baseurl]/_synapse/client/oidc/callback` Synapse config: @@ -323,7 +324,7 @@ one so requires a little more configuration. 2. Once the app is created, add "Facebook Login" and choose "Web". You don't need to go through the whole form here. 3. In the left-hand menu, open "Products"/"Facebook Login"/"Settings". - * Add `[synapse public baseurl]/_synapse/oidc/callback` as an OAuth Redirect + * Add `[synapse public baseurl]/_synapse/client/oidc/callback` as an OAuth Redirect URL. 4. In the left-hand menu, open "Settings/Basic". Here you can copy the "App ID" and "App Secret" for use below. diff --git a/docs/workers.md b/docs/workers.md index c36549c621..c4a6c79238 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -266,7 +266,7 @@ using): ^/_synapse/client/sso_register$ # OpenID Connect requests. - ^/_synapse/oidc/callback$ + ^/_synapse/client/oidc/callback$ # SAML requests. ^/_matrix/saml2/authn_response$ diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py index bb122ef182..4c24c50629 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py @@ -53,7 +53,7 @@ class OIDCConfig(Config): "Multiple OIDC providers have the idp_id %r." % idp_id ) - self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback" + self.oidc_callback_url = self.public_baseurl + "_synapse/client/oidc/callback" @property def oidc_enabled(self) -> bool: diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py index ca647fa78f..71008ec50d 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py @@ -102,7 +102,7 @@ class OidcHandler: ) from e async def handle_oidc_callback(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback Since we might want to display OIDC-related errors in a user-friendly way, we don't raise SynapseError from here. Instead, we call @@ -643,7 +643,7 @@ class OidcProvider: - ``client_id``: the client ID set in ``oidc_config.client_id`` - ``response_type``: ``code`` - - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/oidc/callback`` + - ``redirect_uri``: the callback URL ; ``{base url}/_synapse/client/oidc/callback`` - ``scope``: the list of scopes set in ``oidc_config.scopes`` - ``state``: a random string - ``nonce``: a random string @@ -684,7 +684,7 @@ class OidcProvider: request.addCookie( SESSION_COOKIE_NAME, cookie, - path="/_synapse/oidc", + path="/_synapse/client/oidc", max_age="3600", httpOnly=True, sameSite="lax", @@ -705,7 +705,7 @@ class OidcProvider: async def handle_oidc_callback( self, request: SynapseRequest, session_data: "OidcSessionData", code: str ) -> None: - """Handle an incoming request to /_synapse/oidc/callback + """Handle an incoming request to /_synapse/client/oidc/callback By this time we have already validated the session on the synapse side, and now need to do the provider-specific operations. This includes: diff --git a/synapse/rest/oidc/__init__.py b/synapse/rest/oidc/__init__.py deleted file mode 100644 index d958dd65bb..0000000000 --- a/synapse/rest/oidc/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Quentin Gliech -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from twisted.web.resource import Resource - -from synapse.rest.oidc.callback_resource import OIDCCallbackResource - -logger = logging.getLogger(__name__) - - -class OIDCResource(Resource): - def __init__(self, hs): - Resource.__init__(self) - self.putChild(b"callback", OIDCCallbackResource(hs)) diff --git a/synapse/rest/oidc/callback_resource.py b/synapse/rest/oidc/callback_resource.py deleted file mode 100644 index f7a0bc4bdb..0000000000 --- a/synapse/rest/oidc/callback_resource.py +++ /dev/null @@ -1,30 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2020 Quentin Gliech -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from synapse.http.server import DirectServeHtmlResource - -logger = logging.getLogger(__name__) - - -class OIDCCallbackResource(DirectServeHtmlResource): - isLeaf = 1 - - def __init__(self, hs): - super().__init__() - self._oidc_handler = hs.get_oidc_handler() - - async def _async_render_GET(self, request): - await self._oidc_handler.handle_oidc_callback(request) diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 02310c1900..381baf9729 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -47,9 +47,9 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc # provider-specific SSO bits. Only load these if they are enabled, since they # rely on optional dependencies. if hs.config.oidc_enabled: - from synapse.rest.oidc import OIDCResource + from synapse.rest.synapse.client.oidc import OIDCResource - resources["/_synapse/oidc"] = OIDCResource(hs) + resources["/_synapse/client/oidc"] = OIDCResource(hs) if hs.config.saml2_enabled: from synapse.rest.saml2 import SAML2Resource diff --git a/synapse/rest/synapse/client/oidc/__init__.py b/synapse/rest/synapse/client/oidc/__init__.py new file mode 100644 index 0000000000..64c0deb75d --- /dev/null +++ b/synapse/rest/synapse/client/oidc/__init__.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.oidc.callback_resource import OIDCCallbackResource + +logger = logging.getLogger(__name__) + + +class OIDCResource(Resource): + def __init__(self, hs): + Resource.__init__(self) + self.putChild(b"callback", OIDCCallbackResource(hs)) + + +__all__ = ["OIDCResource"] diff --git a/synapse/rest/synapse/client/oidc/callback_resource.py b/synapse/rest/synapse/client/oidc/callback_resource.py new file mode 100644 index 0000000000..f7a0bc4bdb --- /dev/null +++ b/synapse/rest/synapse/client/oidc/callback_resource.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 Quentin Gliech +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.http.server import DirectServeHtmlResource + +logger = logging.getLogger(__name__) + + +class OIDCCallbackResource(DirectServeHtmlResource): + isLeaf = 1 + + def __init__(self, hs): + super().__init__() + self._oidc_handler = hs.get_oidc_handler() + + async def _async_render_GET(self, request): + await self._oidc_handler.handle_oidc_callback(request) diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py index d8f90b9a80..ad20400b1d 100644 --- a/tests/handlers/test_oidc.py +++ b/tests/handlers/test_oidc.py @@ -40,7 +40,7 @@ ISSUER = "https://issuer/" CLIENT_ID = "test-client-id" CLIENT_SECRET = "test-client-secret" BASE_URL = "https://synapse/" -CALLBACK_URL = BASE_URL + "_synapse/oidc/callback" +CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback" SCOPES = ["openid"] AUTHORIZATION_ENDPOINT = ISSUER + "authorize" @@ -58,12 +58,6 @@ COMMON_CONFIG = { } -# The cookie name and path don't really matter, just that it has to be coherent -# between the callback & redirect handlers. -COOKIE_NAME = b"oidc_session" -COOKIE_PATH = "/_synapse/oidc" - - class TestMappingProvider: @staticmethod def parse_config(config): @@ -340,8 +334,11 @@ class OidcHandlerTestCase(HomeserverTestCase): # For some reason, call.args does not work with python3.5 args = calls[0][0] kwargs = calls[0][1] - self.assertEqual(args[0], COOKIE_NAME) - self.assertEqual(kwargs["path"], COOKIE_PATH) + + # The cookie name and path don't really matter, just that it has to be coherent + # between the callback & redirect handlers. + self.assertEqual(args[0], b"oidc_session") + self.assertEqual(kwargs["path"], "/_synapse/client/oidc") cookie = args[1] macaroon = pymacaroons.Macaroon.deserialize(cookie) -- cgit 1.5.1 From 8f75bf1df7f2bcb3ffe0bb89f8fe3351a48769c0 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 2 Feb 2021 09:43:50 +0000 Subject: Put SAML callback URI under /_synapse/client. (#9289) --- UPGRADE.rst | 4 +++ changelog.d/9289.removal | 1 + docs/sample_config.yaml | 4 +-- docs/workers.md | 2 +- synapse/config/saml2_config.py | 8 ++--- synapse/handlers/saml_handler.py | 2 +- synapse/rest/saml2/__init__.py | 29 ---------------- synapse/rest/saml2/metadata_resource.py | 36 -------------------- synapse/rest/saml2/response_resource.py | 39 ---------------------- synapse/rest/synapse/client/__init__.py | 9 +++-- synapse/rest/synapse/client/saml2/__init__.py | 33 ++++++++++++++++++ .../rest/synapse/client/saml2/metadata_resource.py | 36 ++++++++++++++++++++ .../rest/synapse/client/saml2/response_resource.py | 39 ++++++++++++++++++++++ 13 files changed, 127 insertions(+), 115 deletions(-) create mode 100644 changelog.d/9289.removal delete mode 100644 synapse/rest/saml2/__init__.py delete mode 100644 synapse/rest/saml2/metadata_resource.py delete mode 100644 synapse/rest/saml2/response_resource.py create mode 100644 synapse/rest/synapse/client/saml2/__init__.py create mode 100644 synapse/rest/synapse/client/saml2/metadata_resource.py create mode 100644 synapse/rest/synapse/client/saml2/response_resource.py (limited to 'synapse/handlers') diff --git a/UPGRADE.rst b/UPGRADE.rst index d00f718cae..22edfe0d60 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -99,6 +99,10 @@ to the list of permitted "redirect URIs" at the identity provider. See `docs/openid.md `_ for more information on setting up OpenID Connect. +(Note: a similar change is being made for SAML2; in this case the old URI +``[synapse public baseurl]/_matrix/saml2`` is being deprecated, but will continue to +work, so no immediate changes are required for existing installations.) + Changes to HTML templates ------------------------- diff --git a/changelog.d/9289.removal b/changelog.d/9289.removal new file mode 100644 index 0000000000..49158fc4d3 --- /dev/null +++ b/changelog.d/9289.removal @@ -0,0 +1 @@ +Add new endpoint `/_synapse/client/saml2` for SAML2 authentication callbacks, and deprecate the old endpoint `/_matrix/saml2`. diff --git a/docs/sample_config.yaml b/docs/sample_config.yaml index dd2981717d..6d265d2972 100644 --- a/docs/sample_config.yaml +++ b/docs/sample_config.yaml @@ -1566,10 +1566,10 @@ trusted_key_servers: # enable SAML login. # # Once SAML support is enabled, a metadata file will be exposed at -# https://:/_matrix/saml2/metadata.xml, which you may be able to +# https://:/_synapse/client/saml2/metadata.xml, which you may be able to # use to configure your SAML IdP with. Alternatively, you can manually configure # the IdP to use an ACS location of -# https://:/_matrix/saml2/authn_response. +# https://:/_synapse/client/saml2/authn_response. # saml2_config: # `sp_config` is the configuration for the pysaml2 Service Provider. diff --git a/docs/workers.md b/docs/workers.md index c4a6c79238..f7fc6df119 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -269,7 +269,7 @@ using): ^/_synapse/client/oidc/callback$ # SAML requests. - ^/_matrix/saml2/authn_response$ + ^/_synapse/client/saml2/authn_response$ # CAS requests. ^/_matrix/client/(api/v1|r0|unstable)/login/cas/ticket$ diff --git a/synapse/config/saml2_config.py b/synapse/config/saml2_config.py index f33dfa0d6a..ad865a667f 100644 --- a/synapse/config/saml2_config.py +++ b/synapse/config/saml2_config.py @@ -194,8 +194,8 @@ class SAML2Config(Config): optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute) optional_attributes -= required_attributes - metadata_url = public_baseurl + "_matrix/saml2/metadata.xml" - response_url = public_baseurl + "_matrix/saml2/authn_response" + metadata_url = public_baseurl + "_synapse/client/saml2/metadata.xml" + response_url = public_baseurl + "_synapse/client/saml2/authn_response" return { "entityid": metadata_url, "service": { @@ -233,10 +233,10 @@ class SAML2Config(Config): # enable SAML login. # # Once SAML support is enabled, a metadata file will be exposed at - # https://:/_matrix/saml2/metadata.xml, which you may be able to + # https://:/_synapse/client/saml2/metadata.xml, which you may be able to # use to configure your SAML IdP with. Alternatively, you can manually configure # the IdP to use an ACS location of - # https://:/_matrix/saml2/authn_response. + # https://:/_synapse/client/saml2/authn_response. # saml2_config: # `sp_config` is the configuration for the pysaml2 Service Provider. diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py index 5946919c33..e88fd59749 100644 --- a/synapse/handlers/saml_handler.py +++ b/synapse/handlers/saml_handler.py @@ -133,7 +133,7 @@ class SamlHandler(BaseHandler): raise Exception("prepare_for_authenticate didn't return a Location header") async def handle_saml_response(self, request: SynapseRequest) -> None: - """Handle an incoming request to /_matrix/saml2/authn_response + """Handle an incoming request to /_synapse/client/saml2/authn_response Args: request: the incoming request from the browser. We'll diff --git a/synapse/rest/saml2/__init__.py b/synapse/rest/saml2/__init__.py deleted file mode 100644 index 68da37ca6a..0000000000 --- a/synapse/rest/saml2/__init__.py +++ /dev/null @@ -1,29 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import logging - -from twisted.web.resource import Resource - -from synapse.rest.saml2.metadata_resource import SAML2MetadataResource -from synapse.rest.saml2.response_resource import SAML2ResponseResource - -logger = logging.getLogger(__name__) - - -class SAML2Resource(Resource): - def __init__(self, hs): - Resource.__init__(self) - self.putChild(b"metadata.xml", SAML2MetadataResource(hs)) - self.putChild(b"authn_response", SAML2ResponseResource(hs)) diff --git a/synapse/rest/saml2/metadata_resource.py b/synapse/rest/saml2/metadata_resource.py deleted file mode 100644 index 1e8526e22e..0000000000 --- a/synapse/rest/saml2/metadata_resource.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import saml2.metadata - -from twisted.web.resource import Resource - - -class SAML2MetadataResource(Resource): - """A Twisted web resource which renders the SAML metadata""" - - isLeaf = 1 - - def __init__(self, hs): - Resource.__init__(self) - self.sp_config = hs.config.saml2_sp_config - - def render_GET(self, request): - metadata_xml = saml2.metadata.create_metadata_string( - configfile=None, config=self.sp_config - ) - request.setHeader(b"Content-Type", b"text/xml; charset=utf-8") - return metadata_xml diff --git a/synapse/rest/saml2/response_resource.py b/synapse/rest/saml2/response_resource.py deleted file mode 100644 index f6668fb5e3..0000000000 --- a/synapse/rest/saml2/response_resource.py +++ /dev/null @@ -1,39 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Copyright 2018 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.http.server import DirectServeHtmlResource - - -class SAML2ResponseResource(DirectServeHtmlResource): - """A Twisted web resource which handles the SAML response""" - - isLeaf = 1 - - def __init__(self, hs): - super().__init__() - self._saml_handler = hs.get_saml_handler() - - async def _async_render_GET(self, request): - # We're not expecting any GET request on that resource if everything goes right, - # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. - # In this case, just tell the user that something went wrong and they should - # try to authenticate again. - self._saml_handler._render_error( - request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" - ) - - async def _async_render_POST(self, request): - await self._saml_handler.handle_saml_response(request) diff --git a/synapse/rest/synapse/client/__init__.py b/synapse/rest/synapse/client/__init__.py index 381baf9729..e5ef515090 100644 --- a/synapse/rest/synapse/client/__init__.py +++ b/synapse/rest/synapse/client/__init__.py @@ -52,10 +52,13 @@ def build_synapse_client_resource_tree(hs: "HomeServer") -> Mapping[str, Resourc resources["/_synapse/client/oidc"] = OIDCResource(hs) if hs.config.saml2_enabled: - from synapse.rest.saml2 import SAML2Resource + from synapse.rest.synapse.client.saml2 import SAML2Resource - # This is mounted under '/_matrix' for backwards-compatibility. - resources["/_matrix/saml2"] = SAML2Resource(hs) + res = SAML2Resource(hs) + resources["/_synapse/client/saml2"] = res + + # This is also mounted under '/_matrix' for backwards-compatibility. + resources["/_matrix/saml2"] = res return resources diff --git a/synapse/rest/synapse/client/saml2/__init__.py b/synapse/rest/synapse/client/saml2/__init__.py new file mode 100644 index 0000000000..3e8235ee1e --- /dev/null +++ b/synapse/rest/synapse/client/saml2/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from twisted.web.resource import Resource + +from synapse.rest.synapse.client.saml2.metadata_resource import SAML2MetadataResource +from synapse.rest.synapse.client.saml2.response_resource import SAML2ResponseResource + +logger = logging.getLogger(__name__) + + +class SAML2Resource(Resource): + def __init__(self, hs): + Resource.__init__(self) + self.putChild(b"metadata.xml", SAML2MetadataResource(hs)) + self.putChild(b"authn_response", SAML2ResponseResource(hs)) + + +__all__ = ["SAML2Resource"] diff --git a/synapse/rest/synapse/client/saml2/metadata_resource.py b/synapse/rest/synapse/client/saml2/metadata_resource.py new file mode 100644 index 0000000000..1e8526e22e --- /dev/null +++ b/synapse/rest/synapse/client/saml2/metadata_resource.py @@ -0,0 +1,36 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import saml2.metadata + +from twisted.web.resource import Resource + + +class SAML2MetadataResource(Resource): + """A Twisted web resource which renders the SAML metadata""" + + isLeaf = 1 + + def __init__(self, hs): + Resource.__init__(self) + self.sp_config = hs.config.saml2_sp_config + + def render_GET(self, request): + metadata_xml = saml2.metadata.create_metadata_string( + configfile=None, config=self.sp_config + ) + request.setHeader(b"Content-Type", b"text/xml; charset=utf-8") + return metadata_xml diff --git a/synapse/rest/synapse/client/saml2/response_resource.py b/synapse/rest/synapse/client/saml2/response_resource.py new file mode 100644 index 0000000000..f6668fb5e3 --- /dev/null +++ b/synapse/rest/synapse/client/saml2/response_resource.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2018 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.http.server import DirectServeHtmlResource + + +class SAML2ResponseResource(DirectServeHtmlResource): + """A Twisted web resource which handles the SAML response""" + + isLeaf = 1 + + def __init__(self, hs): + super().__init__() + self._saml_handler = hs.get_saml_handler() + + async def _async_render_GET(self, request): + # We're not expecting any GET request on that resource if everything goes right, + # but some IdPs sometimes end up responding with a 302 redirect on this endpoint. + # In this case, just tell the user that something went wrong and they should + # try to authenticate again. + self._saml_handler._render_error( + request, "unexpected_get", "Unexpected GET request on /saml2/authn_response" + ) + + async def _async_render_POST(self, request): + await self._saml_handler.handle_saml_response(request) -- cgit 1.5.1 From b60bb28bbc3d916586a913970298baba483efc1f Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Tue, 2 Feb 2021 04:16:29 -0700 Subject: Add an admin API to get the current room state (#9168) This could arguably replace the existing admin API for `/members`, however that is out of scope of this change. This sort of endpoint is ideal for moderation use cases as well as other applications, such as needing to retrieve various bits of information about a room to perform a task (like syncing power levels between two places). This endpoint exposes nothing more than an admin would be able to access with a `select *` query on their database. --- changelog.d/9168.feature | 1 + docs/admin_api/rooms.md | 30 ++++++++++++++++++++++++++++++ synapse/handlers/message.py | 2 +- synapse/rest/admin/__init__.py | 2 ++ synapse/rest/admin/rooms.py | 39 +++++++++++++++++++++++++++++++++++++++ tests/rest/admin/test_room.py | 15 +++++++++++++++ 6 files changed, 88 insertions(+), 1 deletion(-) create mode 100644 changelog.d/9168.feature (limited to 'synapse/handlers') diff --git a/changelog.d/9168.feature b/changelog.d/9168.feature new file mode 100644 index 0000000000..8be1950eee --- /dev/null +++ b/changelog.d/9168.feature @@ -0,0 +1 @@ +Add an admin API for retrieving the current room state of a room. \ No newline at end of file diff --git a/docs/admin_api/rooms.md b/docs/admin_api/rooms.md index f34cec1ff7..3832b36407 100644 --- a/docs/admin_api/rooms.md +++ b/docs/admin_api/rooms.md @@ -368,6 +368,36 @@ Response: } ``` +# Room State API + +The Room State admin API allows server admins to get a list of all state events in a room. + +The response includes the following fields: + +* `state` - The current state of the room at the time of request. + +## Usage + +A standard request: + +``` +GET /_synapse/admin/v1/rooms//state + +{} +``` + +Response: + +```json +{ + "state": [ + {"type": "m.room.create", "state_key": "", "etc": true}, + {"type": "m.room.power_levels", "state_key": "", "etc": true}, + {"type": "m.room.name", "state_key": "", "etc": true} + ] +} +``` + # Delete Room API The Delete Room admin API allows server admins to remove rooms from server diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index e2a7d567fa..a15336bf00 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -174,7 +174,7 @@ class MessageHandler: raise NotFoundError("Can't find event for token %s" % (at_token,)) visible_events = await filter_events_for_client( - self.storage, user_id, last_events, filter_send_to_client=False + self.storage, user_id, last_events, filter_send_to_client=False, ) event = last_events[0] diff --git a/synapse/rest/admin/__init__.py b/synapse/rest/admin/__init__.py index 57e0a10837..f5c5d164f9 100644 --- a/synapse/rest/admin/__init__.py +++ b/synapse/rest/admin/__init__.py @@ -44,6 +44,7 @@ from synapse.rest.admin.rooms import ( MakeRoomAdminRestServlet, RoomMembersRestServlet, RoomRestServlet, + RoomStateRestServlet, ShutdownRoomRestServlet, ) from synapse.rest.admin.server_notice_servlet import SendServerNoticeServlet @@ -213,6 +214,7 @@ def register_servlets(hs, http_server): """ register_servlets_for_client_rest_resource(hs, http_server) ListRoomRestServlet(hs).register(http_server) + RoomStateRestServlet(hs).register(http_server) RoomRestServlet(hs).register(http_server) RoomMembersRestServlet(hs).register(http_server) DeleteRoomRestServlet(hs).register(http_server) diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index f14915d47e..3e57e6a4d0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -292,6 +292,45 @@ class RoomMembersRestServlet(RestServlet): return 200, ret +class RoomStateRestServlet(RestServlet): + """ + Get full state within a room. + """ + + PATTERNS = admin_patterns("/rooms/(?P[^/]+)/state") + + def __init__(self, hs: "HomeServer"): + self.hs = hs + self.auth = hs.get_auth() + self.store = hs.get_datastore() + self.clock = hs.get_clock() + self._event_serializer = hs.get_event_client_serializer() + + async def on_GET( + self, request: SynapseRequest, room_id: str + ) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + await assert_user_is_admin(self.auth, requester.user) + + ret = await self.store.get_room(room_id) + if not ret: + raise NotFoundError("Room not found") + + event_ids = await self.store.get_current_state_ids(room_id) + events = await self.store.get_events(event_ids.values()) + now = self.clock.time_msec() + room_state = await self._event_serializer.serialize_events( + events.values(), + now, + # We don't bother bundling aggregations in when asked for state + # events, as clients won't use them. + bundle_aggregations=False, + ) + ret = {"state": room_state} + + return 200, ret + + class JoinRoomAliasServlet(RestServlet): PATTERNS = admin_patterns("/join/(?P[^/]*)") diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index a0f32c5512..7c47aa7e0a 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -1180,6 +1180,21 @@ class RoomTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.json_body["total"], 3) + def test_room_state(self): + """Test that room state can be requested correctly""" + # Create two test rooms + room_id = self.helper.create_room_as(self.admin_user, tok=self.admin_user_tok) + + url = "/_synapse/admin/v1/rooms/%s/state" % (room_id,) + channel = self.make_request( + "GET", url.encode("ascii"), access_token=self.admin_user_tok, + ) + self.assertEqual(200, channel.code, msg=channel.json_body) + self.assertIn("state", channel.json_body) + # testing that the state events match is painful and not done here. We assume that + # the create_room already does the right thing, so no need to verify that we got + # the state events it created. + class JoinAliasRoomTestCase(unittest.HomeserverTestCase): -- cgit 1.5.1