diff --git a/changelog.d/8858.bugfix b/changelog.d/8858.bugfix
new file mode 100644
index 0000000000..0d58cb9abc
--- /dev/null
+++ b/changelog.d/8858.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug on Synapse instances supporting Single-Sign-On, where users would be prompted to enter their password to confirm certain actions, even though they have not set a password.
diff --git a/changelog.d/8864.misc b/changelog.d/8864.misc
new file mode 100644
index 0000000000..a780883495
--- /dev/null
+++ b/changelog.d/8864.misc
@@ -0,0 +1 @@
+Remove unused `FakeResponse` class from unit tests.
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b53e7a20ec..434718ddfc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args:
hs (synapse.server.HomeServer): homeserver
- resource (TransportLayerServer): resource class to register to
+ resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..2e72298e05 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
self.hs = hs # FIXME better possibility to access registrationHandler later?
self.macaroon_gen = hs.get_macaroon_generator()
self._password_enabled = hs.config.password_enabled
- self._sso_enabled = (
- hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
- )
+ self._password_localdb_enabled = hs.config.password_localdb_enabled
# we keep this as a list despite the O(N^2) implication so that we can
# keep PASSWORD first and avoid confusing clients which pick the first
@@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
# start out by assuming PASSWORD is enabled; we will remove it later if not.
login_types = []
- if hs.config.password_localdb_enabled:
+ if self._password_localdb_enabled:
login_types.append(LoginType.PASSWORD)
for provider in self.password_providers:
@@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
self._supported_login_types = login_types
- # Login types and UI Auth types have a heavy overlap, but are not
- # necessarily identical. Login types have SSO (and other login types)
- # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
- ui_auth_types = login_types.copy()
- if self._sso_enabled:
- ui_auth_types.append(LoginType.SSO)
- self._supported_ui_auth_types = ui_auth_types
-
# Ratelimiter for failed auth during UIA. Uses same ratelimit config
# as per `rc_login.failed_attempts`.
self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
# build a list of supported flows
- flows = [[login_type] for login_type in self._supported_ui_auth_types]
+ supported_ui_auth_types = await self._get_available_ui_auth_types(
+ requester.user
+ )
+ flows = [[login_type] for login_type in supported_ui_auth_types]
try:
result, params, session_id = await self.check_ui_auth(
@@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
raise
# find the completed login type
- for login_type in self._supported_ui_auth_types:
+ for login_type in supported_ui_auth_types:
if login_type not in result:
continue
@@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
return params, session_id
+ async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+ """Get a list of the authentication types this user can use
+ """
+
+ ui_auth_types = set()
+
+ # if the HS supports password auth, and the user has a non-null password, we
+ # support password auth
+ if self._password_localdb_enabled and self._password_enabled:
+ lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+ if lookupres:
+ _, password_hash = lookupres
+ if password_hash:
+ ui_auth_types.add(LoginType.PASSWORD)
+
+ # also allow auth from password providers
+ for provider in self.password_providers:
+ for t in provider.get_supported_login_types().keys():
+ if t == LoginType.PASSWORD and not self._password_enabled:
+ continue
+ ui_auth_types.add(t)
+
+ # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+ # from sso to mxid.
+ if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+ if await self.store.get_external_ids_by_user(user.to_string()):
+ ui_auth_types.add(LoginType.SSO)
+
+ # Our CAS impl does not (yet) correctly register users in user_external_ids,
+ # so always offer that if it's available.
+ if self.hs.config.cas.cas_enabled:
+ ui_auth_types.add(LoginType.SSO)
+
+ return ui_auth_types
+
def get_enabled_auth_types(self):
"""Return the enabled user-interactive authentication types
@@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
if result:
return result
- if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+ if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
known_login_type = True
# we've already checked that there is a (valid) password field
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index fedb8a6c26..ff96c34c2e 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -463,6 +463,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
desc="get_user_by_external_id",
)
+ async def get_external_ids_by_user(self, mxid: str) -> List[Tuple[str, str]]:
+ """Look up external ids for the given user
+
+ Args:
+ mxid: the MXID to be looked up
+
+ Returns:
+ Tuples of (auth_provider, external_id)
+ """
+ res = await self.db_pool.simple_select_list(
+ table="user_external_ids",
+ keyvalues={"user_id": mxid},
+ retcols=("auth_provider", "external_id"),
+ desc="get_external_ids_by_user",
+ )
+ return [(r["auth_provider"], r["external_id"]) for r in res]
+
async def count_all_users(self):
"""Counts all users registered on the homeserver."""
@@ -963,6 +980,14 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
+ self.db_pool.updates.register_background_index_update(
+ "user_external_ids_user_id_idx",
+ index_name="user_external_ids_user_id_idx",
+ table="user_external_ids",
+ columns=["user_id"],
+ unique=False,
+ )
+
async def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
for each of them.
diff --git a/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
new file mode 100644
index 0000000000..8f5e65aa71
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/58/25user_external_ids_user_id_idx.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
+ (5825, 'user_external_ids_user_id_idx', '{}');
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index d485af52fd..23ea1e3543 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -17,30 +17,15 @@ from urllib.parse import parse_qs, urlparse
from mock import Mock, patch
-import attr
import pymacaroons
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
from synapse.types import UserID
+from tests.test_utils import FakeResponse
from tests.unittest import HomeserverTestCase, override_config
-
-@attr.s
-class FakeResponse:
- code = attr.ib()
- body = attr.ib()
- phrase = attr.ib()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
-
-
# These are a few constants that are used as config parameters in the tests.
ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index 294b71e8e2..f21de958f1 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,15 +15,16 @@
import json
+from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.web.resource import Resource
from synapse.api.errors import AuthError
-from synapse.federation.transport import server as federation_server
+from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
-from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
from tests.test_utils import make_awaitable
@@ -53,20 +54,7 @@ def _make_edu_transaction_json(edu_type, content):
return json.dumps(_expect_edu_transaction(edu_type, content)).encode("utf8")
-def register_federation_servlets(hs, resource):
- federation_server.register_servlets(
- hs,
- resource=resource,
- authenticator=federation_server.Authenticator(hs),
- ratelimiter=FederationRateLimiter(
- hs.get_clock(), config=hs.config.rc_federation
- ),
- )
-
-
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- servlets = [register_federation_servlets]
-
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -89,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return hs
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
+
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 728de28277..3379189785 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- servlets = [
- streams.register_servlets,
- ]
-
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index 737c38c396..5a18af8d34 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -2,7 +2,7 @@
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2017 Vector Creations Ltd
# Copyright 2018-2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019-2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -17,17 +17,23 @@
# limitations under the License.
import json
+import re
import time
+import urllib.parse
from typing import Any, Dict, Optional
+from mock import patch
+
import attr
from twisted.web.resource import Resource
from twisted.web.server import Site
from synapse.api.constants import Membership
+from synapse.types import JsonDict
from tests.server import FakeSite, make_request
+from tests.test_utils import FakeResponse
@attr.s
@@ -344,3 +350,111 @@ class RestHelper:
)
return channel.json_body
+
+ def login_via_oidc(self, remote_user_id: str) -> JsonDict:
+ """Log in (as a new user) via OIDC
+
+ Returns the result of the final token login.
+
+ Requires that "oidc_config" in the homeserver config be set appropriately
+ (TEST_OIDC_CONFIG is a suitable example) - and by implication, needs a
+ "public_base_url".
+
+ Also requires the login servlet and the OIDC callback resource to be mounted at
+ the normal places.
+ """
+ client_redirect_url = "https://x"
+
+ # first hit the redirect url (which will issue a cookie and state)
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ "/login/sso/redirect?redirectUrl=" + client_redirect_url,
+ )
+ # that will redirect to the OIDC IdP, but we skip that and go straight
+ # back to synapse's OIDC callback resource. However, we do need the "state"
+ # param that synapse passes to the IdP via query params, and the cookie that
+ # synapse passes to the client.
+ assert channel.code == 302
+ oauth_uri = channel.headers.getRawHeaders("Location")[0]
+ params = urllib.parse.parse_qs(urllib.parse.urlparse(oauth_uri).query)
+ redirect_uri = "%s?%s" % (
+ urllib.parse.urlparse(params["redirect_uri"][0]).path,
+ urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+ )
+ cookies = {}
+ for h in channel.headers.getRawHeaders("Set-Cookie"):
+ parts = h.split(";")
+ k, v = parts[0].split("=", maxsplit=1)
+ cookies[k] = v
+
+ # before we hit the callback uri, stub out some methods in the http client so
+ # that we don't have to handle full HTTPS requests.
+
+ # (expected url, json response) pairs, in the order we expect them.
+ expected_requests = [
+ # first we get a hit to the token endpoint, which we tell to return
+ # a dummy OIDC access token
+ ("https://issuer.test/token", {"access_token": "TEST"}),
+ # and then one to the user_info endpoint, which returns our remote user id.
+ ("https://issuer.test/userinfo", {"sub": remote_user_id}),
+ ]
+
+ async def mock_req(method: str, uri: str, data=None, headers=None):
+ (expected_uri, resp_obj) = expected_requests.pop(0)
+ assert uri == expected_uri
+ resp = FakeResponse(
+ code=200, phrase=b"OK", body=json.dumps(resp_obj).encode("utf-8"),
+ )
+ return resp
+
+ with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+ # now hit the callback URI with the right params and a made-up code
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "GET",
+ redirect_uri,
+ custom_headers=[
+ ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
+ ],
+ )
+
+ # expect a confirmation page
+ assert channel.code == 200
+
+ # fish the matrix login token out of the body of the confirmation page
+ m = re.search(
+ 'a href="%s.*loginToken=([^"]*)"' % (client_redirect_url,),
+ channel.result["body"].decode("utf-8"),
+ )
+ assert m
+ login_token = m.group(1)
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token and device id.
+ _, channel = make_request(
+ self.hs.get_reactor(),
+ self.site,
+ "POST",
+ "/login",
+ content={"type": "m.login.token", "token": login_token},
+ )
+ assert channel.code == 200
+ return channel.json_body
+
+
+# an 'oidc_config' suitable for login_with_oidc.
+TEST_OIDC_CONFIG = {
+ "enabled": True,
+ "discover": False,
+ "issuer": "https://issuer.test",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": "https://z",
+ "token_endpoint": "https://issuer.test/token",
+ "userinfo_endpoint": "https://issuer.test/userinfo",
+ "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 77246e478f..ac67a9de29 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from typing import List, Union
from twisted.internet.defer import succeed
@@ -22,9 +23,11 @@ from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
from synapse.http.site import SynapseRequest
from synapse.rest.client.v1 import login
from synapse.rest.client.v2_alpha import auth, devices, register
-from synapse.types import JsonDict
+from synapse.rest.oidc import OIDCResource
+from synapse.types import JsonDict, UserID
from tests import unittest
+from tests.rest.client.v1.utils import TEST_OIDC_CONFIG
from tests.server import FakeChannel
@@ -156,27 +159,45 @@ class UIAuthTests(unittest.HomeserverTestCase):
register.register_servlets,
]
+ def default_config(self):
+ config = super().default_config()
+
+ # we enable OIDC as a way of testing SSO flows
+ oidc_config = {}
+ oidc_config.update(TEST_OIDC_CONFIG)
+ oidc_config["allow_existing_users"] = True
+
+ config["oidc_config"] = oidc_config
+ config["public_baseurl"] = "https://synapse.test"
+ return config
+
+ def create_resource_dict(self):
+ resource_dict = super().create_resource_dict()
+ # mount the OIDC resource at /_synapse/oidc
+ resource_dict["/_synapse/oidc"] = OIDCResource(self.hs)
+ return resource_dict
+
def prepare(self, reactor, clock, hs):
self.user_pass = "pass"
self.user = self.register_user("test", self.user_pass)
self.user_tok = self.login("test", self.user_pass)
- def get_device_ids(self) -> List[str]:
+ def get_device_ids(self, access_token: str) -> List[str]:
# Get the list of devices so one can be deleted.
- request, channel = self.make_request(
- "GET", "devices", access_token=self.user_tok,
- ) # type: SynapseRequest, FakeChannel
-
- # Get the ID of the device.
- self.assertEqual(request.code, 200)
+ _, channel = self.make_request("GET", "devices", access_token=access_token,)
+ self.assertEqual(channel.code, 200)
return [d["device_id"] for d in channel.json_body["devices"]]
def delete_device(
- self, device: str, expected_response: int, body: Union[bytes, JsonDict] = b""
+ self,
+ access_token: str,
+ device: str,
+ expected_response: int,
+ body: Union[bytes, JsonDict] = b"",
) -> FakeChannel:
"""Delete an individual device."""
request, channel = self.make_request(
- "DELETE", "devices/" + device, body, access_token=self.user_tok
+ "DELETE", "devices/" + device, body, access_token=access_token,
) # type: SynapseRequest, FakeChannel
# Ensure the response is sane.
@@ -201,11 +222,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
"""
Test user interactive authentication outside of registration.
"""
- device_id = self.get_device_ids()[0]
+ device_id = self.get_device_ids(self.user_tok)[0]
# Attempt to delete this device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_id, 401)
+ channel = self.delete_device(self.user_tok, device_id, 401)
# Grab the session
session = channel.json_body["session"]
@@ -214,6 +235,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow.
self.delete_device(
+ self.user_tok,
device_id,
200,
{
@@ -233,12 +255,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
UIA - check that still works.
"""
- device_id = self.get_device_ids()[0]
- channel = self.delete_device(device_id, 401)
+ device_id = self.get_device_ids(self.user_tok)[0]
+ channel = self.delete_device(self.user_tok, device_id, 401)
session = channel.json_body["session"]
# Make another request providing the UI auth flow.
self.delete_device(
+ self.user_tok,
device_id,
200,
{
@@ -264,7 +287,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login.
self.login("test", self.user_pass)
- device_ids = self.get_device_ids()
+ device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device.
@@ -298,12 +321,12 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Create a second login.
self.login("test", self.user_pass)
- device_ids = self.get_device_ids()
+ device_ids = self.get_device_ids(self.user_tok)
self.assertEqual(len(device_ids), 2)
# Attempt to delete the first device.
# Returns a 401 as per the spec
- channel = self.delete_device(device_ids[0], 401)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
# Grab the session
session = channel.json_body["session"]
@@ -313,6 +336,7 @@ class UIAuthTests(unittest.HomeserverTestCase):
# Make another request providing the UI auth flow, but try to delete the
# second device. This results in an error.
self.delete_device(
+ self.user_tok,
device_ids[1],
403,
{
@@ -324,3 +348,39 @@ class UIAuthTests(unittest.HomeserverTestCase):
},
},
)
+
+ def test_does_not_offer_password_for_sso_user(self):
+ login_resp = self.helper.login_via_oidc("username")
+ user_tok = login_resp["access_token"]
+ device_id = login_resp["device_id"]
+
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ channel = self.delete_device(user_tok, device_id, 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.sso"]}])
+
+ def test_does_not_offer_sso_for_password_user(self):
+ # now call the device deletion API: we should get the option to auth with SSO
+ # and not password.
+ device_ids = self.get_device_ids(self.user_tok)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+ flows = channel.json_body["flows"]
+ self.assertEqual(flows, [{"stages": ["m.login.password"]}])
+
+ def test_offers_both_flows_for_upgraded_user(self):
+ """A user that had a password and then logged in with SSO should get both flows
+ """
+ login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+ self.assertEqual(login_resp["user_id"], self.user)
+
+ device_ids = self.get_device_ids(self.user_tok)
+ channel = self.delete_device(self.user_tok, device_ids[0], 401)
+
+ flows = channel.json_body["flows"]
+ # we have no particular expectations of ordering here
+ self.assertIn({"stages": ["m.login.password"]}, flows)
+ self.assertIn({"stages": ["m.login.sso"]}, flows)
+ self.assertEqual(len(flows), 2)
diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py
index ccdc8c2ecf..529b6bcded 100644
--- a/tests/rest/media/v1/test_url_preview.py
+++ b/tests/rest/media/v1/test_url_preview.py
@@ -18,41 +18,15 @@ import re
from mock import patch
-import attr
-
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
-from twisted.python.failure import Failure
from twisted.test.proto_helpers import AccumulatingProtocol
-from twisted.web._newclient import ResponseDone
from tests import unittest
from tests.server import FakeTransport
-@attr.s
-class FakeResponse:
- version = attr.ib()
- code = attr.ib()
- phrase = attr.ib()
- headers = attr.ib()
- body = attr.ib()
- absoluteURI = attr.ib()
-
- @property
- def request(self):
- @attr.s
- class FakeTransport:
- absoluteURI = self.absoluteURI
-
- return FakeTransport()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
-
-
class URLPreviewTests(unittest.HomeserverTestCase):
hijack_auth = True
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..4faf32e335 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -216,8 +216,9 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
+ if path.startswith(b"/"):
+ path = path[1:]
path = b"/_matrix/client/r0/" + path
- path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path
@@ -258,6 +259,7 @@ def make_request(
for k, v in custom_headers:
req.requestHeaders.addRawHeader(k, v)
+ req.parseCookies()
req.requestReceived(method, path, b"1.1")
if await_result:
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index d232b72264..6873d45eb6 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -22,6 +22,11 @@ import warnings
from asyncio import Future
from typing import Any, Awaitable, Callable, TypeVar
+import attr
+
+from twisted.python.failure import Failure
+from twisted.web.client import ResponseDone
+
TV = TypeVar("TV")
@@ -80,3 +85,25 @@ def setup_awaitable_errors() -> Callable[[], None]:
sys.unraisablehook = unraisablehook # type: ignore
return cleanup
+
+
+@attr.s
+class FakeResponse:
+ """A fake twisted.web.IResponse object
+
+ there is a similar class at treq.test.test_response, but it lacks a `phrase`
+ attribute, and didn't support deliverBody until recently.
+ """
+
+ # HTTP response code
+ code = attr.ib(type=int)
+
+ # HTTP response phrase (eg b'OK' for a 200)
+ phrase = attr.ib(type=bytes)
+
+ # body of the response
+ body = attr.ib(type=bytes)
+
+ def deliverBody(self, protocol):
+ protocol.dataReceived(self.body)
+ protocol.connectionLost(Failure(ResponseDone()))
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..102b0a1f34 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Dict, Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch
@@ -46,6 +46,7 @@ from synapse.logging.context import (
)
from synapse.server import HomeServer
from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
"""
Create a the root resource for the test server.
- The default implementation creates a JsonResource and calls each function in
- `servlets` to register servletes against it
+ The default calls `self.create_resource_dict` and builds the resultant dict
+ into a tree.
"""
- resource = JsonResource(self.hs)
+ root_resource = Resource()
+ create_resource_tree(self.create_resource_dict(), root_resource)
+ return root_resource
- for servlet in self.servlets:
- servlet(self.hs, resource)
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ """Create a resource tree for the test server
- return resource
+ A resource tree is a mapping from path to twisted.web.resource.
+
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servlets against it.
+ """
+ servlet_resource = JsonResource(self.hs)
+ for servlet in self.servlets:
+ servlet(self.hs, servlet_resource)
+ return {
+ "/_matrix/client": servlet_resource,
+ "/_synapse/admin": servlet_resource,
+ }
def default_config(self):
"""
@@ -691,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
- def prepare(self, reactor, clock, homeserver):
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+ return d
+
+
+class TestTransportLayerServer(JsonResource):
+ """A test implementation of TransportLayerServer
+
+ authenticates incoming requests as `other.example.com`.
+ """
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
+ authenticator = Authenticator()
+
ratelimiter = FederationRateLimiter(
- clock,
+ hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@@ -706,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
- federation_server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
- return super().prepare(reactor, clock, homeserver)
+ federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):
|