summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9117.bugfix1
-rw-r--r--changelog.d/9124.misc1
-rw-r--r--changelog.d/9127.feature1
-rw-r--r--changelog.d/9128.bugfix1
-rw-r--r--changelog.d/9130.feature1
-rw-r--r--changelog.d/9145.bugfix1
-rw-r--r--docs/workers.md3
-rw-r--r--synapse/handlers/auth.py4
-rw-r--r--synapse/handlers/devicemessage.py2
-rw-r--r--synapse/handlers/oidc_handler.py2
-rw-r--r--synapse/http/client.py2
-rw-r--r--synapse/http/matrixfederationclient.py2
-rw-r--r--synapse/rest/synapse/client/pick_idp.py4
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py329
-rw-r--r--synapse/storage/databases/main/pusher.py5
-rw-r--r--tests/rest/client/v1/test_login.py146
-rw-r--r--tests/rest/client/v1/utils.py62
-rw-r--r--tests/rest/client/v2_alpha/test_auth.py4
-rw-r--r--tests/server.py2
-rw-r--r--tests/storage/test_event_chain.py217
-rw-r--r--tests/test_utils/html_parsers.py53
-rw-r--r--tox.ini3
22 files changed, 572 insertions, 274 deletions
diff --git a/changelog.d/9117.bugfix b/changelog.d/9117.bugfix
new file mode 100644

index 0000000000..233a76d18b --- /dev/null +++ b/changelog.d/9117.bugfix
@@ -0,0 +1 @@ +Fix corruption of `pushers` data when a postgres bouncer is used. diff --git a/changelog.d/9124.misc b/changelog.d/9124.misc new file mode 100644
index 0000000000..346741d982 --- /dev/null +++ b/changelog.d/9124.misc
@@ -0,0 +1 @@ +Improve efficiency of large state resolutions. diff --git a/changelog.d/9127.feature b/changelog.d/9127.feature new file mode 100644
index 0000000000..01a24dcf49 --- /dev/null +++ b/changelog.d/9127.feature
@@ -0,0 +1 @@ +Add support for multiple SSO Identity Providers. diff --git a/changelog.d/9128.bugfix b/changelog.d/9128.bugfix new file mode 100644
index 0000000000..f87b9fb9aa --- /dev/null +++ b/changelog.d/9128.bugfix
@@ -0,0 +1 @@ +Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login. diff --git a/changelog.d/9130.feature b/changelog.d/9130.feature new file mode 100644
index 0000000000..4ec319f1f2 --- /dev/null +++ b/changelog.d/9130.feature
@@ -0,0 +1 @@ +Add experimental support for handling and persistence of to-device messages to happen on worker processes. diff --git a/changelog.d/9145.bugfix b/changelog.d/9145.bugfix new file mode 100644
index 0000000000..947cf1dc25 --- /dev/null +++ b/changelog.d/9145.bugfix
@@ -0,0 +1 @@ +Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0. diff --git a/docs/workers.md b/docs/workers.md
index 7fb651bba4..cc5090f224 100644 --- a/docs/workers.md +++ b/docs/workers.md
@@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only be used for demo purposes and any admin considering workers should already be running PostgreSQL. +See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability +for a higher level overview. + ## Main process/worker communication The processes communicate with each other via a Synapse-specific protocol called diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 8b18038720..3127357964 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py
@@ -1494,8 +1494,8 @@ class AuthHandler(BaseHandler): @staticmethod def add_query_param_to_url(url: str, param_name: str, param: Any): url_parts = list(urllib.parse.urlparse(url)) - query = dict(urllib.parse.parse_qsl(url_parts[4])) - query.update({param_name: param}) + query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True) + query.append((param_name, param)) url_parts[4] = urllib.parse.urlencode(query) return urllib.parse.urlunparse(url_parts) diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index fc974a82e8..0c7737e09d 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py
@@ -163,7 +163,7 @@ class DeviceMessageHandler: await self.store.mark_remote_user_device_cache_as_stale(sender_user_id) # Immediately attempt a resync in the background - run_in_background(self._user_device_resync, sender_user_id) + run_in_background(self._user_device_resync, user_id=sender_user_id) async def send_device_message( self, diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 5e5fda7b2f..ba686d74b2 100644 --- a/synapse/handlers/oidc_handler.py +++ b/synapse/handlers/oidc_handler.py
@@ -85,7 +85,7 @@ class OidcHandler: self._token_generator = OidcSessionTokenGenerator(hs) self._providers = { p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs - } + } # type: Dict[str, OidcProvider] async def load_metadata(self) -> None: """Validate the config and load the metadata from the remote endpoint. diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4e5ef106a0..8eb93ba73e 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -722,7 +722,7 @@ class SimpleHttpClient: read_body_with_max_size(response, output_stream, max_size) ) except BodyExceededMaxSize: - SynapseError( + raise SynapseError( 502, "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b7103d6541..19293bf673 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -996,7 +996,7 @@ class MatrixFederationHttpClient: logger.warning( "{%s} [%s] %s", request.txn_id, request.destination, msg, ) - SynapseError(502, msg, Codes.TOO_LARGE) + raise SynapseError(502, msg, Codes.TOO_LARGE) except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
index e5b720bbca..9550b82998 100644 --- a/synapse/rest/synapse/client/pick_idp.py +++ b/synapse/rest/synapse/client/pick_idp.py
@@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource): self._server_name = hs.hostname async def _async_render_GET(self, request: SynapseRequest) -> None: - client_redirect_url = parse_string(request, "redirectUrl", required=True) + client_redirect_url = parse_string( + request, "redirectUrl", required=True, encoding="utf-8" + ) idp = parse_string(request, "idp", required=False) # if we need to pick an IdP, do so diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7128dc1742..e46e44ba54 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -16,6 +16,8 @@ import logging from typing import Dict, List, Optional, Tuple +import attr + from synapse.api.constants import EventContentFields from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import make_event_from_dict @@ -28,6 +30,25 @@ from synapse.types import JsonDict logger = logging.getLogger(__name__) +@attr.s(slots=True, frozen=True) +class _CalculateChainCover: + """Return value for _calculate_chain_cover_txn. + """ + + # The last room_id/depth/stream processed. + room_id = attr.ib(type=str) + depth = attr.ib(type=int) + stream = attr.ib(type=int) + + # Number of rows processed + processed_count = attr.ib(type=int) + + # Map from room_id to last depth/stream processed for each room that we have + # processed all events for (i.e. the rooms we can flip the + # `has_auth_chain_index` for) + finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]]) + + class EventsBackgroundUpdatesStore(SQLBaseStore): EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" @@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): current_room_id = progress.get("current_room_id", "") - # Have we finished processing the current room. - finished = progress.get("finished", True) - # Where we've processed up to in the room, defaults to the start of the # room. last_depth = progress.get("last_depth", -1) last_stream = progress.get("last_stream", -1) - # Have we set the `has_auth_chain_index` for the room yet. - has_set_room_has_chain_index = progress.get( - "has_set_room_has_chain_index", False + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + current_room_id, + last_depth, + last_stream, + batch_size, + single_room=False, ) - if finished: - # If we've finished with the previous room (or its our first - # iteration) we move on to the next room. - - def _get_next_room(txn: Cursor) -> Optional[str]: - sql = """ - SELECT room_id FROM rooms - WHERE room_id > ? - AND ( - NOT has_auth_chain_index - OR has_auth_chain_index IS NULL - ) - ORDER BY room_id - LIMIT 1 - """ - txn.execute(sql, (current_room_id,)) - row = txn.fetchone() - if row: - return row[0] + finished = result.processed_count == 0 - return None - - current_room_id = await self.db_pool.runInteraction( - "_chain_cover_index", _get_next_room - ) - if not current_room_id: - await self.db_pool.updates._end_background_update("chain_cover") - return 0 - - logger.debug("Adding chain cover to %s", current_room_id) - - def _calculate_auth_chain( - txn: Cursor, last_depth: int, last_stream: int - ) -> Tuple[int, int, int]: - # Get the next set of events in the room (that we haven't already - # computed chain cover for). We do this in topological order. - - # We want to do a `(topological_ordering, stream_ordering) > (?,?)` - # comparison, but that is not supported on older SQLite versions - tuple_clause, tuple_args = make_tuple_comparison_clause( - self.database_engine, - [ - ("topological_ordering", last_depth), - ("stream_ordering", last_stream), - ], - ) + total_rows_processed = result.processed_count + current_room_id = result.room_id + last_depth = result.depth + last_stream = result.stream - sql = """ - SELECT - event_id, state_events.type, state_events.state_key, - topological_ordering, stream_ordering - FROM events - INNER JOIN state_events USING (event_id) - LEFT JOIN event_auth_chains USING (event_id) - LEFT JOIN event_auth_chain_to_calculate USING (event_id) - WHERE events.room_id = ? - AND event_auth_chains.event_id IS NULL - AND event_auth_chain_to_calculate.event_id IS NULL - AND %(tuple_cmp)s - ORDER BY topological_ordering, stream_ordering - LIMIT ? - """ % { - "tuple_cmp": tuple_clause, - } - - args = [current_room_id] - args.extend(tuple_args) - args.append(batch_size) - - txn.execute(sql, args) - rows = txn.fetchall() - - # Put the results in the necessary format for - # `_add_chain_cover_index` - event_to_room_id = {row[0]: current_room_id for row in rows} - event_to_types = {row[0]: (row[1], row[2]) for row in rows} - - new_last_depth = rows[-1][3] if rows else last_depth # type: int - new_last_stream = rows[-1][4] if rows else last_stream # type: int - - count = len(rows) - - # We also need to fetch the auth events for them. - auth_events = self.db_pool.simple_select_many_txn( - txn, - table="event_auth", - column="event_id", - iterable=event_to_room_id, - keyvalues={}, - retcols=("event_id", "auth_id"), - ) - - event_to_auth_chain = {} # type: Dict[str, List[str]] - for row in auth_events: - event_to_auth_chain.setdefault(row["event_id"], []).append( - row["auth_id"] - ) - - # Calculate and persist the chain cover index for this set of events. - # - # Annoyingly we need to gut wrench into the persit event store so that - # we can reuse the function to calculate the chain cover for rooms. - PersistEventsStore._add_chain_cover_index( - txn, - self.db_pool, - event_to_room_id, - event_to_types, - event_to_auth_chain, - ) - - return new_last_depth, new_last_stream, count - - last_depth, last_stream, count = await self.db_pool.runInteraction( - "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream - ) - - total_rows_processed = count - - if count < batch_size and not has_set_room_has_chain_index: + for room_id, (depth, stream) in result.finished_room_map.items(): # If we've done all the events in the room we flip the # `has_auth_chain_index` in the DB. Note that its possible for # further events to be persisted between the above and setting the @@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore): await self.db_pool.simple_update( table="rooms", - keyvalues={"room_id": current_room_id}, + keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": True}, desc="_chain_cover_index", ) - has_set_room_has_chain_index = True # Handle any events that might have raced with us flipping the # bit above. - last_depth, last_stream, count = await self.db_pool.runInteraction( - "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream + result = await self.db_pool.runInteraction( + "_chain_cover_index", + self._calculate_chain_cover_txn, + room_id, + depth, + stream, + batch_size=None, + single_room=True, ) - total_rows_processed += count + total_rows_processed += result.processed_count - # Note that at this point its technically possible that more events - # than our `batch_size` have been persisted without their chain - # cover, so we need to continue processing this room if the last - # count returned was equal to the `batch_size`. + if finished: + await self.db_pool.updates._end_background_update("chain_cover") + return total_rows_processed - if count < batch_size: - # We've finished calculating the index for this room, move on to the - # next room. - await self.db_pool.updates._background_update_progress( - "chain_cover", {"current_room_id": current_room_id, "finished": True}, - ) - else: - # We still have outstanding events to calculate the index for. - await self.db_pool.updates._background_update_progress( - "chain_cover", - { - "current_room_id": current_room_id, - "last_depth": last_depth, - "last_stream": last_stream, - "has_auth_chain_index": has_set_room_has_chain_index, - "finished": False, - }, - ) + await self.db_pool.updates._background_update_progress( + "chain_cover", + { + "current_room_id": current_room_id, + "last_depth": last_depth, + "last_stream": last_stream, + }, + ) return total_rows_processed + + def _calculate_chain_cover_txn( + self, + txn: Cursor, + last_room_id: str, + last_depth: int, + last_stream: int, + batch_size: Optional[int], + single_room: bool, + ) -> _CalculateChainCover: + """Calculate the chain cover for `batch_size` events, ordered by + `(room_id, depth, stream)`. + + Args: + txn, + last_room_id, last_depth, last_stream: The `(room_id, depth, stream)` + tuple to fetch results after. + batch_size: The maximum number of events to process. If None then + no limit. + single_room: Whether to calculate the index for just the given + room. + """ + + # Get the next set of events in the room (that we haven't already + # computed chain cover for). We do this in topological order. + + # We want to do a `(topological_ordering, stream_ordering) > (?,?)` + # comparison, but that is not supported on older SQLite versions + tuple_clause, tuple_args = make_tuple_comparison_clause( + self.database_engine, + [ + ("events.room_id", last_room_id), + ("topological_ordering", last_depth), + ("stream_ordering", last_stream), + ], + ) + + extra_clause = "" + if single_room: + extra_clause = "AND events.room_id = ?" + tuple_args.append(last_room_id) + + sql = """ + SELECT + event_id, state_events.type, state_events.state_key, + topological_ordering, stream_ordering, + events.room_id + FROM events + INNER JOIN state_events USING (event_id) + LEFT JOIN event_auth_chains USING (event_id) + LEFT JOIN event_auth_chain_to_calculate USING (event_id) + WHERE event_auth_chains.event_id IS NULL + AND event_auth_chain_to_calculate.event_id IS NULL + AND %(tuple_cmp)s + %(extra)s + ORDER BY events.room_id, topological_ordering, stream_ordering + %(limit)s + """ % { + "tuple_cmp": tuple_clause, + "limit": "LIMIT ?" if batch_size is not None else "", + "extra": extra_clause, + } + + if batch_size is not None: + tuple_args.append(batch_size) + + txn.execute(sql, tuple_args) + rows = txn.fetchall() + + # Put the results in the necessary format for + # `_add_chain_cover_index` + event_to_room_id = {row[0]: row[5] for row in rows} + event_to_types = {row[0]: (row[1], row[2]) for row in rows} + + # Calculate the new last position we've processed up to. + new_last_depth = rows[-1][3] if rows else last_depth # type: int + new_last_stream = rows[-1][4] if rows else last_stream # type: int + new_last_room_id = rows[-1][5] if rows else "" # type: str + + # Map from room_id to last depth/stream_ordering processed for the room, + # excluding the last room (which we're likely still processing). We also + # need to include the room passed in if it's not included in the result + # set (as we then know we've processed all events in said room). + # + # This is the set of rooms that we can now safely flip the + # `has_auth_chain_index` bit for. + finished_rooms = { + row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id + } + if last_room_id not in finished_rooms and last_room_id != new_last_room_id: + finished_rooms[last_room_id] = (last_depth, last_stream) + + count = len(rows) + + # We also need to fetch the auth events for them. + auth_events = self.db_pool.simple_select_many_txn( + txn, + table="event_auth", + column="event_id", + iterable=event_to_room_id, + keyvalues={}, + retcols=("event_id", "auth_id"), + ) + + event_to_auth_chain = {} # type: Dict[str, List[str]] + for row in auth_events: + event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"]) + + # Calculate and persist the chain cover index for this set of events. + # + # Annoyingly we need to gut wrench into the persit event store so that + # we can reuse the function to calculate the chain cover for rooms. + PersistEventsStore._add_chain_cover_index( + txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, + ) + + return _CalculateChainCover( + room_id=new_last_room_id, + depth=new_last_depth, + stream=new_last_stream, + processed_count=count, + finished_room_map=finished_rooms, + ) diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 77ba9d819e..bc7621b8d6 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py
@@ -17,14 +17,13 @@ import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple -from canonicaljson import encode_canonical_json - from synapse.push import PusherConfig, ThrottleParams from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage.database import DatabasePool from synapse.storage.types import Connection from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore): "device_display_name": device_display_name, "ts": pushkey_ts, "lang": lang, - "data": bytearray(encode_canonical_json(data)), + "data": json_encoder.encode(data), "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 73a009efd1..2d25490374 100644 --- a/tests/rest/client/v1/test_login.py +++ b/tests/rest/client/v1/test_login.py
@@ -15,9 +15,8 @@ import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, Optional, Tuple, Union -from urllib.parse import parse_qs, urlencode, urlparse +from typing import Any, Dict, Union +from urllib.parse import urlencode from mock import Mock @@ -38,6 +37,7 @@ from tests import unittest from tests.handlers.test_oidc import HAS_OIDC from tests.handlers.test_saml import has_saml2 from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG +from tests.test_utils.html_parsers import TestHtmlParser from tests.unittest import HomeserverTestCase, override_config, skip_unless try: @@ -69,6 +69,12 @@ TEST_SAML_METADATA = """ LOGIN_URL = b"/_matrix/client/r0/login" TEST_URL = b"/_matrix/client/r0/account/whoami" +# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is + +TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"' + +# the query params in TEST_CLIENT_REDIRECT_URL +EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')] + class LoginRestServletTestCase(unittest.HomeserverTestCase): @@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): }, } + # default OIDC provider config["oidc_config"] = TEST_OIDC_CONFIG + # additional OIDC providers + config["oidc_providers"] = [ + { + "idp_id": "idp1", + "idp_name": "IDP1", + "discover": False, + "issuer": "https://issuer1", + "client_id": "test-client-id", + "client_secret": "test-client-secret", + "scopes": ["profile"], + "authorization_endpoint": "https://issuer1/auth", + "token_endpoint": "https://issuer1/token", + "userinfo_endpoint": "https://issuer1/userinfo", + "user_mapping_provider": { + "config": {"localpart_template": "{{ user.sub }}"} + }, + } + ] return config def create_resource_dict(self) -> Dict[str, Resource]: + from synapse.rest.oidc import OIDCResource + d = super().create_resource_dict() d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs) + d["/_synapse/oidc"] = OIDCResource(self.hs) return d def test_multi_sso_redirect(self): """/login/sso/redirect should redirect to an identity picker""" - client_redirect_url = "https://x?<abc>" - # first hit the redirect url, which should redirect to our idp picker channel = self.make_request( "GET", - "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url, + "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL), ) self.assertEqual(channel.code, 302, channel.result) uri = channel.headers.getRawHeaders("Location")[0] @@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) # parse the form to check it has fields assumed elsewhere in this class - class FormPageParser(HTMLParser): - def __init__(self): - super().__init__() - - # the values of the hidden inputs: map from name to value - self.hiddens = {} # type: Dict[str, Optional[str]] - - # the values of the radio buttons - self.radios = [] # type: List[Optional[str]] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "input": - if attr_dict["type"] == "radio" and attr_dict["name"] == "idp": - self.radios.append(attr_dict["value"]) - elif attr_dict["type"] == "hidden": - input_name = attr_dict["name"] - assert input_name - self.hiddens[input_name] = attr_dict["value"] - - def error(_, message): - self.fail(message) - - p = FormPageParser() + p = TestHtmlParser() p.feed(channel.result["body"].decode("utf-8")) p.close() - self.assertCountEqual(p.radios, ["cas", "oidc", "saml"]) + self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"]) - self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url) + self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_cas(self): """If CAS is chosen, should redirect to the CAS server""" - client_redirect_url = "https://x?<abc>" channel = self.make_request( "GET", - "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas", + "/_synapse/client/pick_idp?redirectUrl=" + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + + "&idp=cas", shorthand=False, ) self.assertEqual(channel.code, 302, channel.result) @@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): service_uri = cas_uri_params["service"][0] _, service_uri_query = service_uri.split("?", 1) service_uri_params = urllib.parse.parse_qs(service_uri_query) - self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url) + self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL) def test_multi_sso_redirect_to_saml(self): """If SAML is chosen, should redirect to the SAML server""" - client_redirect_url = "https://x?<abc>" - channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=saml", ) self.assertEqual(channel.code, 302, channel.result) @@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): # the RelayState is used to carry the client redirect url saml_uri_params = urllib.parse.parse_qs(saml_uri_query) relay_state_param = saml_uri_params["RelayState"][0] - self.assertEqual(relay_state_param, client_redirect_url) + self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL) - def test_multi_sso_redirect_to_oidc(self): + def test_login_via_oidc(self): """If OIDC is chosen, should redirect to the OIDC auth endpoint""" - client_redirect_url = "https://x?<abc>" + # pick the default OIDC provider channel = self.make_request( "GET", "/_synapse/client/pick_idp?redirectUrl=" - + client_redirect_url + + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL) + "&idp=oidc", ) self.assertEqual(channel.code, 302, channel.result) @@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase): macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie) self.assertEqual( self._get_value_from_macaroon(macaroon, "client_redirect_url"), - client_redirect_url, + TEST_CLIENT_REDIRECT_URL, ) + channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"}) + + # that should serve a confirmation page + self.assertEqual(channel.code, 200, channel.result) + self.assertTrue( + channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html") + ) + p = TestHtmlParser() + p.feed(channel.text_body) + p.close() + + # ... which should contain our redirect link + self.assertEqual(len(p.links), 1) + path, query = p.links[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" + ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") + + # finally, submit the matrix login token to the login API, which gives us our + # matrix access token, mxid, and device id. + login_token = params[2][1] + chan = self.make_request( + "POST", "/login", content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(chan.code, 200, chan.result) + self.assertEqual(chan.json_body["user_id"], "@user1:test") + def test_multi_sso_redirect_to_unknown(self): """An unknown IdP should cause a 400""" channel = self.make_request( @@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase): # whitelist this client URI so we redirect straight to it rather than # serving a confirmation page - config["sso"] = {"client_whitelist": ["https://whitelisted.client"]} + config["sso"] = {"client_whitelist": ["https://x"]} return config def create_resource_dict(self) -> Dict[str, Resource]: @@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase): def test_username_picker(self): """Test the happy path of a username picker flow.""" - client_redirect_url = "https://whitelisted.client" # do the start of the login flow channel = self.helper.auth_via_oidc( - {"sub": "tester", "displayname": "Jonny"}, client_redirect_url + {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL ) # that should redirect to the username picker @@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase): session = username_mapping_sessions[session_id] self.assertEqual(session.remote_user_id, "tester") self.assertEqual(session.display_name, "Jonny") - self.assertEqual(session.client_redirect_url, client_redirect_url) + self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL) # the expiry time should be about 15 minutes away expected_expiry = self.clock.time_msec() + (15 * 60 * 1000) @@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase): ) self.assertEqual(chan.code, 302, chan.result) location_headers = chan.headers.getRawHeaders("Location") - # ensure that the returned location starts with the requested redirect URL - self.assertEqual( - location_headers[0][: len(client_redirect_url)], client_redirect_url + # ensure that the returned location matches the requested redirect URL + path, query = location_headers[0].split("?", 1) + self.assertEqual(path, "https://x") + + # it will have url-encoded the params properly, so we'll have to parse them + params = urllib.parse.parse_qsl( + query, keep_blank_values=True, strict_parsing=True, errors="strict" ) + self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS) + self.assertEqual(params[2][0], "loginToken") # fish the login token out of the returned redirect uri - parts = urlparse(location_headers[0]) - query = parse_qs(parts.query) - login_token = query["loginToken"][0] + login_token = params[2][1] # finally, submit the matrix login token to the login API, which gives us our # matrix access token, mxid, and device id. diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index c6647dbe08..b1333df82d 100644 --- a/tests/rest/client/v1/utils.py +++ b/tests/rest/client/v1/utils.py
@@ -20,8 +20,7 @@ import json import re import time import urllib.parse -from html.parser import HTMLParser -from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple +from typing import Any, Dict, Mapping, MutableMapping, Optional from mock import patch @@ -35,6 +34,7 @@ from synapse.types import JsonDict from tests.server import FakeChannel, FakeSite, make_request from tests.test_utils import FakeResponse +from tests.test_utils.html_parsers import TestHtmlParser @attr.s @@ -440,10 +440,36 @@ class RestHelper: # param that synapse passes to the IdP via query params, as well as the cookie # that synapse passes to the client. - oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1) + oauth_uri_path, _ = oauth_uri.split("?", 1) assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, ( "unexpected SSO URI " + oauth_uri_path ) + return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict) + + def complete_oidc_auth( + self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict, + ) -> FakeChannel: + """Mock out an OIDC authentication flow + + Assumes that an OIDC auth has been initiated by one of initiate_sso_login or + initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to + Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get + sent back to the OIDC provider. + + Requires the OIDC callback resource to be mounted at the normal place. + + Args: + oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie, + from initiate_sso_login or initiate_sso_ui_auth). + cookies: the cookies set by synapse's redirect endpoint, which will be + sent back to the callback endpoint. + user_info_dict: the remote userinfo that the OIDC provider should present. + Typically this should be '{"sub": "<remote user id>"}'. + + Returns: + A FakeChannel containing the result of calling the OIDC callback endpoint. + """ + _, oauth_uri_qs = oauth_uri.split("?", 1) params = urllib.parse.parse_qs(oauth_uri_qs) callback_uri = "%s?%s" % ( urllib.parse.urlparse(params["redirect_uri"][0]).path, @@ -456,9 +482,9 @@ class RestHelper: expected_requests = [ # first we get a hit to the token endpoint, which we tell to return # a dummy OIDC access token - ("https://issuer.test/token", {"access_token": "TEST"}), + (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}), # and then one to the user_info endpoint, which returns our remote user id. - ("https://issuer.test/userinfo", user_info_dict), + (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict), ] async def mock_req(method: str, uri: str, data=None, headers=None): @@ -542,25 +568,7 @@ class RestHelper: channel.extract_cookies(cookies) # parse the confirmation page to fish out the link. - class ConfirmationPageParser(HTMLParser): - def __init__(self): - super().__init__() - - self.links = [] # type: List[str] - - def handle_starttag( - self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] - ) -> None: - attr_dict = dict(attrs) - if tag == "a": - href = attr_dict["href"] - if href: - self.links.append(href) - - def error(_, message): - raise AssertionError(message) - - p = ConfirmationPageParser() + p = TestHtmlParser() p.feed(channel.text_body) p.close() assert len(p.links) == 1, "not exactly one link in confirmation page" @@ -570,6 +578,8 @@ class RestHelper: # an 'oidc_config' suitable for login_via_oidc. TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth" +TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token" +TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo" TEST_OIDC_CONFIG = { "enabled": True, "discover": False, @@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = { "client_secret": "test-client-secret", "scopes": ["profile"], "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT, - "token_endpoint": "https://issuer.test/token", - "userinfo_endpoint": "https://issuer.test/userinfo", + "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT, + "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT, "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}}, } diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 7728884bae..a6488a3d29 100644 --- a/tests/rest/client/v2_alpha/test_auth.py +++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -475,7 +475,9 @@ class UIAuthTests(unittest.HomeserverTestCase): session_id = channel.json_body["session"] # do the OIDC auth, but auth as the wrong user - channel = self.helper.auth_via_oidc({"sub": "wrong_user"}, ui_auth_session_id=session_id) + channel = self.helper.auth_via_oidc( + {"sub": "wrong_user"}, ui_auth_session_id=session_id + ) # that should return a failure message self.assertSubstring("We were unable to validate", channel.text_body) diff --git a/tests/server.py b/tests/server.py
index 5a1b66270f..5a85d5fe7f 100644 --- a/tests/server.py +++ b/tests/server.py
@@ -74,7 +74,7 @@ class FakeChannel: return int(self.result["code"]) @property - def headers(self): + def headers(self) -> Headers: if not self.result: raise Exception("No result yet.") h = Headers() diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index ff67a73749..0c46ad595b 100644 --- a/tests/storage/test_event_chain.py +++ b/tests/storage/test_event_chain.py
@@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Tuple +from typing import Dict, List, Set, Tuple from twisted.trial import unittest @@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): login.register_servlets, ] - def test_background_update(self): - """Test that the background update to calculate auth chains for historic - rooms works correctly. - """ - - # Create a room - user_id = self.register_user("foo", "pass") - token = self.login("foo", "pass") - room_id = self.helper.create_room_as(user_id, tok=token) - requester = create_requester(user_id) + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.user_id = self.register_user("foo", "pass") + self.token = self.login("foo", "pass") + self.requester = create_requester(self.user_id) - store = self.hs.get_datastore() + def _generate_room(self) -> Tuple[str, List[Set[str]]]: + """Insert a room without a chain cover index. + """ + room_id = self.helper.create_room_as(self.user_id, tok=self.token) # Mark the room as not having a chain cover index self.get_success( - store.db_pool.simple_update( + self.store.db_pool.simple_update( table="rooms", keyvalues={"room_id": room_id}, updatevalues={"has_auth_chain_index": False}, @@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): # Create a fork in the DAG with different events. event_handler = self.hs.get_event_creation_handler() - latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id)) + latest_event_ids = self.get_success( + self.store.get_prev_events_for_room(room_id) + ) event, context = self.get_success( event_handler.create_event( - requester, + self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, - "sender": user_id, + "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( - event_handler.handle_new_client_event(requester, event, context) + event_handler.handle_new_client_event(self.requester, event, context) ) - state1 = list(self.get_success(context.get_current_state_ids()).values()) + state1 = set(self.get_success(context.get_current_state_ids()).values()) event, context = self.get_success( event_handler.create_event( - requester, + self.requester, { "type": "some_state_type", "state_key": "", "content": {}, "room_id": room_id, - "sender": user_id, + "sender": self.user_id, }, prev_event_ids=latest_event_ids, ) ) self.get_success( - event_handler.handle_new_client_event(requester, event, context) + event_handler.handle_new_client_event(self.requester, event, context) ) - state2 = list(self.get_success(context.get_current_state_ids()).values()) + state2 = set(self.get_success(context.get_current_state_ids()).values()) # Delete the chain cover info. @@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase): txn.execute("DELETE FROM event_auth_chains") txn.execute("DELETE FROM event_auth_chain_links") - self.get_success(store.db_pool.runInteraction("test", _delete_tables)) + self.get_success(self.store.db_pool.runInteraction("test", _delete_tables)) + + return room_id, [state1, state2] + + def test_background_update_single_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() # Insert and run the background update. self.get_success( - store.db_pool.simple_insert( + self.store.db_pool.simple_insert( "background_updates", {"update_name": "chain_cover", "progress_json": "{}"}, ) ) # Ugh, have to reset this flag - store.db_pool.updates._all_done = False + self.store.db_pool.updates._all_done = False while not self.get_success( - store.db_pool.updates.has_completed_background_updates() + self.store.db_pool.updates.has_completed_background_updates() ): self.get_success( - store.db_pool.updates.do_next_background_update(100), by=0.1 + self.store.db_pool.updates.do_next_background_update(100), by=0.1 ) # Test that the `has_auth_chain_index` has been set - self.assertTrue(self.get_success(store.has_auth_chain_index(room_id))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) # Test that calculating the auth chain difference using the newly # calculated chain cover works. self.get_success( - store.db_pool.runInteraction( + self.store.db_pool.runInteraction( "test", - store._get_auth_chain_difference_using_cover_index_txn, + self.store._get_auth_chain_difference_using_cover_index_txn, room_id, - [state1, state2], + states, + ) + ) + + def test_background_update_multiple_rooms(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + # Create a room + room_id1, states1 = self._generate_room() + room_id2, states2 = self._generate_room() + room_id3, states2 = self._generate_room() + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id1, + states1, ) ) + + def test_background_update_single_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create a room + room_id, states = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id))) + + # Test that calculating the auth chain difference using the newly + # calculated chain cover works. + self.get_success( + self.store.db_pool.runInteraction( + "test", + self.store._get_auth_chain_difference_using_cover_index_txn, + room_id, + states, + ) + ) + + def test_background_update_multiple_large_room(self): + """Test that the background update to calculate auth chains for historic + rooms works correctly. + """ + + # Create the rooms + room_id1, _ = self._generate_room() + room_id2, _ = self._generate_room() + + # Add a bunch of state so that it takes multiple iterations of the + # background update to process the room. + for i in range(0, 150): + self.helper.send_state( + room_id1, event_type="m.test", body={"index": i}, tok=self.token + ) + + for i in range(0, 150): + self.helper.send_state( + room_id2, event_type="m.test", body={"index": i}, tok=self.token + ) + + # Insert and run the background update. + self.get_success( + self.store.db_pool.simple_insert( + "background_updates", + {"update_name": "chain_cover", "progress_json": "{}"}, + ) + ) + + # Ugh, have to reset this flag + self.store.db_pool.updates._all_done = False + + iterations = 0 + while not self.get_success( + self.store.db_pool.updates.has_completed_background_updates() + ): + iterations += 1 + self.get_success( + self.store.db_pool.updates.do_next_background_update(100), by=0.1 + ) + + # Ensure that we did actually take multiple iterations to process the + # room. + self.assertGreater(iterations, 1) + + # Test that the `has_auth_chain_index` has been set + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1))) + self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2))) diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py new file mode 100644
index 0000000000..ad563eb3f0 --- /dev/null +++ b/tests/test_utils/html_parsers.py
@@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +# Copyright 2021 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from html.parser import HTMLParser +from typing import Dict, Iterable, List, Optional, Tuple + + +class TestHtmlParser(HTMLParser): + """A generic HTML page parser which extracts useful things from the HTML""" + + def __init__(self): + super().__init__() + + # a list of links found in the doc + self.links = [] # type: List[str] + + # the values of any hidden <input>s: map from name to value + self.hiddens = {} # type: Dict[str, Optional[str]] + + # the values of any radio buttons: map from name to list of values + self.radios = {} # type: Dict[str, List[Optional[str]]] + + def handle_starttag( + self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]] + ) -> None: + attr_dict = dict(attrs) + if tag == "a": + href = attr_dict["href"] + if href: + self.links.append(href) + elif tag == "input": + input_name = attr_dict.get("name") + if attr_dict["type"] == "radio": + assert input_name + self.radios.setdefault(input_name, []).append(attr_dict["value"]) + elif attr_dict["type"] == "hidden": + assert input_name + self.hiddens[input_name] = attr_dict["value"] + + def error(_, message): + raise AssertionError(message) diff --git a/tox.ini b/tox.ini
index 3cf68a47a6..92a59d79c3 100644 --- a/tox.ini +++ b/tox.ini
@@ -105,6 +105,9 @@ usedevelop=true [testenv:py35-old] skip_install=True deps = + # Ensure a version of setuptools that supports Python 3.5 is installed. + setuptools < 51.0.0 + # Old automat version for Twisted Automat == 0.3.0 lxml