diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
new file mode 100644
index 0000000000..7baf224f7e
--- /dev/null
+++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from mock import Mock
+
+from synapse.handlers.cas_handler import CasResponse
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+SERVER_URL = "https://issuer/"
+
+
+class CasHandlerTestCase(HomeserverTestCase):
+ def default_config(self):
+ config = super().default_config()
+ config["public_baseurl"] = BASE_URL
+ cas_config = {
+ "enabled": True,
+ "server_url": SERVER_URL,
+ "service_url": BASE_URL,
+ }
+ config["cas_config"] = cas_config
+
+ return config
+
+ def make_homeserver(self, reactor, clock):
+ hs = self.setup_test_homeserver()
+
+ self.handler = hs.get_cas_handler()
+
+ # Reduce the number of attempts when generating MXIDs.
+ sso_handler = hs.get_sso_handler()
+ sso_handler._MAP_USERNAME_RETRIES = 3
+
+ return hs
+
+ def test_map_cas_user_to_user(self):
+ """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=True
+ )
+
+ def test_map_cas_user_to_existing_user(self):
+ """Existing users can log in with CAS account."""
+ store = self.hs.get_datastore()
+ self.get_success(
+ store.register_user(user_id="@test_user:test", password_hash=None)
+ )
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # Map a user via SSO.
+ cas_response = CasResponse("test_user", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=False
+ )
+
+ # Subsequent calls should map to the same mxid.
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=False
+ )
+
+ def test_map_cas_user_to_invalid_localpart(self):
+ """CAS automaps invalid characters to base-64 encoding."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ cas_response = CasResponse("föö", {})
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+ )
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 875aaec2c6..5dfeccfeb6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -27,7 +27,7 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
return hs
@@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver("server", http_client=None)
+ hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ee6ef5e6fa..a39f898608 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,8 +42,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.mock_registry.register_query_handler = register_query_handler
hs = self.setup_test_homeserver(
- http_client=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_registry=self.mock_registry,
)
@@ -407,7 +405,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
def test_denied(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
b"directory/room/%23test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -417,7 +415,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
def test_allowed(self):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT",
b"directory/room/%23unofficial_test%3Atest",
('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -433,7 +431,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.assertEquals(200, channel.code, channel.result)
@@ -448,7 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = True
# Room list is enabled so we should get some results
- request, channel = self.make_request("GET", b"publicRooms")
+ channel = self.make_request("GET", b"publicRooms")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) > 0)
@@ -456,13 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
self.directory_handler.enable_room_list_search = False
# Room list disabled so we should get no results
- request, channel = self.make_request("GET", b"publicRooms")
+ channel = self.make_request("GET", b"publicRooms")
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["chunk"]) == 0)
# Room list disabled so we shouldn't be allowed to publish rooms
room_id = self.helper.create_room_as(self.user_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf866dacf3..983e368592 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
from unittest import TestCase
from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.federation.federation_base import event_from_pdu_json
@@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor, clock):
- hs = self.setup_test_homeserver(http_client=None)
+ hs = self.setup_test_homeserver(federation_http_client=None)
self.handler = hs.get_federation_handler()
self.store = hs.get_datastore()
return hs
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
room_version,
)
- with LoggingContext(request="send_rejected"):
+ with LoggingContext("send_rejected"):
d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
self.get_success(d)
@@ -191,6 +191,50 @@ class FederationTestCase(unittest.HomeserverTestCase):
self.assertEqual(sg, sg2)
+ @unittest.override_config(
+ {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+ )
+ def test_invite_by_user_ratelimit(self):
+ """Tests that invites from federation to a particular user are
+ actually rate-limited.
+ """
+ other_server = "otherserver"
+ other_user = "@otheruser:" + other_server
+
+ # create the room
+ user_id = self.register_user("kermit", "test")
+ tok = self.login("kermit", "test")
+
+ def create_invite():
+ room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+ room_version = self.get_success(self.store.get_room_version(room_id))
+ return event_from_pdu_json(
+ {
+ "type": EventTypes.Member,
+ "content": {"membership": "invite"},
+ "room_id": room_id,
+ "sender": other_user,
+ "state_key": "@user:test",
+ "depth": 32,
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": self.clock.time_msec(),
+ },
+ room_version,
+ )
+
+ for i in range(3):
+ event = create_invite()
+ self.get_success(
+ self.handler.on_invite_request(other_server, event, event.room_version,)
+ )
+
+ event = create_invite()
+ self.get_failure(
+ self.handler.on_invite_request(other_server, event, event.room_version,),
+ exc=LimitExceededError,
+ )
+
def _build_and_send_join_event(self, other_server, other_user, room_id):
join_event = self.get_success(
self.handler.on_make_join_request(other_server, room_id, other_user)
@@ -198,7 +242,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
# the auth code requires that a signature exists, but doesn't check that
# signature... go figure.
join_event.signatures[other_server] = {"x": "y"}
- with LoggingContext(request="send_join"):
+ with LoggingContext("send_join"):
d = run_in_background(
self.handler.on_send_join_request, other_server, join_event
)
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index af42775815..f955dfa490 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -206,7 +206,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
# Redaction of event should fail.
path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id)
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", path, content={}, access_token=self.access_token
)
self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..ad20400b1d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,32 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
+from typing import Optional
from urllib.parse import parse_qs, urlparse
-from mock import Mock, patch
+from mock import ANY, Mock, patch
-import attr
import pymacaroons
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
from synapse.types import UserID
+from tests.test_utils import FakeResponse, simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
+try:
+ import authlib # noqa: F401
-@attr.s
-class FakeResponse:
- code = attr.ib()
- body = attr.ib()
- phrase = attr.ib()
-
- def deliverBody(self, protocol):
- protocol.dataReceived(self.body)
- protocol.connectionLost(Failure(ResponseDone()))
+ HAS_OIDC = True
+except ImportError:
+ HAS_OIDC = False
# These are a few constants that are used as config parameters in the tests.
@@ -46,7 +40,7 @@ ISSUER = "https://issuer/"
CLIENT_ID = "test-client-id"
CLIENT_SECRET = "test-client-secret"
BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
SCOPES = ["openid"]
AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -64,17 +58,14 @@ COMMON_CONFIG = {
}
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
@staticmethod
def parse_config(config):
return
+ def __init__(self, config):
+ pass
+
def get_remote_user_id(self, userinfo):
return userinfo["sub"]
@@ -97,16 +88,6 @@ class TestMappingProviderFailures(TestMappingProvider):
}
-def simple_async_mock(return_value=None, raises=None):
- # AsyncMock is not available in python3.5, this mimics part of its behaviour
- async def cb(*args, **kwargs):
- if raises:
- raise raises
- return return_value
-
- return Mock(side_effect=cb)
-
-
async def get_json(url):
# Mock get_json calls to handle jwks & oidc discovery endpoints
if url == WELL_KNOWN:
@@ -127,6 +108,9 @@ async def get_json(url):
class OidcHandlerTestCase(HomeserverTestCase):
+ if not HAS_OIDC:
+ skip = "requires OIDC"
+
def default_config(self):
config = super().default_config()
config["public_baseurl"] = BASE_URL
@@ -155,6 +139,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
+ self.provider = self.handler._providers["oidc"]
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
@@ -166,27 +151,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def metadata_edit(self, values):
- return patch.dict(self.handler._provider_metadata, values)
+ return patch.dict(self.provider._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
+ self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
self.assertEqual(args[2], error_description)
# Reset the render_error mock
self.render_error.reset_mock()
+ return args
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
- self.assertEqual(self.handler._callback_url, CALLBACK_URL)
- self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
- self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+ self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+ self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
- metadata = self.get_success(self.handler.load_metadata())
+ metadata = self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -198,47 +185,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
- self.get_success(self.handler.load_metadata())
+ self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
- self.get_success(self.handler.load_metadata())
+ self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
- jwks = self.get_success(self.handler.load_jwks())
+ jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
- self.get_success(self.handler.load_jwks())
+ self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
- self.get_success(self.handler.load_jwks(force=True))
+ self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
- self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+ self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
- self.handler._scopes = [] # not asking the openid scope
+ self.provider._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
- jwks = self.get_success(self.handler.load_jwks(force=True))
+ jwks = self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": COMMON_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
- h = self.handler
+ h = self.provider
# Default test config does not throw
h._validate_metadata()
@@ -317,13 +304,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
- self.handler._validate_metadata()
+ self.provider._validate_metadata()
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
url = self.get_success(
- self.handler.handle_redirect_request(req, b"http://client/redirect")
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -347,14 +334,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
# For some reason, call.args does not work with python3.5
args = calls[0][0]
kwargs = calls[0][1]
- self.assertEqual(args[0], COOKIE_NAME)
- self.assertEqual(kwargs["path"], COOKIE_PATH)
+
+ # The cookie name and path don't really matter, just that it has to be coherent
+ # between the callback & redirect handlers.
+ self.assertEqual(args[0], b"oidc_session")
+ self.assertEqual(kwargs["path"], "/_synapse/client/oidc")
cookie = args[1]
macaroon = pymacaroons.Macaroon.deserialize(cookie)
- state = self.handler._get_value_from_macaroon(macaroon, "state")
- nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
- redirect = self.handler._get_value_from_macaroon(
+ state = self.handler._token_generator._get_value_from_macaroon(
+ macaroon, "state"
+ )
+ nonce = self.handler._token_generator._get_value_from_macaroon(
+ macaroon, "nonce"
+ )
+ redirect = self.handler._token_generator._get_value_from_macaroon(
macaroon, "client_redirect_url"
)
@@ -384,31 +378,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
- when the userinfo fetching fails
- when the code exchange fails
"""
+
+ # ensure that we are correctly testing the fallback when "get_extra_attributes"
+ # is not implemented.
+ mapping_provider = self.provider._user_mapping_provider
+ with self.assertRaises(AttributeError):
+ _ = mapping_provider.get_extra_attributes
+
token = {
"type": "bearer",
"id_token": "id_token",
"access_token": "access_token",
}
+ username = "bar"
userinfo = {
"sub": "foo",
- "preferred_username": "bar",
+ "username": username,
}
- user_id = "@foo:domain.org"
- self.handler._exchange_code = simple_async_mock(return_value=token)
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ expected_user_id = "@%s:%s" % (username, self.hs.hostname)
+ self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
code = "code"
state = "state"
@@ -416,74 +408,61 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
- state=state,
- nonce=nonce,
- client_redirect_url=client_redirect_url,
- ui_auth_session_id=None,
+ session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+ request = _build_callback_request(
+ code, state, session, user_agent=user_agent, ip_address=ip_address
)
- request.args = {}
- request.args[b"code"] = [code.encode("utf-8")]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = ip_address
- request.get_user_agent.return_value = user_agent
-
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, None, new_user=True
)
- self.handler._exchange_code.assert_called_once_with(code)
- self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
- )
- self.handler._fetch_userinfo.assert_not_called()
+ self.provider._exchange_code.assert_called_once_with(code)
+ self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.provider._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
- self.handler._map_userinfo_to_user = simple_async_mock(
- raises=MappingException()
- )
- self.get_success(self.handler.handle_oidc_callback(request))
- self.assertRenderedError("mapping_error")
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ with patch.object(
+ self.provider,
+ "_remote_id_from_userinfo",
+ new=Mock(side_effect=MappingException()),
+ ):
+ self.get_success(self.handler.handle_oidc_callback(request))
+ self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
- self.handler._auth_handler.complete_sso_login.reset_mock()
- self.handler._exchange_code.reset_mock()
- self.handler._parse_id_token.reset_mock()
- self.handler._map_userinfo_to_user.reset_mock()
- self.handler._fetch_userinfo.reset_mock()
+ auth_handler.complete_sso_login.reset_mock()
+ self.provider._exchange_code.reset_mock()
+ self.provider._parse_id_token.reset_mock()
+ self.provider._fetch_userinfo.reset_mock()
# With userinfo fetching
- self.handler._scopes = [] # do not ask the "openid" scope
+ self.provider._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {},
- )
- self.handler._exchange_code.assert_called_once_with(code)
- self.handler._parse_id_token.assert_not_called()
- self.handler._map_userinfo_to_user.assert_called_once_with(
- userinfo, token, user_agent, ip_address
+ auth_handler.complete_sso_login.assert_called_once_with(
+ expected_user_id, request, client_redirect_url, None, new_user=False
)
- self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.provider._exchange_code.assert_called_once_with(code)
+ self.provider._parse_id_token.assert_not_called()
+ self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
- self.handler._exchange_code = simple_async_mock(
+ from synapse.handlers.oidc_handler import OidcError
+
+ self.provider._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
@@ -513,11 +492,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("invalid_session")
# Mismatching session
- session = self.handler._generate_oidc_session_token(
- state="state",
- nonce="nonce",
- client_redirect_url="http://client/redirect",
- ui_auth_session_id=None,
+ session = self._generate_oidc_session_token(
+ state="state", nonce="nonce", client_redirect_url="http://client/redirect",
)
request.args = {}
request.args[b"state"] = [b"mismatching state"]
@@ -541,7 +517,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
- ret = self.get_success(self.handler._exchange_code(code))
+ ret = self.get_success(self.provider._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -563,7 +539,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ from synapse.handlers.oidc_handler import OidcError
+
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar")
@@ -573,7 +551,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
@@ -585,14 +563,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
@@ -601,7 +579,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@override_config(
@@ -624,72 +602,56 @@ class OidcHandlerTestCase(HomeserverTestCase):
}
userinfo = {
"sub": "foo",
+ "username": "foo",
"phone": "1234567",
}
- user_id = "@foo:domain.org"
- self.handler._exchange_code = simple_async_mock(return_value=token)
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
- self.handler._auth_handler.complete_sso_login = simple_async_mock()
- request = Mock(
- spec=[
- "args",
- "getCookie",
- "addCookie",
- "requestHeaders",
- "getClientIP",
- "get_user_agent",
- ]
- )
+ self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
state = "state"
client_redirect_url = "http://client/redirect"
- request.getCookie.return_value = self.handler._generate_oidc_session_token(
- state=state,
- nonce="nonce",
- client_redirect_url=client_redirect_url,
- ui_auth_session_id=None,
+ session = self._generate_oidc_session_token(
+ state=state, nonce="nonce", client_redirect_url=client_redirect_url,
)
-
- request.args = {}
- request.args[b"code"] = [b"code"]
- request.args[b"state"] = [state.encode("utf-8")]
-
- request.getClientIP.return_value = "10.0.0.1"
- request.get_user_agent.return_value = "Browser"
+ request = _build_callback_request("code", state, session)
self.get_success(self.handler.handle_oidc_callback(request))
- self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url, {"phone": "1234567"},
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@foo:test",
+ request,
+ client_redirect_url,
+ {"phone": "1234567"},
+ new_user=True,
)
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
userinfo = {
"sub": "test_user",
"username": "test_user",
}
- # The token doesn't matter with the default user mapping provider.
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", ANY, ANY, None, new_user=True
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Some providers return an integer ID.
userinfo = {
"sub": 1234,
"username": "test_user_2",
}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user_2:test", ANY, ANY, None, new_user=True
)
- self.assertEqual(mxid, "@test_user_2:test")
+ auth_handler.complete_sso_login.reset_mock()
# Test if the mxid is already taken
store = self.hs.get_datastore()
@@ -698,14 +660,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user3.to_string(), password_hash=None)
)
userinfo = {"sub": "test3", "username": "test_user_3"}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
- self.assertEqual(
- str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error",
+ "Mapping provider does not support de-duplicating Matrix IDs",
)
@override_config({"oidc_config": {"allow_existing_users": True}})
@@ -717,26 +676,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user.to_string(), password_hash=None)
)
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
userinfo = {
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Note that a second SSO user can be mapped to the same Matrix ID. (This
# requires a unique sub, but something that maps to the same matrix ID,
@@ -747,13 +706,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test1",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ user.to_string(), ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
+ auth_handler.complete_sso_login.reset_mock()
# Register some non-exact matching cases.
user2 = UserID.from_string("@TEST_user_2:test")
@@ -770,14 +727,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test2",
"username": "TEST_USER_2",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ args = self.assertRenderedError("mapping_error")
self.assertTrue(
- str(e.value).startswith(
+ args[2].startswith(
"Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
)
)
@@ -788,28 +742,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
store.register_user(user_id=user2.to_string(), password_hash=None)
)
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@TEST_USER_2:test", ANY, ANY, None, new_user=False
)
- self.assertEqual(mxid, "@TEST_USER_2:test")
def test_map_userinfo_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
- userinfo = {
- "sub": "test2",
- "username": "föö",
- }
- token = {}
-
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(
+ _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
)
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ self.assertRenderedError("mapping_error", "localpart is invalid: föö")
@override_config(
{
@@ -822,6 +765,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
def test_map_userinfo_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
@@ -830,14 +776,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "test",
"username": "test_user",
}
- token = {}
- mxid = self.get_success(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- )
- )
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", ANY, ANY, None, new_user=True
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular OIDC username.
self.get_success(
@@ -853,12 +798,128 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "tester",
"username": "tester",
}
- e = self.get_failure(
- self.handler._map_userinfo_to_user(
- userinfo, token, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ auth_handler.complete_sso_login.assert_not_called()
+ self.assertRenderedError(
+ "mapping_error", "Unable to generate a Matrix ID from the SSO response"
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+
+ def test_empty_localpart(self):
+ """Attempts to map onto an empty localpart should be rejected."""
+ userinfo = {
+ "sub": "tester",
+ "username": "",
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "config": {"localpart_template": "{{ user.username }}"}
+ }
+ }
+ }
+ )
+ def test_null_localpart(self):
+ """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+ userinfo = {
+ "sub": "tester",
+ "username": None,
+ }
+ self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+ self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+ def _generate_oidc_session_token(
+ self,
+ state: str,
+ nonce: str,
+ client_redirect_url: str,
+ ui_auth_session_id: Optional[str] = None,
+ ) -> str:
+ from synapse.handlers.oidc_handler import OidcSessionData
+
+ return self.handler._token_generator.generate_oidc_session_token(
+ state=state,
+ session_data=OidcSessionData(
+ idp_id="oidc",
+ nonce=nonce,
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=ui_auth_session_id,
+ ),
)
+
+
+async def _make_callback_with_userinfo(
+ hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+ """Mock up an OIDC callback with the given userinfo dict
+
+ We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+ and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
+
+ Args:
+ hs: the HomeServer impl to send the callback to.
+ userinfo: the OIDC userinfo dict
+ client_redirect_url: the URL to redirect to on success.
+ """
+ from synapse.handlers.oidc_handler import OidcSessionData
+
+ handler = hs.get_oidc_handler()
+ provider = handler._providers["oidc"]
+ provider._exchange_code = simple_async_mock(return_value={})
+ provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+
+ state = "state"
+ session = handler._token_generator.generate_oidc_session_token(
+ state=state,
+ session_data=OidcSessionData(
+ idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+ ),
+ )
+ request = _build_callback_request("code", state, session)
+
+ await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+ code: str,
+ state: str,
+ session: str,
+ user_agent: str = "Browser",
+ ip_address: str = "10.0.0.1",
+):
+ """Builds a fake SynapseRequest to mock the browser callback
+
+ Returns a Mock object which looks like the SynapseRequest we get from a browser
+ after SSO (before we return to the client)
+
+ Args:
+ code: the authorization code which would have been returned by the OIDC
+ provider
+ state: the "state" param which would have been passed around in the
+ query param. Should be the same as was embedded in the session in
+ _build_oidc_session.
+ session: the "session" which would have been passed around in the cookie.
+ user_agent: the user-agent to present
+ ip_address: the IP address to pretend the request came from
+ """
+ request = Mock(
+ spec=[
+ "args",
+ "getCookie",
+ "addCookie",
+ "requestHeaders",
+ "getClientIP",
+ "getHeader",
+ ]
+ )
+
+ request.getCookie.return_value = session
+ request.args = {}
+ request.args[b"code"] = [code.encode("utf-8")]
+ request.args[b"state"] = [state.encode("utf-8")]
+ request.getClientIP.return_value = ip_address
+ return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..f816594ee4 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -432,6 +432,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
@override_config(
{
+ **providers_config(CustomAuthProvider),
+ "password_config": {"enabled": False, "localdb_enabled": False},
+ }
+ )
+ def test_custom_auth_password_disabled_localdb_enabled(self):
+ """Check the localdb_enabled == enabled == False
+
+ Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+ that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+ cause an exception.
+ """
+ self.register_user("localuser", "localpass")
+
+ flows = self._get_login_flows()
+ self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+ # login shouldn't work and should be rejected with a 400 ("unknown login type")
+ channel = self._send_password_login("localuser", "localpass")
+ self.assertEqual(channel.code, 400, channel.result)
+ mock_password_provider.check_auth.assert_not_called()
+
+ @override_config(
+ {
**providers_config(PasswordCustomAuthProvider),
"password_config": {"enabled": False},
}
@@ -528,7 +551,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 400, channel.result)
def _get_login_flows(self) -> JsonDict:
- _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+ channel = self.make_request("GET", "/_matrix/client/r0/login")
self.assertEqual(channel.code, 200, channel.result)
return channel.json_body["flows"]
@@ -537,7 +560,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
def _send_login(self, type, user, **params) -> FakeChannel:
params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
- _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+ channel = self.make_request("POST", "/_matrix/client/r0/login", params)
return channel
def _start_delete_device_session(self, access_token, device_id) -> str:
@@ -574,7 +597,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
) -> FakeChannel:
"""Delete an individual device."""
- _, channel = self.make_request(
+ channel = self.make_request(
"DELETE", "devices/" + device, body, access_token=access_token
)
return channel
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 8ed67640f8..0794b32c9c 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(
- "server", http_client=None, federation_sender=Mock()
+ "server", federation_http_client=None, federation_sender=Mock()
)
return hs
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..022943a10a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
self.addCleanup,
- http_client=None,
- resource_for_federation=Mock(),
federation_client=self.mock_federation,
federation_server=Mock(),
federation_registry=self.mock_registry,
@@ -107,6 +105,21 @@ class ProfileTestCase(unittest.TestCase):
"Frank",
)
+ # Set displayname to an empty string
+ yield defer.ensureDeferred(
+ self.handler.set_displayname(
+ self.frank, synapse.types.create_requester(self.frank), ""
+ )
+ )
+
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_displayname(self.frank.localpart)
+ )
+ )
+ )
+
@defer.inlineCallbacks
def test_set_my_name_if_disabled(self):
self.hs.config.enable_set_displayname = False
@@ -225,6 +238,21 @@ class ProfileTestCase(unittest.TestCase):
"http://my.server/me.png",
)
+ # Set avatar to an empty string
+ yield defer.ensureDeferred(
+ self.handler.set_avatar_url(
+ self.frank, synapse.types.create_requester(self.frank), "",
+ )
+ )
+
+ self.assertIsNone(
+ (
+ yield defer.ensureDeferred(
+ self.store.get_profile_avatar_url(self.frank.localpart)
+ )
+ ),
+ )
+
@defer.inlineCallbacks
def test_set_my_avatar_if_disabled(self):
self.hs.config.enable_set_avatar_url = False
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 45dc17aba5..a8d6c0f617 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,13 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from typing import Optional
+
+from mock import Mock
+
import attr
from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
+from tests.test_utils import simple_async_mock
from tests.unittest import HomeserverTestCase, override_config
+# Check if we have the dependencies to run the tests.
+try:
+ import saml2.config
+ from saml2.sigver import SigverError
+
+ has_saml2 = True
+
+ # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
+ config = saml2.config.SPConfig()
+ try:
+ config.load({"metadata": {}})
+ has_xmlsec1 = True
+ except SigverError:
+ has_xmlsec1 = False
+except ImportError:
+ has_saml2 = False
+ has_xmlsec1 = False
+
# These are a few constants that are used as config parameters in the tests.
BASE_URL = "https://synapse/"
@@ -26,6 +48,8 @@ BASE_URL = "https://synapse/"
@attr.s
class FakeAuthnResponse:
ava = attr.ib(type=dict)
+ assertions = attr.ib(type=list, factory=list)
+ in_response_to = attr.ib(type=Optional[str], default=None)
class TestMappingProvider:
@@ -86,17 +110,29 @@ class SamlHandlerTestCase(HomeserverTestCase):
return hs
+ if not has_saml2:
+ skip = "Requires pysaml2"
+ elif not has_xmlsec1:
+ skip = "Requires xmlsec1"
+
def test_map_saml_response_to_user(self):
"""Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # send a mocked-up SAML response to the callback
saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
- # The redirect_url doesn't matter with the default user mapping provider.
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "redirect_uri", None, new_user=True
)
- self.assertEqual(mxid, "@test_user:test")
@override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
def test_map_saml_response_to_existing_user(self):
@@ -106,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
store.register_user(user_id="@test_user:test", password_hash=None)
)
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
# Map a user via SSO.
saml_response = FakeAuthnResponse(
{"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
)
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+
+ # check that the auth handler got called as expected
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "", None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
# Subsequent calls should map to the same mxid.
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ auth_handler.complete_sso_login.reset_mock()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, "")
+ )
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user:test", request, "", None, new_user=False
)
- self.assertEqual(mxid, "@test_user:test")
def test_map_saml_response_to_invalid_localpart(self):
"""If the mapping provider generates an invalid localpart it should be rejected."""
+
+ # stub out the auth handler
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+
+ # mock out the error renderer too
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
- redirect_url = ""
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
+ )
+ sso_handler.render_error.assert_called_once_with(
+ request, "mapping_error", "localpart is invalid: föö"
)
- self.assertEqual(str(e.value), "localpart is invalid: föö")
+ auth_handler.complete_sso_login.assert_not_called()
def test_map_saml_response_to_user_retries(self):
"""The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+ # stub out the auth handler and error renderer
+ auth_handler = self.hs.get_auth_handler()
+ auth_handler.complete_sso_login = simple_async_mock()
+ sso_handler = self.hs.get_sso_handler()
+ sso_handler.render_error = Mock(return_value=None)
+
+ # register a user to occupy the first-choice MXID
store = self.hs.get_datastore()
self.get_success(
store.register_user(user_id="@test_user:test", password_hash=None)
)
+
+ # send the fake SAML response
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
- mxid = self.get_success(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- )
+ request = _mock_request()
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
+
# test_user is already taken, so test_user1 gets registered instead.
- self.assertEqual(mxid, "@test_user1:test")
+ auth_handler.complete_sso_login.assert_called_once_with(
+ "@test_user1:test", request, "", None, new_user=True
+ )
+ auth_handler.complete_sso_login.reset_mock()
# Register all of the potential mxids for a particular SAML username.
self.get_success(
@@ -165,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
# Now attempt to map to a username, this will fail since all potential usernames are taken.
saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
- e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
- MappingException,
+ self.get_success(
+ self.handler._handle_authn_response(request, saml_response, ""),
)
- self.assertEqual(
- str(e.value), "Unable to generate a Matrix ID from the SSO response"
+ sso_handler.render_error.assert_called_once_with(
+ request,
+ "mapping_error",
+ "Unable to generate a Matrix ID from the SSO response",
)
+ auth_handler.complete_sso_login.assert_not_called()
@override_config(
{
@@ -185,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
}
)
def test_map_saml_response_redirect(self):
+ """Test a mapping provider that raises a RedirectException"""
+
saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
- redirect_url = ""
+ request = _mock_request()
e = self.get_failure(
- self.handler._map_saml_response_to_user(
- saml_response, redirect_url, "user-agent", "10.10.10.10"
- ),
+ self.handler._handle_authn_response(request, saml_response, ""),
RedirectException,
)
self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+ """Returns a mock which will stand in as a SynapseRequest"""
+ return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..96e5bdac4a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
import json
+from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.web.resource import Resource
from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
-from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- servlets = [register_federation_servlets]
-
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -70,13 +70,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs = self.setup_test_homeserver(
notifier=Mock(),
- http_client=mock_federation_client,
+ federation_http_client=mock_federation_client,
keyring=mock_keyring,
replication_streams={},
)
return hs
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
+
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
@@ -192,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
- put_json = self.hs.get_http_client().put_json
+ put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
@@ -215,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(self.event_source.get_current_key(), 0)
- (request, channel) = self.make_request(
+ channel = self.make_request(
"PUT",
"/_matrix/federation/v1/send/1000000",
_make_edu_transaction_json(
@@ -270,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
- put_json = self.hs.get_http_client().put_json
+ put_json = self.hs.get_federation_http_client().put_json
put_json.assert_called_once_with(
"farm",
path="/_matrix/federation/v1/send/1000000",
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 98e5af2072..9c886d671a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
)
)
+ regular_user_id = "@regular:test"
+ self.get_success(
+ self.store.register_user(user_id=regular_user_id, password_hash=None)
+ )
self.get_success(
self.handler.handle_local_profile_change(support_user_id, None)
@@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
display_name = "display_name"
profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
- regular_user_id = "@regular:test"
self.get_success(
self.handler.handle_local_profile_change(regular_user_id, profile_info)
)
profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
self.assertTrue(profile["display_name"] == display_name)
+ def test_handle_local_profile_change_with_deactivated_user(self):
+ # create user
+ r_user_id = "@regular:test"
+ self.get_success(
+ self.store.register_user(user_id=r_user_id, password_hash=None)
+ )
+
+ # update profile
+ display_name = "Regular User"
+ profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+ self.get_success(
+ self.handler.handle_local_profile_change(r_user_id, profile_info)
+ )
+
+ # profile is in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile["display_name"] == display_name)
+
+ # deactivate user
+ self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
+ self.get_success(self.handler.handle_user_deactivated(r_user_id))
+
+ # profile is not in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile is None)
+
+ # update profile after deactivation
+ self.get_success(
+ self.handler.handle_local_profile_change(r_user_id, profile_info)
+ )
+
+ # profile is furthermore not in directory
+ profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+ self.assertTrue(profile is None)
+
def test_handle_user_deactivated_support_user(self):
s_user_id = "@support:test"
self.get_success(
@@ -270,7 +308,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
spam_checker = self.hs.get_spam_checker()
class AllowAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# Allow all users.
return False
@@ -283,7 +321,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
# Configure a spam checker that filters all users.
class BlockAll:
- def check_username_for_spam(self, user_profile):
+ async def check_username_for_spam(self, user_profile):
# All users are spammy.
return True
@@ -534,7 +572,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
self.helper.join(room, user=u2)
# Assert user directory is not empty
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.assertEquals(200, channel.code, channel.result)
@@ -542,7 +580,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
# Disable user directory and check search returns nothing
self.config.user_directory_search_enabled = False
- request, channel = self.make_request(
+ channel = self.make_request(
"POST", b"user_directory/search", b'{"search_term":"user2"}'
)
self.assertEquals(200, channel.code, channel.result)
|