diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5910772aa8..d5087e58be 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -21,7 +21,6 @@ from mock import Mock, patch
import attr
import pymacaroons
-from twisted.internet import defer
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
@@ -87,6 +86,13 @@ class TestMappingProvider(OidcMappingProvider):
async def map_user_attributes(self, userinfo, token):
return {"localpart": userinfo["username"], "display_name": None}
+ # Do not include get_extra_attributes to test backwards compatibility paths.
+
+
+class TestMappingProviderExtra(TestMappingProvider):
+ async def get_extra_attributes(self, userinfo, token):
+ return {"phone": userinfo["phone"]}
+
def simple_async_mock(return_value=None, raises=None):
# AsyncMock is not available in python3.5, this mimics part of its behaviour
@@ -126,7 +132,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
config = self.default_config()
config["public_baseurl"] = BASE_URL
- oidc_config = config.get("oidc_config", {})
+ oidc_config = {}
oidc_config["enabled"] = True
oidc_config["client_id"] = CLIENT_ID
oidc_config["client_secret"] = CLIENT_SECRET
@@ -135,6 +141,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
oidc_config["user_mapping_provider"] = {
"module": __name__ + ".TestMappingProvider",
}
+
+ # Update this config with what's in the default config so that
+ # override_config works as expected.
+ oidc_config.update(config.get("oidc_config", {}))
config["oidc_config"] = oidc_config
hs = self.setup_test_homeserver(
@@ -165,11 +175,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
- @defer.inlineCallbacks
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
- metadata = yield defer.ensureDeferred(self.handler.load_metadata())
+ metadata = self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -181,43 +190,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_metadata())
+ self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
- @defer.inlineCallbacks
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
- yield defer.ensureDeferred(self.handler.load_metadata())
+ self.get_success(self.handler.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
- @defer.inlineCallbacks
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
- jwks = yield defer.ensureDeferred(self.handler.load_jwks())
+ jwks = self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_jwks())
+ self.get_success(self.handler.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
- yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
- with self.assertRaises(RuntimeError):
- yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
self.handler._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
- jwks = yield defer.ensureDeferred(self.handler.load_jwks(force=True))
+ jwks = self.get_success(self.handler.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@@ -299,11 +305,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# This should not throw
self.handler._validate_metadata()
- @defer.inlineCallbacks
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
- url = yield defer.ensureDeferred(
+ url = self.get_success(
self.handler.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
@@ -343,20 +348,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertEqual(params["nonce"], [nonce])
self.assertEqual(redirect, "http://client/redirect")
- @defer.inlineCallbacks
def test_callback_error(self):
"""Errors from the provider returned in the callback are displayed."""
self.handler._render_error = Mock()
request = Mock(args={})
request.args[b"error"] = [b"invalid_client"]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "")
request.args[b"error_description"] = [b"some description"]
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_client", "some description")
- @defer.inlineCallbacks
def test_callback(self):
"""Code callback works and display errors if something went wrong.
@@ -377,7 +380,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
"sub": "foo",
"preferred_username": "bar",
}
- user_id = UserID("foo", "domain.org")
+ user_id = "@foo:domain.org"
self.handler._render_error = Mock(return_value=None)
self.handler._exchange_code = simple_async_mock(return_value=token)
self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
@@ -394,13 +397,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
client_redirect_url = "http://client/redirect"
user_agent = "Browser"
ip_address = "10.0.0.1"
- session = self.handler._generate_oidc_session_token(
+ request.getCookie.return_value = self.handler._generate_oidc_session_token(
state=state,
nonce=nonce,
client_redirect_url=client_redirect_url,
ui_auth_session_id=None,
)
- request.getCookie.return_value = session
request.args = {}
request.args[b"code"] = [code.encode("utf-8")]
@@ -410,10 +412,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.requestHeaders.getRawHeaders.return_value = [user_agent.encode("ascii")]
request.getClientIP.return_value = ip_address
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url,
+ user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -427,13 +429,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.handler._map_userinfo_to_user = simple_async_mock(
raises=MappingException()
)
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mapping_error")
self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
# Handle ID token errors
self.handler._parse_id_token = simple_async_mock(raises=Exception())
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
self.handler._auth_handler.complete_sso_login.reset_mock()
@@ -444,10 +446,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
# With userinfo fetching
self.handler._scopes = [] # do not ask the "openid" scope
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.handler._auth_handler.complete_sso_login.assert_called_once_with(
- user_id, request, client_redirect_url,
+ user_id, request, client_redirect_url, {},
)
self.handler._exchange_code.assert_called_once_with(code)
self.handler._parse_id_token.assert_not_called()
@@ -459,17 +461,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Handle userinfo fetching error
self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
self.handler._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
- @defer.inlineCallbacks
def test_callback_session(self):
"""The callback verifies the session presence and validity"""
self.handler._render_error = Mock(return_value=None)
@@ -478,20 +479,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
# Missing cookie
request.args = {}
request.getCookie.return_value = None
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("missing_session", "No session cookie found")
# Missing session parameter
request.args = {}
request.getCookie.return_value = "session"
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request", "State parameter is missing")
# Invalid cookie
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = "session"
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_session")
# Mismatching session
@@ -504,18 +505,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
request.args = {}
request.args[b"state"] = [b"mismatching state"]
request.getCookie.return_value = session
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("mismatching_session")
# Valid session
request.args = {}
request.args[b"state"] = [b"state"]
request.getCookie.return_value = session
- yield defer.ensureDeferred(self.handler.handle_oidc_callback(request))
+ self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_request")
@override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
- @defer.inlineCallbacks
def test_exchange_code(self):
"""Code exchange behaves correctly and handles various error scenarios."""
token = {"type": "bearer"}
@@ -524,7 +524,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
- ret = yield defer.ensureDeferred(self.handler._exchange_code(code))
+ ret = self.get_success(self.handler._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -546,10 +546,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "foo", "error_description": "bar"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "foo")
- self.assertEqual(exc.exception.error_description, "bar")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "foo")
+ self.assertEqual(exc.value.error_description, "bar")
# Internal server error with no JSON body
self.http_client.request = simple_async_mock(
@@ -557,9 +556,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "server_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
self.http_client.request = simple_async_mock(
@@ -569,17 +567,16 @@ class OidcHandlerTestCase(HomeserverTestCase):
body=b'{"error": "internal_server_error"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "internal_server_error")
+
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "server_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
self.http_client.request = simple_async_mock(
@@ -587,9 +584,62 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
- with self.assertRaises(OidcError) as exc:
- yield defer.ensureDeferred(self.handler._exchange_code(code))
- self.assertEqual(exc.exception.error, "some_error")
+ exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ self.assertEqual(exc.value.error, "some_error")
+
+ @override_config(
+ {
+ "oidc_config": {
+ "user_mapping_provider": {
+ "module": __name__ + ".TestMappingProviderExtra"
+ }
+ }
+ }
+ )
+ def test_extra_attributes(self):
+ """
+ Login while using a mapping provider that implements get_extra_attributes.
+ """
+ token = {
+ "type": "bearer",
+ "id_token": "id_token",
+ "access_token": "access_token",
+ }
+ userinfo = {
+ "sub": "foo",
+ "phone": "1234567",
+ }
+ user_id = "@foo:domain.org"
+ self.handler._exchange_code = simple_async_mock(return_value=token)
+ self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+ self.handler._auth_handler.complete_sso_login = simple_async_mock()
+ request = Mock(
+ spec=["args", "getCookie", "addCookie", "requestHeaders", "getClientIP"]
+ )
+
+ state = "state"
+ client_redirect_url = "http://client/redirect"
+ request.getCookie.return_value = self.handler._generate_oidc_session_token(
+ state=state,
+ nonce="nonce",
+ client_redirect_url=client_redirect_url,
+ ui_auth_session_id=None,
+ )
+
+ request.args = {}
+ request.args[b"code"] = [b"code"]
+ request.args[b"state"] = [state.encode("utf-8")]
+
+ request.requestHeaders = Mock(spec=["getRawHeaders"])
+ request.requestHeaders.getRawHeaders.return_value = [b"Browser"]
+ request.getClientIP.return_value = "10.0.0.1"
+
+ self.get_success(self.handler.handle_oidc_callback(request))
+
+ self.handler._auth_handler.complete_sso_login.assert_called_once_with(
+ user_id, request, client_redirect_url, {"phone": "1234567"},
+ )
def test_map_userinfo_to_user(self):
"""Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 0a567b032f..0d809d25d5 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -905,6 +905,7 @@ class RoomMessageListTestCase(RoomBase):
first_token = self.get_success(
store.get_topological_token_for_event(first_event_id)
)
+ first_token_str = self.get_success(first_token.to_string(store))
# Send a second message in the room, which won't be removed, and which we'll
# use as the marker to purge events before.
@@ -912,6 +913,7 @@ class RoomMessageListTestCase(RoomBase):
second_token = self.get_success(
store.get_topological_token_for_event(second_event_id)
)
+ second_token_str = self.get_success(second_token.to_string(store))
# Send a third event in the room to ensure we don't fall under any edge case
# due to our marker being the latest forward extremity in the room.
@@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
pagination_handler._purge_history(
purge_id=purge_id,
room_id=self.room_id,
- token=second_token,
+ token=second_token_str,
delete_local_events=True,
)
)
@@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ second_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
- % (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
+ % (
+ self.room_id,
+ first_token_str,
+ json.dumps({"types": [EventTypes.Message]}),
+ ),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
diff --git a/tests/storage/test_event_metrics.py b/tests/storage/test_event_metrics.py
index 949846fe33..3957471f3f 100644
--- a/tests/storage/test_event_metrics.py
+++ b/tests/storage/test_event_metrics.py
@@ -52,14 +52,14 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
self.reactor.advance(60 * 60 * 1000)
self.pump(1)
- items = set(
+ items = list(
filter(
lambda x: b"synapse_forward_extremities_" in x,
- generate_latest(REGISTRY).split(b"\n"),
+ generate_latest(REGISTRY, emit_help=False).split(b"\n"),
)
)
- expected = {
+ expected = [
b'synapse_forward_extremities_bucket{le="1.0"} 0.0',
b'synapse_forward_extremities_bucket{le="2.0"} 2.0',
b'synapse_forward_extremities_bucket{le="3.0"} 2.0',
@@ -72,9 +72,12 @@ class ExtremStatisticsTestCase(HomeserverTestCase):
b'synapse_forward_extremities_bucket{le="100.0"} 3.0',
b'synapse_forward_extremities_bucket{le="200.0"} 3.0',
b'synapse_forward_extremities_bucket{le="500.0"} 3.0',
- b'synapse_forward_extremities_bucket{le="+Inf"} 3.0',
- b"synapse_forward_extremities_count 3.0",
- b"synapse_forward_extremities_sum 10.0",
- }
-
+ # per https://docs.google.com/document/d/1KwV0mAXwwbvvifBvDKH_LU1YjyXE_wxCkHNoCGq1GX0/edit#heading=h.wghdjzzh72j9,
+ # "inf" is valid: "this includes variants such as inf"
+ b'synapse_forward_extremities_bucket{le="inf"} 3.0',
+ b"# TYPE synapse_forward_extremities_gcount gauge",
+ b"synapse_forward_extremities_gcount 3.0",
+ b"# TYPE synapse_forward_extremities_gsum gauge",
+ b"synapse_forward_extremities_gsum 10.0",
+ ]
self.assertEqual(items, expected)
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
index 918387733b..cc1f3c53c5 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
- event = self.get_success(
+ token = self.get_success(
store.get_topological_token_for_event(last["event_id"])
)
+ token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
- self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
+ self.get_success(
+ storage.purge_events.purge_history(self.room_id, token_str, True)
+ )
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.
@@ -74,12 +77,10 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_datastore()
# Set the topological token higher than it should be
- event = self.get_success(
+ token = self.get_success(
storage.get_topological_token_for_event(last["event_id"])
)
- event = "t{}-{}".format(
- *list(map(lambda x: x + 1, map(int, event[1:].split("-"))))
- )
+ event = "t{}-{}".format(token.topological + 1, token.stream + 1)
# Purge everything before this topological token
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
|