diff --git a/.buildkite/scripts/test_old_deps.sh b/.buildkite/scripts/test_old_deps.sh
index 9905c4bc4f..28e6694b5d 100755
--- a/.buildkite/scripts/test_old_deps.sh
+++ b/.buildkite/scripts/test_old_deps.sh
@@ -10,4 +10,7 @@ apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev x
export LANG="C.UTF-8"
+# Prevent virtualenv from auto-updating pip to an incompatible version
+export VIRTUALENV_NO_DOWNLOAD=1
+
exec tox -e py35-old,combine
diff --git a/changelog.d/9129.misc b/changelog.d/9129.misc
new file mode 100644
index 0000000000..7800be3e7e
--- /dev/null
+++ b/changelog.d/9129.misc
@@ -0,0 +1 @@
+Various improvements to the federation client.
diff --git a/changelog.d/9135.doc b/changelog.d/9135.doc
new file mode 100644
index 0000000000..d11ba70de4
--- /dev/null
+++ b/changelog.d/9135.doc
@@ -0,0 +1 @@
+Add link to Matrix VoIP tester for turn-howto.
diff --git a/changelog.d/9180.misc b/changelog.d/9180.misc
new file mode 100644
index 0000000000..69dd86110d
--- /dev/null
+++ b/changelog.d/9180.misc
@@ -0,0 +1 @@
+Add a `long_description_type` to the package metadata.
diff --git a/changelog.d/9181.misc b/changelog.d/9181.misc
new file mode 100644
index 0000000000..7820d09cd0
--- /dev/null
+++ b/changelog.d/9181.misc
@@ -0,0 +1 @@
+Speed up batch insertion when using PostgreSQL.
diff --git a/changelog.d/9183.feature b/changelog.d/9183.feature
new file mode 100644
index 0000000000..2d5c735042
--- /dev/null
+++ b/changelog.d/9183.feature
@@ -0,0 +1 @@
+Add experimental support for allowing clients to pick an SSO Identity Provider ([MSC2858](https://github.com/matrix-org/matrix-doc/pull/2858).
diff --git a/changelog.d/9184.misc b/changelog.d/9184.misc
new file mode 100644
index 0000000000..70da3d6cf5
--- /dev/null
+++ b/changelog.d/9184.misc
@@ -0,0 +1 @@
+Emit an error at startup if different Identity Providers are configured with the same `idp_id`.
diff --git a/changelog.d/9217.misc b/changelog.d/9217.misc
new file mode 100644
index 0000000000..72bacc7110
--- /dev/null
+++ b/changelog.d/9217.misc
@@ -0,0 +1 @@
+Fix the Python 3.5 old dependencies build.
diff --git a/docs/turn-howto.md b/docs/turn-howto.md
index a470c274a5..e8f13ad484 100644
--- a/docs/turn-howto.md
+++ b/docs/turn-howto.md
@@ -232,6 +232,12 @@ Here are a few things to try:
(Understanding the output is beyond the scope of this document!)
+ * You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/.
+ Note that this test is not fully reliable yet, so don't be discouraged if
+ the test fails.
+ [Here](https://github.com/matrix-org/voip-tester) is the github repo of the
+ source of the tester, where you can file bug reports.
+
* There is a WebRTC test tool at
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
use it, you will need a username/password for your TURN server. You can
diff --git a/setup.py b/setup.py
index 9730afb41b..ddbe9f511a 100755
--- a/setup.py
+++ b/setup.py
@@ -121,6 +121,7 @@ setup(
include_package_data=True,
zip_safe=False,
long_description=long_description,
+ long_description_content_type="text/x-rst",
python_requires="~=3.5",
classifiers=[
"Development Status :: 5 - Production/Stable",
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 29aa064e57..3ccea4b02d 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -9,6 +9,7 @@ from synapse.config import (
consent_config,
database,
emailconfig,
+ experimental,
groups,
jwt_config,
key,
@@ -48,6 +49,7 @@ def path_exists(file_path: str): ...
class RootConfig:
server: server.ServerConfig
+ experimental: experimental.ExperimentalConfig
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
new file mode 100644
index 0000000000..b1c1c51e4d
--- /dev/null
+++ b/synapse/config/experimental.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config
+from synapse.types import JsonDict
+
+
+class ExperimentalConfig(Config):
+ """Config section for enabling experimental features"""
+
+ section = "experimental"
+
+ def read_config(self, config: JsonDict, **kwargs):
+ experimental = config.get("experimental_features") or {}
+
+ # MSC2858 (multiple SSO identity providers)
+ self.msc2858_enabled = experimental.get("msc2858_enabled", False) # type: bool
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 4bd2b3587b..64a2429f77 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -24,6 +24,7 @@ from .cas import CasConfig
from .consent_config import ConsentConfig
from .database import DatabaseConfig
from .emailconfig import EmailConfig
+from .experimental import ExperimentalConfig
from .federation import FederationConfig
from .groups import GroupsConfig
from .jwt_config import JWTConfig
@@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
+ ExperimentalConfig,
TlsConfig,
FederationConfig,
CacheConfig,
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index d58a83be7f..bfeceeed18 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -15,6 +15,7 @@
# limitations under the License.
import string
+from collections import Counter
from typing import Iterable, Optional, Tuple, Type
import attr
@@ -43,6 +44,16 @@ class OIDCConfig(Config):
except DependencyException as e:
raise ConfigError(e.message) from e
+ # check we don't have any duplicate idp_ids now. (The SSO handler will also
+ # check for duplicates when the REST listeners get registered, but that happens
+ # after synapse has forked so doesn't give nice errors.)
+ c = Counter([i.idp_id for i in self.oidc_providers])
+ for idp_id, count in c.items():
+ if count > 1:
+ raise ConfigError(
+ "Multiple OIDC providers have the idp_id %r." % idp_id
+ )
+
public_baseurl = self.public_baseurl
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 302b2f69bc..d330ae5dbc 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -18,6 +18,7 @@ import copy
import itertools
import logging
from typing import (
+ TYPE_CHECKING,
Any,
Awaitable,
Callable,
@@ -26,7 +27,6 @@ from typing import (
List,
Mapping,
Optional,
- Sequence,
Tuple,
TypeVar,
Union,
@@ -61,6 +61,9 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -80,10 +83,10 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase):
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
- self.pdu_destination_tried = {}
+ self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]]
self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client()
@@ -116,33 +119,32 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict
@log_function
- def make_query(
+ async def make_query(
self,
- destination,
- query_type,
- args,
- retry_on_dns_fail=False,
- ignore_backoff=False,
- ):
+ destination: str,
+ query_type: str,
+ args: dict,
+ retry_on_dns_fail: bool = False,
+ ignore_backoff: bool = False,
+ ) -> JsonDict:
"""Sends a federation Query to a remote homeserver of the given type
and arguments.
Args:
- destination (str): Domain name of the remote homeserver
- query_type (str): Category of the query type; should match the
+ destination: Domain name of the remote homeserver
+ query_type: Category of the query type; should match the
handler name used in register_query_handler().
- args (dict): Mapping of strings to strings containing the details
+ args: Mapping of strings to strings containing the details
of the query request.
- ignore_backoff (bool): true to ignore the historical backoff data
+ ignore_backoff: true to ignore the historical backoff data
and try the request anyway.
Returns:
- a Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels(query_type).inc()
- return self.transport_layer.make_query(
+ return await self.transport_layer.make_query(
destination,
query_type,
args,
@@ -151,42 +153,52 @@ class FederationClient(FederationBase):
)
@log_function
- def query_client_keys(self, destination, content, timeout):
+ async def query_client_keys(
+ self, destination: str, content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Query device keys for a device hosted on a remote server.
Args:
- destination (str): Domain name of the remote homeserver
- content (dict): The query content.
+ destination: Domain name of the remote homeserver
+ content: The query content.
Returns:
- an Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels("client_device_keys").inc()
- return self.transport_layer.query_client_keys(destination, content, timeout)
+ return await self.transport_layer.query_client_keys(
+ destination, content, timeout
+ )
@log_function
- def query_user_devices(self, destination, user_id, timeout=30000):
+ async def query_user_devices(
+ self, destination: str, user_id: str, timeout: int = 30000
+ ) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.labels("user_devices").inc()
- return self.transport_layer.query_user_devices(destination, user_id, timeout)
+ return await self.transport_layer.query_user_devices(
+ destination, user_id, timeout
+ )
@log_function
- def claim_client_keys(self, destination, content, timeout):
+ async def claim_client_keys(
+ self, destination: str, content: JsonDict, timeout: int
+ ) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server.
Args:
- destination (str): Domain name of the remote homeserver
- content (dict): The query content.
+ destination: Domain name of the remote homeserver
+ content: The query content.
Returns:
- an Awaitable which will eventually yield a JSON object from the
- response
+ The JSON object from the response
"""
sent_queries_counter.labels("client_one_time_keys").inc()
- return self.transport_layer.claim_client_keys(destination, content, timeout)
+ return await self.transport_layer.claim_client_keys(
+ destination, content, timeout
+ )
async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@@ -195,10 +207,10 @@ class FederationClient(FederationBase):
given destination server.
Args:
- dest (str): The remote homeserver to ask.
- room_id (str): The room_id to backfill.
- limit (int): The maximum number of events to return.
- extremities (list): our current backwards extremities, to backfill from
+ dest: The remote homeserver to ask.
+ room_id: The room_id to backfill.
+ limit: The maximum number of events to return.
+ extremities: our current backwards extremities, to backfill from
"""
logger.debug("backfill extrem=%s", extremities)
@@ -370,7 +382,7 @@ class FederationClient(FederationBase):
for events that have failed their checks
Returns:
- Deferred : A list of PDUs that have valid signatures and hashes.
+ A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@@ -418,7 +430,9 @@ class FederationClient(FederationBase):
else:
return [p for p in valid_pdus if p]
- async def get_event_auth(self, destination, room_id, event_id):
+ async def get_event_auth(
+ self, destination: str, room_id: str, event_id: str
+ ) -> List[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version(room_id)
@@ -700,18 +714,16 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request)
- async def _do_send_join(self, destination: str, pdu: EventBase):
+ async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_join_v2(
+ return await self.transport_layer.send_join_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
-
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -769,7 +781,7 @@ class FederationClient(FederationBase):
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_invite_v2(
+ return await self.transport_layer.send_invite_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
@@ -779,7 +791,6 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []),
},
)
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -842,18 +853,16 @@ class FederationClient(FederationBase):
"send_leave", destinations, send_request
)
- async def _do_send_leave(self, destination, pdu):
+ async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict:
time_now = self._clock.time_msec()
try:
- content = await self.transport_layer.send_leave_v2(
+ return await self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
-
- return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
@@ -879,7 +888,7 @@ class FederationClient(FederationBase):
# content.
return resp[1]
- def get_public_rooms(
+ async def get_public_rooms(
self,
remote_server: str,
limit: Optional[int] = None,
@@ -887,7 +896,7 @@ class FederationClient(FederationBase):
search_filter: Optional[Dict] = None,
include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None,
- ):
+ ) -> JsonDict:
"""Get the list of public rooms from a remote homeserver
Args:
@@ -901,8 +910,7 @@ class FederationClient(FederationBase):
party instance
Returns:
- Awaitable[Dict[str, Any]]: The response from the remote server, or None if
- `remote_server` is the same as the local server_name
+ The response from the remote server.
Raises:
HttpResponseException: There was an exception returned from the remote server
@@ -910,7 +918,7 @@ class FederationClient(FederationBase):
requests over federation
"""
- return self.transport_layer.get_public_rooms(
+ return await self.transport_layer.get_public_rooms(
remote_server,
limit,
since_token,
@@ -923,7 +931,7 @@ class FederationClient(FederationBase):
self,
destination: str,
room_id: str,
- earliest_events_ids: Sequence[str],
+ earliest_events_ids: Iterable[str],
latest_events: Iterable[EventBase],
limit: int,
min_depth: int,
@@ -974,7 +982,9 @@ class FederationClient(FederationBase):
return signed_events
- async def forward_third_party_invite(self, destinations, room_id, event_dict):
+ async def forward_third_party_invite(
+ self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
+ ) -> None:
for destination in destinations:
if destination == self.server_name:
continue
@@ -983,7 +993,7 @@ class FederationClient(FederationBase):
await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict
)
- return None
+ return
except CodeMessageException:
raise
except Exception as e:
@@ -995,7 +1005,7 @@ class FederationClient(FederationBase):
async def get_room_complexity(
self, destination: str, room_id: str
- ) -> Optional[dict]:
+ ) -> Optional[JsonDict]:
"""
Fetch the complexity of a remote room from another server.
@@ -1008,10 +1018,9 @@ class FederationClient(FederationBase):
could not fetch the complexity.
"""
try:
- complexity = await self.transport_layer.get_room_complexity(
+ return await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
- return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d493327a10..afc1341d09 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
from synapse.handlers.ui_auth import UIAuthSessionDataConstants
from synapse.http import get_request_user_agent
from synapse.http.server import respond_with_html
@@ -235,7 +235,10 @@ class SsoHandler:
respond_with_html(request, code, html)
async def handle_redirect_request(
- self, request: SynapseRequest, client_redirect_url: bytes,
+ self,
+ request: SynapseRequest,
+ client_redirect_url: bytes,
+ idp_id: Optional[str],
) -> str:
"""Handle a request to /login/sso/redirect
@@ -243,6 +246,7 @@ class SsoHandler:
request: incoming HTTP request
client_redirect_url: the URL that we should redirect the
client to after login.
+ idp_id: optional identity provider chosen by the client
Returns:
the URI to redirect to
@@ -252,10 +256,19 @@ class SsoHandler:
400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
)
+ # if the client chose an IdP, use that
+ idp = None # type: Optional[SsoIdentityProvider]
+ if idp_id:
+ idp = self._identity_providers.get(idp_id)
+ if not idp:
+ raise NotFoundError("Unknown identity provider")
+
# if we only have one auth provider, redirect to it directly
- if len(self._identity_providers) == 1:
- ap = next(iter(self._identity_providers.values()))
- return await ap.handle_redirect_request(request, client_redirect_url)
+ elif len(self._identity_providers) == 1:
+ idp = next(iter(self._identity_providers.values()))
+
+ if idp:
+ return await idp.handle_redirect_request(request, client_redirect_url)
# otherwise, redirect to the IDP picker
return "/_synapse/client/pick_idp?" + urlencode(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index e464bfe6c7..d69d579b3a 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,10 +22,22 @@ import types
import urllib
from http import HTTPStatus
from io import BytesIO
-from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
+from typing import (
+ Any,
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ Iterator,
+ List,
+ Pattern,
+ Tuple,
+ Union,
+)
import jinja2
from canonicaljson import iterencode_canonical_json
+from typing_extensions import Protocol
from zope.interface import implementer
from twisted.internet import defer, interfaces
@@ -168,11 +180,25 @@ def wrap_async_request_handler(h):
return preserve_fn(wrapped_async_request_handler)
-class HttpServer:
+# Type of a callback method for processing requests
+# it is actually called with a SynapseRequest and a kwargs dict for the params,
+# but I can't figure out how to represent that.
+ServletCallback = Callable[
+ ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
+]
+
+
+class HttpServer(Protocol):
""" Interface for registering callbacks on a HTTP server
"""
- def register_paths(self, method, path_patterns, callback):
+ def register_paths(
+ self,
+ method: str,
+ path_patterns: Iterable[Pattern],
+ callback: ServletCallback,
+ servlet_classname: str,
+ ) -> None:
""" Register a callback that gets fired if we receive a http request
with the given method for a path that matches the given regex.
@@ -180,12 +206,14 @@ class HttpServer:
an unpacked tuple.
Args:
- method (str): The method to listen to.
- path_patterns (list<SRE_Pattern>): The regex used to match requests.
- callback (function): The function to fire if we receive a matched
+ method: The HTTP method to listen to.
+ path_patterns: The regex used to match requests.
+ callback: The function to fire if we receive a matched
request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex.
- This should return a tuple of (code, response).
+ This should return either tuple of (code, response), or None.
+ servlet_classname (str): The name of the handler to be used in prometheus
+ and opentracing logs.
"""
pass
@@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource):
def _get_handler_for_request(
self, request: SynapseRequest
- ) -> Tuple[Callable, str, Dict[str, str]]:
+ ) -> Tuple[ServletCallback, str, Dict[str, str]]:
"""Finds a callback method to handle the given request.
Returns:
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index be938df962..0a561eea60 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
from synapse.http.servlet import (
RestServlet,
parse_json_object_from_request,
@@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet):
self.saml2_enabled = hs.config.saml2_enabled
self.cas_enabled = hs.config.cas_enabled
self.oidc_enabled = hs.config.oidc_enabled
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
self.auth = hs.get_auth()
self.auth_handler = self.hs.get_auth_handler()
self.registration_handler = hs.get_registration_handler()
+ self._sso_handler = hs.get_sso_handler()
+
self._well_known_builder = WellKnownBuilder(hs)
self._address_ratelimiter = Ratelimiter(
clock=hs.get_clock(),
@@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet):
flows.append({"type": LoginRestServlet.CAS_TYPE})
if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
- flows.append({"type": LoginRestServlet.SSO_TYPE})
- # While its valid for us to advertise this login type generally,
+ sso_flow = {"type": LoginRestServlet.SSO_TYPE} # type: JsonDict
+
+ if self._msc2858_enabled:
+ sso_flow["org.matrix.msc2858.identity_providers"] = [
+ _get_auth_flow_dict_for_idp(idp)
+ for idp in self._sso_handler.get_identity_providers().values()
+ ]
+
+ flows.append(sso_flow)
+
+ # While it's valid for us to advertise this login type generally,
# synapse currently only gives out these tokens as part of the
# SSO login flow.
# Generally we don't want to advertise login flows that clients
@@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet):
return result
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+ """Return an entry for the login flow dict
+
+ Returns an entry suitable for inclusion in "identity_providers" in the
+ response to GET /_matrix/client/r0/login
+ """
+ e = {"id": idp.idp_id, "name": idp.idp_name} # type: JsonDict
+ if idp.idp_icon:
+ e["icon"] = idp.idp_icon
+ return e
+
+
class SsoRedirectServlet(RestServlet):
- PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+ PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
def __init__(self, hs: "HomeServer"):
# make sure that the relevant handlers are instantiated, so that they
@@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet):
if hs.config.oidc_enabled:
hs.get_oidc_handler()
self._sso_handler = hs.get_sso_handler()
+ self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+ def register(self, http_server: HttpServer) -> None:
+ super().register(http_server)
+ if self._msc2858_enabled:
+ # expose additional endpoint for MSC2858 support
+ http_server.register_paths(
+ "GET",
+ client_patterns(
+ "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+ releases=(),
+ unstable=True,
+ ),
+ self.on_GET,
+ self.__class__.__name__,
+ )
- async def on_GET(self, request: SynapseRequest):
+ async def on_GET(
+ self, request: SynapseRequest, idp_id: Optional[str] = None
+ ) -> None:
client_redirect_url = parse_string(
request, "redirectUrl", required=True, encoding=None
)
sso_url = await self._sso_handler.handle_redirect_request(
- request, client_redirect_url
+ request, client_redirect_url, idp_id,
)
logger.info("Redirecting to %s", sso_url)
request.redirect(sso_url)
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index a19d65ad23..c7220bc778 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -267,8 +267,7 @@ class LoggingTransaction:
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else:
- for val in args:
- self.execute(sql, val)
+ self.executemany(sql, args)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when
@@ -888,7 +887,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]),
)
- txn.executemany(sql, vals)
+ txn.execute_batch(sql, vals)
async def simple_upsert(
self,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 3216b3f3c8..5db7d7aaa8 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -876,7 +876,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ?
)
"""
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -895,7 +895,7 @@ class PersistEventsStore:
)
# Now we actually update the current_state_events table
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?",
(
@@ -907,7 +907,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships.
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -927,7 +927,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it.
if to_delete or to_insert:
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?",
(
@@ -938,7 +938,7 @@ class PersistEventsStore:
)
if to_insert:
- txn.executemany(
+ txn.execute_batch(
"""INSERT INTO local_current_membership
(room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -1738,7 +1738,7 @@ class PersistEventsStore:
"""
if events_and_contexts:
- txn.executemany(
+ txn.execute_batch(
sql,
(
(
@@ -1767,7 +1767,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being
# persisted.
- txn.executemany(
+ txn.execute_batch(
"DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts),
)
@@ -1886,7 +1886,7 @@ class PersistEventsStore:
" )"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1900,7 +1900,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?"
)
- txn.executemany(
+ txn.execute_batch(
query,
[
(ev.event_id, ev.room_id)
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 2672ce24c6..e2bb945453 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -75,6 +75,10 @@ TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
# the query params in TEST_CLIENT_REDIRECT_URL
EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+# (possibly experimental) login flows we expect to appear in the list after the normal
+# ones
+ADDITIONAL_LOGIN_FLOWS = [{"type": "uk.half-shot.msc2778.login.application_service"}]
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -426,6 +430,57 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
d["/_synapse/oidc"] = OIDCResource(self.hs)
return d
+ def test_get_login_flows(self):
+ """GET /login should return password and SSO flows"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.sso"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(channel.json_body["flows"], expected_flows)
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_get_msc2858_login_flows(self):
+ """The SSO flow should include IdP info if MSC2858 is enabled"""
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
+ self.assertEqual(channel.code, 200, channel.result)
+
+ # stick the flows results in a dict by type
+ flow_results = {} # type: Dict[str, Any]
+ for f in channel.json_body["flows"]:
+ flow_type = f["type"]
+ self.assertNotIn(
+ flow_type, flow_results, "duplicate flow type %s" % (flow_type,)
+ )
+ flow_results[flow_type] = f
+
+ self.assertIn("m.login.sso", flow_results, "m.login.sso was not returned")
+ sso_flow = flow_results.pop("m.login.sso")
+ # we should have a set of IdPs
+ self.assertCountEqual(
+ sso_flow["org.matrix.msc2858.identity_providers"],
+ [
+ {"id": "cas", "name": "CAS"},
+ {"id": "saml", "name": "SAML"},
+ {"id": "oidc-idp1", "name": "IDP1"},
+ {"id": "oidc", "name": "OIDC"},
+ ],
+ )
+
+ # the rest of the flows are simple
+ expected_flows = [
+ {"type": "m.login.cas"},
+ {"type": "m.login.token"},
+ {"type": "m.login.password"},
+ ] + ADDITIONAL_LOGIN_FLOWS
+
+ self.assertCountEqual(flow_results.values(), expected_flows)
+
def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
# first hit the redirect url, which should redirect to our idp picker
@@ -564,6 +619,43 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 400, channel.result)
+ def test_client_idp_redirect_msc2858_disabled(self):
+ """If the client tries to pick an IdP but MSC2858 is disabled, return a 400"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+ self.assertEqual(channel.code, 400, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_UNRECOGNIZED")
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_idp_redirect_to_unknown(self):
+ """If the client tries to pick an unknown IdP, return a 404"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/xxx?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+ self.assertEqual(channel.code, 404, channel.result)
+ self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+ @override_config({"experimental_features": {"msc2858_enabled": True}})
+ def test_client_idp_redirect_to_oidc(self):
+ """If the client pick a known IdP, redirect to it"""
+ channel = self.make_request(
+ "GET",
+ "/_matrix/client/unstable/org.matrix.msc2858/login/sso/redirect/oidc?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
+ )
+
+ self.assertEqual(channel.code, 302, channel.result)
+ oidc_uri = channel.headers.getRawHeaders("Location")[0]
+ oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
+
+ # it should redirect us to the auth page of the OIDC server
+ self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+
@staticmethod
def _get_value_from_macaroon(macaroon: pymacaroons.Macaroon, key: str) -> str:
prefix = key + " = "
diff --git a/tests/utils.py b/tests/utils.py
index 09614093bc..022223cf24 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -33,7 +33,6 @@ from synapse.api.room_versions import RoomVersions
from synapse.config.database import DatabaseConnectionConfig
from synapse.config.homeserver import HomeServerConfig
from synapse.config.server import DEFAULT_ROOM_VERSION
-from synapse.http.server import HttpServer
from synapse.logging.context import current_context, set_current_context
from synapse.server import HomeServer
from synapse.storage import DataStore
@@ -351,7 +350,7 @@ def mock_getRawHeaders(headers=None):
# This is a mock /resource/ not an entire server
-class MockHttpResource(HttpServer):
+class MockHttpResource:
def __init__(self, prefix=""):
self.callbacks = [] # 3-tuple of method/pattern/function
self.prefix = prefix
diff --git a/tox.ini b/tox.ini
index 801e6dea2c..0479186348 100644
--- a/tox.ini
+++ b/tox.ini
@@ -18,11 +18,13 @@ deps =
# installed on that).
#
# anyway, make sure that we have a recent enough setuptools.
- setuptools>=18.5
+ setuptools>=18.5 ; python_version >= '3.6'
+ setuptools>=18.5,<51.0.0 ; python_version < '3.6'
# we also need a semi-recent version of pip, because old ones fail to
# install the "enum34" dependency of cryptography.
- pip>=10
+ pip>=10 ; python_version >= '3.6'
+ pip>=10,<21.0 ; python_version < '3.6'
# directories/files we run the linters on
lint_targets =
@@ -103,15 +105,10 @@ usedevelop=true
[testenv:py35-old]
skip_install=True
deps =
- # Ensure a version of setuptools that supports Python 3.5 is installed.
- setuptools < 51.0.0
-
# Old automat version for Twisted
Automat == 0.3.0
-
lxml
- coverage
- coverage-enable-subprocess==1.0
+ {[base]deps}
commands =
# Make all greater-thans equals so we test the oldest version of our direct
@@ -168,6 +165,8 @@ commands = {toxinidir}/scripts-dev/generate_sample_config --check
skip_install = True
deps =
coverage
+ pip>=10 ; python_version >= '3.6'
+ pip>=10,<21.0 ; python_version < '3.6'
commands=
coverage combine
coverage report
|