summary refs log tree commit diff
diff options
context:
space:
mode:
-rwxr-xr-x.buildkite/scripts/test_old_deps.sh3
-rw-r--r--changelog.d/9129.misc1
-rw-r--r--changelog.d/9135.doc1
-rw-r--r--changelog.d/9180.misc1
-rw-r--r--changelog.d/9181.misc1
-rw-r--r--changelog.d/9183.feature1
-rw-r--r--changelog.d/9184.misc1
-rw-r--r--changelog.d/9217.misc1
-rw-r--r--docs/turn-howto.md6
-rwxr-xr-xsetup.py1
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/experimental.py29
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/config/oidc_config.py11
-rw-r--r--synapse/federation/federation_client.py125
-rw-r--r--synapse/handlers/sso.py23
-rw-r--r--synapse/http/server.py44
-rw-r--r--synapse/rest/client/v1/login.py55
-rw-r--r--synapse/storage/database.py5
-rw-r--r--synapse/storage/databases/main/events.py18
-rw-r--r--tests/rest/client/v1/test_login.py92
-rw-r--r--tests/utils.py3
-rw-r--r--tox.ini15
23 files changed, 342 insertions, 99 deletions
diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh

index 9905c4bc4f..28e6694b5d 100755 --- a/.buildkite/scripts/test_old_deps.sh +++ b/.buildkite/scripts/test_old_deps.sh
@@ -10,4 +10,7 @@ apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev x export LANG="C.UTF-8" +# Prevent virtualenv from auto-updating pip to an incompatible version +export VIRTUALENV_NO_DOWNLOAD=1 + exec tox -e py35-old,combine diff --git a/changelog.d/9129.misc b/changelog.d/9129.misc new file mode 100644
index 0000000000..7800be3e7e --- /dev/null +++ b/changelog.d/9129.misc
@@ -0,0 +1 @@ +Various improvements to the federation client. diff --git a/changelog.d/9135.doc b/changelog.d/9135.doc new file mode 100644
index 0000000000..d11ba70de4 --- /dev/null +++ b/changelog.d/9135.doc
@@ -0,0 +1 @@ +Add link to Matrix VoIP tester for turn-howto. diff --git a/changelog.d/9180.misc b/changelog.d/9180.misc new file mode 100644
index 0000000000..69dd86110d --- /dev/null +++ b/changelog.d/9180.misc
@@ -0,0 +1 @@ +Add a `long_description_type` to the package metadata. diff --git a/changelog.d/9181.misc b/changelog.d/9181.misc new file mode 100644
index 0000000000..7820d09cd0 --- /dev/null +++ b/changelog.d/9181.misc
@@ -0,0 +1 @@ +Speed up batch insertion when using PostgreSQL. diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature new file mode 100644
index 0000000000..2d5c735042 --- /dev/null +++ b/changelog.d/9183.feature
@@ -0,0 +1 @@ +Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858). diff --git a/changelog.d/9184.misc b/changelog.d/9184.misc new file mode 100644
index 0000000000..70da3d6cf5 --- /dev/null +++ b/changelog.d/9184.misc
@@ -0,0 +1 @@ +Emit an error at startup if different Identity Providers are configured with the same `idp_id`. diff --git a/changelog.d/9217.misc b/changelog.d/9217.misc new file mode 100644
index 0000000000..72bacc7110 --- /dev/null +++ b/changelog.d/9217.misc
@@ -0,0 +1 @@ +Fix the Python 3.5 old dependencies build. diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index a470c274a5..e8f13ad484 100644 --- a/docs/turn-howto.md +++ b/docs/turn-howto.md
@@ -232,6 +232,12 @@ Here are a few things to try: (Understanding the output is beyond the scope of this document!) + * You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/. + Note that this test is not fully reliable yet, so don't be discouraged if + the test fails. + [Here](https://github.com/matrix-org/voip-tester) is the github repo of the + source of the tester, where you can file bug reports. + * There is a WebRTC test tool at https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To use it, you will need a username/password for your TURN server. You can diff --git a/setup.py b/setup.py
index 9730afb41b..ddbe9f511a 100755 --- a/setup.py +++ b/setup.py
@@ -121,6 +121,7 @@ setup( include_package_data=True, zip_safe=False, long_description=long_description, + long_description_content_type="text/x-rst", python_requires="~=3.5", classifiers=[ "Development Status :: 5 - Production/Stable", diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 29aa064e57..3ccea4b02d 100644 --- a/synapse/config/_base.pyi +++ b/synapse/config/_base.pyi
@@ -9,6 +9,7 @@ from synapse.config import ( consent_config, database, emailconfig, + experimental, groups, jwt_config, key, @@ -48,6 +49,7 @@ def path_exists(file_path: str): ... class RootConfig: server: server.ServerConfig + experimental: experimental.ExperimentalConfig tls: tls.TlsConfig database: database.DatabaseConfig logging: logger.LoggingConfig diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py new file mode 100644
index 0000000000..b1c1c51e4d --- /dev/null +++ b/synapse/config/experimental.py
@@ -0,0 +1,29 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.config._base import Config +from synapse.types import JsonDict + + +class ExperimentalConfig(Config): + """Config section for enabling experimental features""" + + section = "experimental" + + def read_config(self, config: JsonDict, **kwargs): + experimental = config.get("experimental_features") or {} + + # MSC2858 (multiple SSO identity providers) + self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 4bd2b3587b..64a2429f77 100644 --- a/synapse/config/homeserver.py +++ b/synapse/config/homeserver.py
@@ -24,6 +24,7 @@ from .cas import CasConfig from .consent_config import ConsentConfig from .database import DatabaseConfig from .emailconfig import EmailConfig +from .experimental import ExperimentalConfig from .federation import FederationConfig from .groups import GroupsConfig from .jwt_config import JWTConfig @@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig): config_classes = [ ServerConfig, + ExperimentalConfig, TlsConfig, FederationConfig, CacheConfig, diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index d58a83be7f..bfeceeed18 100644 --- a/synapse/config/oidc_config.py +++ b/synapse/config/oidc_config.py
@@ -15,6 +15,7 @@ # limitations under the License. import string +from collections import Counter from typing import Iterable, Optional, Tuple, Type import attr @@ -43,6 +44,16 @@ class OIDCConfig(Config): except DependencyException as e: raise ConfigError(e.message) from e + # check we don't have any duplicate idp_ids now. (The SSO handler will also + # check for duplicates when the REST listeners get registered, but that happens + # after synapse has forked so doesn't give nice errors.) + c = Counter([i.idp_id for i in self.oidc_providers]) + for idp_id, count in c.items(): + if count > 1: + raise ConfigError( + "Multiple OIDC providers have the idp_id %r." % idp_id + ) + public_baseurl = self.public_baseurl self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 302b2f69bc..d330ae5dbc 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -18,6 +18,7 @@ import copy import itertools import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -26,7 +27,6 @@ from typing import ( List, Mapping, Optional, - Sequence, Tuple, TypeVar, Union, @@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.retryutils import NotRetryingDestination +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) @@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError): class FederationClient(FederationBase): - def __init__(self, hs): + def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.pdu_destination_tried = {} + self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -116,33 +119,32 @@ class FederationClient(FederationBase): self.pdu_destination_tried[event_id] = destination_dict @log_function - def make_query( + async def make_query( self, - destination, - query_type, - args, - retry_on_dns_fail=False, - ignore_backoff=False, - ): + destination: str, + query_type: str, + args: dict, + retry_on_dns_fail: bool = False, + ignore_backoff: bool = False, + ) -> JsonDict: """Sends a federation Query to a remote homeserver of the given type and arguments. Args: - destination (str): Domain name of the remote homeserver - query_type (str): Category of the query type; should match the + destination: Domain name of the remote homeserver + query_type: Category of the query type; should match the handler name used in register_query_handler(). - args (dict): Mapping of strings to strings containing the details + args: Mapping of strings to strings containing the details of the query request. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. Returns: - a Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels(query_type).inc() - return self.transport_layer.make_query( + return await self.transport_layer.make_query( destination, query_type, args, @@ -151,42 +153,52 @@ class FederationClient(FederationBase): ) @log_function - def query_client_keys(self, destination, content, timeout): + async def query_client_keys( + self, destination: str, content: JsonDict, timeout: int + ) -> JsonDict: """Query device keys for a device hosted on a remote server. Args: - destination (str): Domain name of the remote homeserver - content (dict): The query content. + destination: Domain name of the remote homeserver + content: The query content. Returns: - an Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() - return self.transport_layer.query_client_keys(destination, content, timeout) + return await self.transport_layer.query_client_keys( + destination, content, timeout + ) @log_function - def query_user_devices(self, destination, user_id, timeout=30000): + async def query_user_devices( + self, destination: str, user_id: str, timeout: int = 30000 + ) -> JsonDict: """Query the device keys for a list of user ids hosted on a remote server. """ sent_queries_counter.labels("user_devices").inc() - return self.transport_layer.query_user_devices(destination, user_id, timeout) + return await self.transport_layer.query_user_devices( + destination, user_id, timeout + ) @log_function - def claim_client_keys(self, destination, content, timeout): + async def claim_client_keys( + self, destination: str, content: JsonDict, timeout: int + ) -> JsonDict: """Claims one-time keys for a device hosted on a remote server. Args: - destination (str): Domain name of the remote homeserver - content (dict): The query content. + destination: Domain name of the remote homeserver + content: The query content. Returns: - an Awaitable which will eventually yield a JSON object from the - response + The JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() - return self.transport_layer.claim_client_keys(destination, content, timeout) + return await self.transport_layer.claim_client_keys( + destination, content, timeout + ) async def backfill( self, dest: str, room_id: str, limit: int, extremities: Iterable[str] @@ -195,10 +207,10 @@ class FederationClient(FederationBase): given destination server. Args: - dest (str): The remote homeserver to ask. - room_id (str): The room_id to backfill. - limit (int): The maximum number of events to return. - extremities (list): our current backwards extremities, to backfill from + dest: The remote homeserver to ask. + room_id: The room_id to backfill. + limit: The maximum number of events to return. + extremities: our current backwards extremities, to backfill from """ logger.debug("backfill extrem=%s", extremities) @@ -370,7 +382,7 @@ class FederationClient(FederationBase): for events that have failed their checks Returns: - Deferred : A list of PDUs that have valid signatures and hashes. + A list of PDUs that have valid signatures and hashes. """ deferreds = self._check_sigs_and_hashes(room_version, pdus) @@ -418,7 +430,9 @@ class FederationClient(FederationBase): else: return [p for p in valid_pdus if p] - async def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth( + self, destination: str, room_id: str, event_id: str + ) -> List[EventBase]: res = await self.transport_layer.get_event_auth(destination, room_id, event_id) room_version = await self.store.get_room_version(room_id) @@ -700,18 +714,16 @@ class FederationClient(FederationBase): return await self._try_destination_list("send_join", destinations, send_request) - async def _do_send_join(self, destination: str, pdu: EventBase): + async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_join_v2( + return await self.transport_layer.send_join_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -769,7 +781,7 @@ class FederationClient(FederationBase): time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_invite_v2( + return await self.transport_layer.send_invite_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, @@ -779,7 +791,6 @@ class FederationClient(FederationBase): "invite_room_state": pdu.unsigned.get("invite_room_state", []), }, ) - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -842,18 +853,16 @@ class FederationClient(FederationBase): "send_leave", destinations, send_request ) - async def _do_send_leave(self, destination, pdu): + async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict: time_now = self._clock.time_msec() try: - content = await self.transport_layer.send_leave_v2( + return await self.transport_layer.send_leave_v2( destination=destination, room_id=pdu.room_id, event_id=pdu.event_id, content=pdu.get_pdu_json(time_now), ) - - return content except HttpResponseException as e: if e.code in [400, 404]: err = e.to_synapse_error() @@ -879,7 +888,7 @@ class FederationClient(FederationBase): # content. return resp[1] - def get_public_rooms( + async def get_public_rooms( self, remote_server: str, limit: Optional[int] = None, @@ -887,7 +896,7 @@ class FederationClient(FederationBase): search_filter: Optional[Dict] = None, include_all_networks: bool = False, third_party_instance_id: Optional[str] = None, - ): + ) -> JsonDict: """Get the list of public rooms from a remote homeserver Args: @@ -901,8 +910,7 @@ class FederationClient(FederationBase): party instance Returns: - Awaitable[Dict[str, Any]]: The response from the remote server, or None if - `remote_server` is the same as the local server_name + The response from the remote server. Raises: HttpResponseException: There was an exception returned from the remote server @@ -910,7 +918,7 @@ class FederationClient(FederationBase): requests over federation """ - return self.transport_layer.get_public_rooms( + return await self.transport_layer.get_public_rooms( remote_server, limit, since_token, @@ -923,7 +931,7 @@ class FederationClient(FederationBase): self, destination: str, room_id: str, - earliest_events_ids: Sequence[str], + earliest_events_ids: Iterable[str], latest_events: Iterable[EventBase], limit: int, min_depth: int, @@ -974,7 +982,9 @@ class FederationClient(FederationBase): return signed_events - async def forward_third_party_invite(self, destinations, room_id, event_dict): + async def forward_third_party_invite( + self, destinations: Iterable[str], room_id: str, event_dict: JsonDict + ) -> None: for destination in destinations: if destination == self.server_name: continue @@ -983,7 +993,7 @@ class FederationClient(FederationBase): await self.transport_layer.exchange_third_party_invite( destination=destination, room_id=room_id, event_dict=event_dict ) - return None + return except CodeMessageException: raise except Exception as e: @@ -995,7 +1005,7 @@ class FederationClient(FederationBase): async def get_room_complexity( self, destination: str, room_id: str - ) -> Optional[dict]: + ) -> Optional[JsonDict]: """ Fetch the complexity of a remote room from another server. @@ -1008,10 +1018,9 @@ class FederationClient(FederationBase): could not fetch the complexity. """ try: - complexity = await self.transport_layer.get_room_complexity( + return await self.transport_layer.get_room_complexity( destination=destination, room_id=room_id ) - return complexity except CodeMessageException as e: # We didn't manage to get it -- probably a 404. We are okay if other # servers don't give it to us. diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d493327a10..afc1341d09 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py
@@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol from twisted.web.http import Request from synapse.api.constants import LoginType -from synapse.api.errors import Codes, RedirectException, SynapseError +from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.handlers.ui_auth import UIAuthSessionDataConstants from synapse.http import get_request_user_agent from synapse.http.server import respond_with_html @@ -235,7 +235,10 @@ class SsoHandler: respond_with_html(request, code, html) async def handle_redirect_request( - self, request: SynapseRequest, client_redirect_url: bytes, + self, + request: SynapseRequest, + client_redirect_url: bytes, + idp_id: Optional[str], ) -> str: """Handle a request to /login/sso/redirect @@ -243,6 +246,7 @@ class SsoHandler: request: incoming HTTP request client_redirect_url: the URL that we should redirect the client to after login. + idp_id: optional identity provider chosen by the client Returns: the URI to redirect to @@ -252,10 +256,19 @@ class SsoHandler: 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED ) + # if the client chose an IdP, use that + idp = None # type: Optional[SsoIdentityProvider] + if idp_id: + idp = self._identity_providers.get(idp_id) + if not idp: + raise NotFoundError("Unknown identity provider") + # if we only have one auth provider, redirect to it directly - if len(self._identity_providers) == 1: - ap = next(iter(self._identity_providers.values())) - return await ap.handle_redirect_request(request, client_redirect_url) + elif len(self._identity_providers) == 1: + idp = next(iter(self._identity_providers.values())) + + if idp: + return await idp.handle_redirect_request(request, client_redirect_url) # otherwise, redirect to the IDP picker return "/_synapse/client/pick_idp?" + urlencode( diff --git a/synapse/http/server.py b/synapse/http/server.py
index e464bfe6c7..d69d579b3a 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -22,10 +22,22 @@ import types import urllib from http import HTTPStatus from io import BytesIO -from typing import Any, Callable, Dict, Iterator, List, Tuple, Union +from typing import ( + Any, + Awaitable, + Callable, + Dict, + Iterable, + Iterator, + List, + Pattern, + Tuple, + Union, +) import jinja2 from canonicaljson import iterencode_canonical_json +from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -168,11 +180,25 @@ def wrap_async_request_handler(h): return preserve_fn(wrapped_async_request_handler) -class HttpServer: +# Type of a callback method for processing requests +# it is actually called with a SynapseRequest and a kwargs dict for the params, +# but I can't figure out how to represent that. +ServletCallback = Callable[ + ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]] +] + + +class HttpServer(Protocol): """ Interface for registering callbacks on a HTTP server """ - def register_paths(self, method, path_patterns, callback): + def register_paths( + self, + method: str, + path_patterns: Iterable[Pattern], + callback: ServletCallback, + servlet_classname: str, + ) -> None: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. @@ -180,12 +206,14 @@ class HttpServer: an unpacked tuple. Args: - method (str): The method to listen to. - path_patterns (list<SRE_Pattern>): The regex used to match requests. - callback (function): The function to fire if we receive a matched + method: The HTTP method to listen to. + path_patterns: The regex used to match requests. + callback: The function to fire if we receive a matched request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. - This should return a tuple of (code, response). + This should return either tuple of (code, response), or None. + servlet_classname (str): The name of the handler to be used in prometheus + and opentracing logs. """ pass @@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource): def _get_handler_for_request( self, request: SynapseRequest - ) -> Tuple[Callable, str, Dict[str, str]]: + ) -> Tuple[ServletCallback, str, Dict[str, str]]: """Finds a callback method to handle the given request. Returns: diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index be938df962..0a561eea60 100644 --- a/synapse/rest/client/v1/login.py +++ b/synapse/rest/client/v1/login.py
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional from synapse.api.errors import Codes, LoginError, SynapseError from synapse.api.ratelimiting import Ratelimiter from synapse.appservice import ApplicationService -from synapse.http.server import finish_request +from synapse.handlers.sso import SsoIdentityProvider +from synapse.http.server import HttpServer, finish_request from synapse.http.servlet import ( RestServlet, parse_json_object_from_request, @@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet): self.saml2_enabled = hs.config.saml2_enabled self.cas_enabled = hs.config.cas_enabled self.oidc_enabled = hs.config.oidc_enabled + self._msc2858_enabled = hs.config.experimental.msc2858_enabled self.auth = hs.get_auth() self.auth_handler = self.hs.get_auth_handler() self.registration_handler = hs.get_registration_handler() + self._sso_handler = hs.get_sso_handler() + self._well_known_builder = WellKnownBuilder(hs) self._address_ratelimiter = Ratelimiter( clock=hs.get_clock(), @@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet): flows.append({"type": LoginRestServlet.CAS_TYPE}) if self.cas_enabled or self.saml2_enabled or self.oidc_enabled: - flows.append({"type": LoginRestServlet.SSO_TYPE}) - # While its valid for us to advertise this login type generally, + sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict + + if self._msc2858_enabled: + sso_flow["org.matrix.msc2858.identity_providers"] = [ + _get_auth_flow_dict_for_idp(idp) + for idp in self._sso_handler.get_identity_providers().values() + ] + + flows.append(sso_flow) + + # While it's valid for us to advertise this login type generally, # synapse currently only gives out these tokens as part of the # SSO login flow. # Generally we don't want to advertise login flows that clients @@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet): return result +def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict: + """Return an entry for the login flow dict + + Returns an entry suitable for inclusion in "identity_providers" in the + response to GET /_matrix/client/r0/login + """ + e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict + if idp.idp_icon: + e["icon"] = idp.idp_icon + return e + + class SsoRedirectServlet(RestServlet): - PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True) + PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True) def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they @@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet): if hs.config.oidc_enabled: hs.get_oidc_handler() self._sso_handler = hs.get_sso_handler() + self._msc2858_enabled = hs.config.experimental.msc2858_enabled + + def register(self, http_server: HttpServer) -> None: + super().register(http_server) + if self._msc2858_enabled: + # expose additional endpoint for MSC2858 support + http_server.register_paths( + "GET", + client_patterns( + "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$", + releases=(), + unstable=True, + ), + self.on_GET, + self.__class__.__name__, + ) - async def on_GET(self, request: SynapseRequest): + async def on_GET( + self, request: SynapseRequest, idp_id: Optional[str] = None + ) -> None: client_redirect_url = parse_string( request, "redirectUrl", required=True, encoding=None ) sso_url = await self._sso_handler.handle_redirect_request( - request, client_redirect_url + request, client_redirect_url, idp_id, ) logger.info("Redirecting to %s", sso_url) request.redirect(sso_url) diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a19d65ad23..c7220bc778 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py
@@ -267,8 +267,7 @@ class LoggingTransaction: self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) else: - for val in args: - self.execute(sql, val) + self.executemany(sql, args) def execute_values(self, sql: str, *args: Any) -> List[Tuple]: """Corresponds to psycopg2.extras.execute_values. Only available when @@ -888,7 +887,7 @@ class DatabasePool: ", ".join("?" for _ in keys[0]), ) - txn.executemany(sql, vals) + txn.execute_batch(sql, vals) async def simple_upsert( self, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3216b3f3c8..5db7d7aaa8 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py
@@ -876,7 +876,7 @@ class PersistEventsStore: WHERE room_id = ? AND type = ? AND state_key = ? ) """ - txn.executemany( + txn.execute_batch( sql, ( ( @@ -895,7 +895,7 @@ class PersistEventsStore: ) # Now we actually update the current_state_events table - txn.executemany( + txn.execute_batch( "DELETE FROM current_state_events" " WHERE room_id = ? AND type = ? AND state_key = ?", ( @@ -907,7 +907,7 @@ class PersistEventsStore: # We include the membership in the current state table, hence we do # a lookup when we insert. This assumes that all events have already # been inserted into room_memberships. - txn.executemany( + txn.execute_batch( """INSERT INTO current_state_events (room_id, type, state_key, event_id, membership) VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) @@ -927,7 +927,7 @@ class PersistEventsStore: # we have no record of the fact the user *was* a member of the # room but got, say, state reset out of it. if to_delete or to_insert: - txn.executemany( + txn.execute_batch( "DELETE FROM local_current_membership" " WHERE room_id = ? AND user_id = ?", ( @@ -938,7 +938,7 @@ class PersistEventsStore: ) if to_insert: - txn.executemany( + txn.execute_batch( """INSERT INTO local_current_membership (room_id, user_id, event_id, membership) VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) @@ -1738,7 +1738,7 @@ class PersistEventsStore: """ if events_and_contexts: - txn.executemany( + txn.execute_batch( sql, ( ( @@ -1767,7 +1767,7 @@ class PersistEventsStore: # Now we delete the staging area for *all* events that were being # persisted. - txn.executemany( + txn.execute_batch( "DELETE FROM event_push_actions_staging WHERE event_id = ?", ((event.event_id,) for event, _ in all_events_and_contexts), ) @@ -1886,7 +1886,7 @@ class PersistEventsStore: " )" ) - txn.executemany( + txn.execute_batch( query, [ (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) @@ -1900,7 +1900,7 @@ class PersistEventsStore: "DELETE FROM event_backward_extremities" " WHERE event_id = ? AND room_id = ?" ) - txn.executemany( + txn.execute_batch( query, [ (ev.event_id, ev.room_id) diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2672ce24c6..e2bb945453 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"' # the query params in TEST_CLIENT_REDIRECT_URL EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')] +# (possibly experimental) login flows we expect to appear in the list after the normal +# ones +ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): d["/_synapse/oidc"] = OIDCResource(self.hs) return d + def test_get_login_flows(self): + """GET /login should return password and SSO flows""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.sso"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(channel.json_body["flows"], expected_flows) + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_get_msc2858_login_flows(self): + """The SSO flow should include IdP info if MSC2858 is enabled""" + channel = self.make_request("GET", "/_matrix/client/r0/login") + self.assertEqual(channel.code, 200, channel.result) + + # stick the flows results in a dict by type + flow_results = {} # type: Dict[str, Any] + for f in channel.json_body["flows"]: + flow_type = f["type"] + self.assertNotIn( + flow_type, flow_results, "duplicate flow type %s" % (flow_type,) + ) + flow_results[flow_type] = f + + self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned") + sso_flow = flow_results.pop("m.login.sso") + # we should have a set of IdPs + self.assertCountEqual( + sso_flow["org.matrix.msc2858.identity_providers"], + [ + {"id": "cas", "name": "CAS"}, + {"id": "saml", "name": "SAML"}, + {"id": "oidc-idp1", "name": "IDP1"}, + {"id": "oidc", "name": "OIDC"}, + ], + ) + + # the rest of the flows are simple + expected_flows = [ + {"type": "m.login.cas"}, + {"type": "m.login.token"}, + {"type": "m.login.password"}, + ] + ADDITIONAL_LOGIN_FLOWS + + self.assertCountEqual(flow_results.values(), expected_flows) + def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" # first hit the redirect url, which should redirect to our idp picker @@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): ) self.assertEqual(channel.code, 400, channel.result) + def test_client_idp_redirect_msc2858_disabled(self): + """If the client tries to pick an IdP but MSC2858 is disabled, return a 400""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 400, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_unknown(self): + """If the client tries to pick an unknown IdP, return a 404""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + self.assertEqual(channel.code, 404, channel.result) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @override_config({"experimental_features": {"msc2858_enabled": True}}) + def test_client_idp_redirect_to_oidc(self): + """If the client pick a known IdP, redirect to it""" + channel = self.make_request( + "GET", + "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), + ) + + self.assertEqual(channel.code, 302, channel.result) + oidc_uri = channel.headers.getRawHeaders("Location")[0] + oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1) + + # it should redirect us to the auth page of the OIDC server + self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT) + @staticmethod def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str: prefix = key + " = " diff --git a/tests/utils.py b/tests/utils.py
index 09614093bc..022223cf24 100644 --- a/tests/utils.py +++ b/tests/utils.py
@@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions from synapse.config.database import DatabaseConnectionConfig from synapse.config.homeserver import HomeServerConfig from synapse.config.server import DEFAULT_ROOM_VERSION -from synapse.http.server import HttpServer from synapse.logging.context import current_context, set_current_context from synapse.server import HomeServer from synapse.storage import DataStore @@ -351,7 +350,7 @@ def mock_getRawHeaders(headers=None): # This is a mock /resource/ not an entire server -class MockHttpResource(HttpServer): +class MockHttpResource: def __init__(self, prefix=""): self.callbacks = [] # 3-tuple of method/pattern/function self.prefix = prefix diff --git a/tox.ini b/tox.ini
index 801e6dea2c..0479186348 100644 --- a/tox.ini +++ b/tox.ini
@@ -18,11 +18,13 @@ deps = # installed on that). # # anyway, make sure that we have a recent enough setuptools. - setuptools>=18.5 + setuptools>=18.5 ; python_version >= '3.6' + setuptools>=18.5,<51.0.0 ; python_version < '3.6' # we also need a semi-recent version of pip, because old ones fail to # install the "enum34" dependency of cryptography. - pip>=10 + pip>=10 ; python_version >= '3.6' + pip>=10,<21.0 ; python_version < '3.6' # directories/files we run the linters on lint_targets = @@ -103,15 +105,10 @@ usedevelop=true [testenv:py35-old] skip_install=True deps = - # Ensure a version of setuptools that supports Python 3.5 is installed. - setuptools < 51.0.0 - # Old automat version for Twisted Automat == 0.3.0 - lxml - coverage - coverage-enable-subprocess==1.0 + {[base]deps} commands = # Make all greater-thans equals so we test the oldest version of our direct @@ -168,6 +165,8 @@ commands = {toxinidir}/scripts-dev/generate_sample_config --check skip_install = True deps = coverage + pip>=10 ; python_version >= '3.6' + pip>=10,<21.0 ; python_version < '3.6' commands= coverage combine coverage report