diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 2abd7a83b5..5d338bea87 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -151,6 +151,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
self.handler = hs.get_oidc_handler()
+ self.provider = self.handler._provider
sso_handler = hs.get_sso_handler()
# Mock the render error method.
self.render_error = Mock(return_value=None)
@@ -162,9 +163,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
return hs
def metadata_edit(self, values):
- return patch.dict(self.handler._provider_metadata, values)
+ return patch.dict(self.provider._provider_metadata, values)
def assertRenderedError(self, error, error_description=None):
+ self.render_error.assert_called_once()
args = self.render_error.call_args[0]
self.assertEqual(args[1], error)
if error_description is not None:
@@ -175,15 +177,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
def test_config(self):
"""Basic config correctly sets up the callback URL and client auth correctly."""
- self.assertEqual(self.handler._callback_url, CALLBACK_URL)
- self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
- self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+ self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+ self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+ self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
@override_config({"oidc_config": {"discover": True}})
def test_discovery(self):
"""The handler should discover the endpoints from OIDC discovery document."""
# This would throw if some metadata were invalid
- metadata = self.get_success(self.handler.load_metadata())
+ metadata = self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
self.assertEqual(metadata.issuer, ISSUER)
@@ -195,47 +197,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
# subsequent calls should be cached
self.http_client.reset_mock()
- self.get_success(self.handler.load_metadata())
+ self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
def test_no_discovery(self):
"""When discovery is disabled, it should not try to load from discovery document."""
- self.get_success(self.handler.load_metadata())
+ self.get_success(self.provider.load_metadata())
self.http_client.get_json.assert_not_called()
@override_config({"oidc_config": COMMON_CONFIG})
def test_load_jwks(self):
"""JWKS loading is done once (then cached) if used."""
- jwks = self.get_success(self.handler.load_jwks())
+ jwks = self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_called_once_with(JWKS_URI)
self.assertEqual(jwks, {"keys": []})
# subsequent calls should be cached…
self.http_client.reset_mock()
- self.get_success(self.handler.load_jwks())
+ self.get_success(self.provider.load_jwks())
self.http_client.get_json.assert_not_called()
# …unless forced
self.http_client.reset_mock()
- self.get_success(self.handler.load_jwks(force=True))
+ self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_called_once_with(JWKS_URI)
# Throw if the JWKS uri is missing
with self.metadata_edit({"jwks_uri": None}):
- self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+ self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
# Return empty key set if JWKS are not used
- self.handler._scopes = [] # not asking the openid scope
+ self.provider._scopes = [] # not asking the openid scope
self.http_client.get_json.reset_mock()
- jwks = self.get_success(self.handler.load_jwks(force=True))
+ jwks = self.get_success(self.provider.load_jwks(force=True))
self.http_client.get_json.assert_not_called()
self.assertEqual(jwks, {"keys": []})
@override_config({"oidc_config": COMMON_CONFIG})
def test_validate_config(self):
"""Provider metadatas are extensively validated."""
- h = self.handler
+ h = self.provider
# Default test config does not throw
h._validate_metadata()
@@ -314,13 +316,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
"""Provider metadata validation can be disabled by config."""
with self.metadata_edit({"issuer": "http://insecure"}):
# This should not throw
- self.handler._validate_metadata()
+ self.provider._validate_metadata()
def test_redirect_request(self):
"""The redirect request has the right arguments & generates a valid session cookie."""
req = Mock(spec=["addCookie"])
url = self.get_success(
- self.handler.handle_redirect_request(req, b"http://client/redirect")
+ self.provider.handle_redirect_request(req, b"http://client/redirect")
)
url = urlparse(url)
auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -388,7 +390,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
# ensure that we are correctly testing the fallback when "get_extra_attributes"
# is not implemented.
- mapping_provider = self.handler._user_mapping_provider
+ mapping_provider = self.provider._user_mapping_provider
with self.assertRaises(AttributeError):
_ = mapping_provider.get_extra_attributes
@@ -403,9 +405,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": username,
}
expected_user_id = "@%s:%s" % (username, self.hs.hostname)
- self.handler._exchange_code = simple_async_mock(return_value=token)
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
- self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -425,14 +427,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None,
)
- self.handler._exchange_code.assert_called_once_with(code)
- self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
- self.handler._fetch_userinfo.assert_not_called()
+ self.provider._exchange_code.assert_called_once_with(code)
+ self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+ self.provider._fetch_userinfo.assert_not_called()
self.render_error.assert_not_called()
# Handle mapping errors
with patch.object(
- self.handler,
+ self.provider,
"_remote_id_from_userinfo",
new=Mock(side_effect=MappingException()),
):
@@ -440,36 +442,36 @@ class OidcHandlerTestCase(HomeserverTestCase):
self.assertRenderedError("mapping_error")
# Handle ID token errors
- self.handler._parse_id_token = simple_async_mock(raises=Exception())
+ self.provider._parse_id_token = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("invalid_token")
auth_handler.complete_sso_login.reset_mock()
- self.handler._exchange_code.reset_mock()
- self.handler._parse_id_token.reset_mock()
- self.handler._fetch_userinfo.reset_mock()
+ self.provider._exchange_code.reset_mock()
+ self.provider._parse_id_token.reset_mock()
+ self.provider._fetch_userinfo.reset_mock()
# With userinfo fetching
- self.handler._scopes = [] # do not ask the "openid" scope
+ self.provider._scopes = [] # do not ask the "openid" scope
self.get_success(self.handler.handle_oidc_callback(request))
auth_handler.complete_sso_login.assert_called_once_with(
expected_user_id, request, client_redirect_url, None,
)
- self.handler._exchange_code.assert_called_once_with(code)
- self.handler._parse_id_token.assert_not_called()
- self.handler._fetch_userinfo.assert_called_once_with(token)
+ self.provider._exchange_code.assert_called_once_with(code)
+ self.provider._parse_id_token.assert_not_called()
+ self.provider._fetch_userinfo.assert_called_once_with(token)
self.render_error.assert_not_called()
# Handle userinfo fetching error
- self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+ self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
self.get_success(self.handler.handle_oidc_callback(request))
self.assertRenderedError("fetch_error")
# Handle code exchange failure
from synapse.handlers.oidc_handler import OidcError
- self.handler._exchange_code = simple_async_mock(
+ self.provider._exchange_code = simple_async_mock(
raises=OidcError("invalid_request")
)
self.get_success(self.handler.handle_oidc_callback(request))
@@ -524,7 +526,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
)
code = "code"
- ret = self.get_success(self.handler._exchange_code(code))
+ ret = self.get_success(self.provider._exchange_code(code))
kwargs = self.http_client.request.call_args[1]
self.assertEqual(ret, token)
@@ -548,7 +550,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
from synapse.handlers.oidc_handler import OidcError
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "foo")
self.assertEqual(exc.value.error_description, "bar")
@@ -558,7 +560,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=500, phrase=b"Internal Server Error", body=b"Not JSON",
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# Internal server error with JSON body
@@ -570,14 +572,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "internal_server_error")
# 4xx error without "error" field
self.http_client.request = simple_async_mock(
return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "server_error")
# 2xx error with "error" field
@@ -586,7 +588,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
code=200, phrase=b"OK", body=b'{"error": "some_error"}',
)
)
- exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+ exc = self.get_failure(self.provider._exchange_code(code), OidcError)
self.assertEqual(exc.value.error, "some_error")
@override_config(
@@ -612,8 +614,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
"username": "foo",
"phone": "1234567",
}
- self.handler._exchange_code = simple_async_mock(return_value=token)
- self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
+ self.provider._exchange_code = simple_async_mock(return_value=token)
+ self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
auth_handler = self.hs.get_auth_handler()
auth_handler.complete_sso_login = simple_async_mock()
@@ -979,9 +981,10 @@ async def _make_callback_with_userinfo(
from synapse.handlers.oidc_handler import OidcSessionData
handler = hs.get_oidc_handler()
- handler._exchange_code = simple_async_mock(return_value={})
- handler._parse_id_token = simple_async_mock(return_value=userinfo)
- handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
+ provider = handler._provider
+ provider._exchange_code = simple_async_mock(return_value={})
+ provider._parse_id_token = simple_async_mock(return_value=userinfo)
+ provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
state = "state"
session = handler._token_generator.generate_oidc_session_token(
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 83c377824b..ff67a73749 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -20,7 +20,10 @@ from twisted.trial import unittest
from synapse.api.constants import EventTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events import _LinkMap
+from synapse.types import create_requester
from tests.unittest import HomeserverTestCase
@@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase):
self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+
+
+class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
+
+ servlets = [
+ admin.register_servlets,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def test_background_update(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create a room
+ user_id = self.register_user("foo", "pass")
+ token = self.login("foo", "pass")
+ room_id = self.helper.create_room_as(user_id, tok=token)
+ requester = create_requester(user_id)
+
+ store = self.hs.get_datastore()
+
+ # Mark the room as not having a chain cover index
+ self.get_success(
+ store.db_pool.simple_update(
+ table="rooms",
+ keyvalues={"room_id": room_id},
+ updatevalues={"has_auth_chain_index": False},
+ desc="test",
+ )
+ )
+
+ # Create a fork in the DAG with different events.
+ event_handler = self.hs.get_event_creation_handler()
+ latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
+ event, context = self.get_success(
+ event_handler.create_event(
+ requester,
+ {
+ "type": "some_state_type",
+ "state_key": "",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(requester, event, context)
+ )
+ state1 = list(self.get_success(context.get_current_state_ids()).values())
+
+ event, context = self.get_success(
+ event_handler.create_event(
+ requester,
+ {
+ "type": "some_state_type",
+ "state_key": "",
+ "content": {},
+ "room_id": room_id,
+ "sender": user_id,
+ },
+ prev_event_ids=latest_event_ids,
+ )
+ )
+ self.get_success(
+ event_handler.handle_new_client_event(requester, event, context)
+ )
+ state2 = list(self.get_success(context.get_current_state_ids()).values())
+
+ # Delete the chain cover info.
+
+ def _delete_tables(txn):
+ txn.execute("DELETE FROM event_auth_chains")
+ txn.execute("DELETE FROM event_auth_chain_links")
+
+ self.get_success(store.db_pool.runInteraction("test", _delete_tables))
+
+ # Insert and run the background update.
+ self.get_success(
+ store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ store.db_pool.updates._all_done = False
+
+ while not self.get_success(
+ store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ store.db_pool.runInteraction(
+ "test",
+ store._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ [state1, state2],
+ )
+ )
diff --git a/tests/util/test_itertools.py b/tests/util/test_itertools.py
index 1184cea5a3..522c8061f9 100644
--- a/tests/util/test_itertools.py
+++ b/tests/util/test_itertools.py
@@ -56,6 +56,14 @@ class SortTopologically(TestCase):
graph = {} # type: Dict[int, List[int]]
self.assertEqual(list(sorted_topologically([], graph)), [])
+ def test_handle_empty_graph(self):
+ "Test that a graph where a node doesn't have an entry is treated as empty"
+
+ graph = {} # type: Dict[int, List[int]]
+
+ # For disconnected nodes the output is simply sorted.
+ self.assertEqual(list(sorted_topologically([1, 2], graph)), [1, 2])
+
def test_disconnected(self):
"Test that a graph with no edges work"
|