diff --git a/changelog.d/9117.bugfix b/changelog.d/9117.bugfix
new file mode 100644
index 0000000000..233a76d18b
--- /dev/null
+++ b/changelog.d/9117.bugfix
@@ -0,0 +1 @@
+Fix corruption of `pushers` data when a postgres bouncer is used.
diff --git a/changelog.d/9124.misc b/changelog.d/9124.misc
new file mode 100644
index 0000000000..346741d982
--- /dev/null
+++ b/changelog.d/9124.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions.
diff --git a/changelog.d/9127.feature b/changelog.d/9127.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9127.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/changelog.d/9128.bugfix b/changelog.d/9128.bugfix
new file mode 100644
index 0000000000..f87b9fb9aa
--- /dev/null
+++ b/changelog.d/9128.bugfix
@@ -0,0 +1 @@
+Fix minor bugs in handling the `clientRedirectUrl` parameter for SSO login.
diff --git a/changelog.d/9130.feature b/changelog.d/9130.feature
new file mode 100644
index 0000000000..4ec319f1f2
--- /dev/null
+++ b/changelog.d/9130.feature
@@ -0,0 +1 @@
+Add experimental support for handling and persistence of to-device messages to happen on worker processes.
diff --git a/changelog.d/9145.bugfix b/changelog.d/9145.bugfix
new file mode 100644
index 0000000000..947cf1dc25
--- /dev/null
+++ b/changelog.d/9145.bugfix
@@ -0,0 +1 @@
+Fix "UnboundLocalError: local variable 'length' referenced before assignment" errors when the response body exceeds the expected size. This bug was introduced in v1.25.0.
diff --git a/docs/workers.md b/docs/workers.md
index 7fb651bba4..cc5090f224 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -16,6 +16,9 @@ workers only work with PostgreSQL-based Synapse deployments. SQLite should only
be used for demo purposes and any admin considering workers should already be
running PostgreSQL.
+See also https://matrix.org/blog/2020/11/03/how-we-fixed-synapses-scalability
+for a higher level overview.
+
## Main process/worker communication
The processes communicate with each other via a Synapse-specific protocol called
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 8b18038720..3127357964 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1494,8 +1494,8 @@ class AuthHandler(BaseHandler):
@staticmethod
def add_query_param_to_url(url: str, param_name: str, param: Any):
url_parts = list(urllib.parse.urlparse(url))
- query = dict(urllib.parse.parse_qsl(url_parts[4]))
- query.update({param_name: param})
+ query = urllib.parse.parse_qsl(url_parts[4], keep_blank_values=True)
+ query.append((param_name, param))
url_parts[4] = urllib.parse.urlencode(query)
return urllib.parse.urlunparse(url_parts)
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index fc974a82e8..0c7737e09d 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -163,7 +163,7 @@ class DeviceMessageHandler:
await self.store.mark_remote_user_device_cache_as_stale(sender_user_id)
# Immediately attempt a resync in the background
- run_in_background(self._user_device_resync, sender_user_id)
+ run_in_background(self._user_device_resync, user_id=sender_user_id)
async def send_device_message(
self,
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 5e5fda7b2f..ba686d74b2 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -85,7 +85,7 @@ class OidcHandler:
self._token_generator = OidcSessionTokenGenerator(hs)
self._providers = {
p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
- }
+ } # type: Dict[str, OidcProvider]
async def load_metadata(self) -> None:
"""Validate the config and load the metadata from the remote endpoint.
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 4e5ef106a0..8eb93ba73e 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -722,7 +722,7 @@ class SimpleHttpClient:
read_body_with_max_size(response, output_stream, max_size)
)
except BodyExceededMaxSize:
- SynapseError(
+ raise SynapseError(
502,
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index b7103d6541..19293bf673 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -996,7 +996,7 @@ class MatrixFederationHttpClient:
logger.warning(
"{%s} [%s] %s", request.txn_id, request.destination, msg,
)
- SynapseError(502, msg, Codes.TOO_LARGE)
+ raise SynapseError(502, msg, Codes.TOO_LARGE)
except Exception as e:
logger.warning(
"{%s} [%s] Error reading response: %s",
diff --git a/synapse/rest/synapse/client/pick_idp.py b/synapse/rest/synapse/client/pick_idp.py
index e5b720bbca..9550b82998 100644
--- a/synapse/rest/synapse/client/pick_idp.py
+++ b/synapse/rest/synapse/client/pick_idp.py
@@ -45,7 +45,9 @@ class PickIdpResource(DirectServeHtmlResource):
self._server_name = hs.hostname
async def _async_render_GET(self, request: SynapseRequest) -> None:
- client_redirect_url = parse_string(request, "redirectUrl", required=True)
+ client_redirect_url = parse_string(
+ request, "redirectUrl", required=True, encoding="utf-8"
+ )
idp = parse_string(request, "idp", required=False)
# if we need to pick an IdP, do so
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7128dc1742..e46e44ba54 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -16,6 +16,8 @@
import logging
from typing import Dict, List, Optional, Tuple
+import attr
+
from synapse.api.constants import EventContentFields
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import make_event_from_dict
@@ -28,6 +30,25 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__)
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+ """Return value for _calculate_chain_cover_txn.
+ """
+
+ # The last room_id/depth/stream processed.
+ room_id = attr.ib(type=str)
+ depth = attr.ib(type=int)
+ stream = attr.ib(type=int)
+
+ # Number of rows processed
+ processed_count = attr.ib(type=int)
+
+ # Map from room_id to last depth/stream processed for each room that we have
+ # processed all events for (i.e. the rooms we can flip the
+ # `has_auth_chain_index` for)
+ finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
class EventsBackgroundUpdatesStore(SQLBaseStore):
EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
current_room_id = progress.get("current_room_id", "")
- # Have we finished processing the current room.
- finished = progress.get("finished", True)
-
# Where we've processed up to in the room, defaults to the start of the
# room.
last_depth = progress.get("last_depth", -1)
last_stream = progress.get("last_stream", -1)
- # Have we set the `has_auth_chain_index` for the room yet.
- has_set_room_has_chain_index = progress.get(
- "has_set_room_has_chain_index", False
+ result = await self.db_pool.runInteraction(
+ "_chain_cover_index",
+ self._calculate_chain_cover_txn,
+ current_room_id,
+ last_depth,
+ last_stream,
+ batch_size,
+ single_room=False,
)
- if finished:
- # If we've finished with the previous room (or its our first
- # iteration) we move on to the next room.
-
- def _get_next_room(txn: Cursor) -> Optional[str]:
- sql = """
- SELECT room_id FROM rooms
- WHERE room_id > ?
- AND (
- NOT has_auth_chain_index
- OR has_auth_chain_index IS NULL
- )
- ORDER BY room_id
- LIMIT 1
- """
- txn.execute(sql, (current_room_id,))
- row = txn.fetchone()
- if row:
- return row[0]
+ finished = result.processed_count == 0
- return None
-
- current_room_id = await self.db_pool.runInteraction(
- "_chain_cover_index", _get_next_room
- )
- if not current_room_id:
- await self.db_pool.updates._end_background_update("chain_cover")
- return 0
-
- logger.debug("Adding chain cover to %s", current_room_id)
-
- def _calculate_auth_chain(
- txn: Cursor, last_depth: int, last_stream: int
- ) -> Tuple[int, int, int]:
- # Get the next set of events in the room (that we haven't already
- # computed chain cover for). We do this in topological order.
-
- # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
- # comparison, but that is not supported on older SQLite versions
- tuple_clause, tuple_args = make_tuple_comparison_clause(
- self.database_engine,
- [
- ("topological_ordering", last_depth),
- ("stream_ordering", last_stream),
- ],
- )
+ total_rows_processed = result.processed_count
+ current_room_id = result.room_id
+ last_depth = result.depth
+ last_stream = result.stream
- sql = """
- SELECT
- event_id, state_events.type, state_events.state_key,
- topological_ordering, stream_ordering
- FROM events
- INNER JOIN state_events USING (event_id)
- LEFT JOIN event_auth_chains USING (event_id)
- LEFT JOIN event_auth_chain_to_calculate USING (event_id)
- WHERE events.room_id = ?
- AND event_auth_chains.event_id IS NULL
- AND event_auth_chain_to_calculate.event_id IS NULL
- AND %(tuple_cmp)s
- ORDER BY topological_ordering, stream_ordering
- LIMIT ?
- """ % {
- "tuple_cmp": tuple_clause,
- }
-
- args = [current_room_id]
- args.extend(tuple_args)
- args.append(batch_size)
-
- txn.execute(sql, args)
- rows = txn.fetchall()
-
- # Put the results in the necessary format for
- # `_add_chain_cover_index`
- event_to_room_id = {row[0]: current_room_id for row in rows}
- event_to_types = {row[0]: (row[1], row[2]) for row in rows}
-
- new_last_depth = rows[-1][3] if rows else last_depth # type: int
- new_last_stream = rows[-1][4] if rows else last_stream # type: int
-
- count = len(rows)
-
- # We also need to fetch the auth events for them.
- auth_events = self.db_pool.simple_select_many_txn(
- txn,
- table="event_auth",
- column="event_id",
- iterable=event_to_room_id,
- keyvalues={},
- retcols=("event_id", "auth_id"),
- )
-
- event_to_auth_chain = {} # type: Dict[str, List[str]]
- for row in auth_events:
- event_to_auth_chain.setdefault(row["event_id"], []).append(
- row["auth_id"]
- )
-
- # Calculate and persist the chain cover index for this set of events.
- #
- # Annoyingly we need to gut wrench into the persit event store so that
- # we can reuse the function to calculate the chain cover for rooms.
- PersistEventsStore._add_chain_cover_index(
- txn,
- self.db_pool,
- event_to_room_id,
- event_to_types,
- event_to_auth_chain,
- )
-
- return new_last_depth, new_last_stream, count
-
- last_depth, last_stream, count = await self.db_pool.runInteraction(
- "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
- )
-
- total_rows_processed = count
-
- if count < batch_size and not has_set_room_has_chain_index:
+ for room_id, (depth, stream) in result.finished_room_map.items():
# If we've done all the events in the room we flip the
# `has_auth_chain_index` in the DB. Note that its possible for
# further events to be persisted between the above and setting the
@@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
await self.db_pool.simple_update(
table="rooms",
- keyvalues={"room_id": current_room_id},
+ keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": True},
desc="_chain_cover_index",
)
- has_set_room_has_chain_index = True
# Handle any events that might have raced with us flipping the
# bit above.
- last_depth, last_stream, count = await self.db_pool.runInteraction(
- "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+ result = await self.db_pool.runInteraction(
+ "_chain_cover_index",
+ self._calculate_chain_cover_txn,
+ room_id,
+ depth,
+ stream,
+ batch_size=None,
+ single_room=True,
)
- total_rows_processed += count
+ total_rows_processed += result.processed_count
- # Note that at this point its technically possible that more events
- # than our `batch_size` have been persisted without their chain
- # cover, so we need to continue processing this room if the last
- # count returned was equal to the `batch_size`.
+ if finished:
+ await self.db_pool.updates._end_background_update("chain_cover")
+ return total_rows_processed
- if count < batch_size:
- # We've finished calculating the index for this room, move on to the
- # next room.
- await self.db_pool.updates._background_update_progress(
- "chain_cover", {"current_room_id": current_room_id, "finished": True},
- )
- else:
- # We still have outstanding events to calculate the index for.
- await self.db_pool.updates._background_update_progress(
- "chain_cover",
- {
- "current_room_id": current_room_id,
- "last_depth": last_depth,
- "last_stream": last_stream,
- "has_auth_chain_index": has_set_room_has_chain_index,
- "finished": False,
- },
- )
+ await self.db_pool.updates._background_update_progress(
+ "chain_cover",
+ {
+ "current_room_id": current_room_id,
+ "last_depth": last_depth,
+ "last_stream": last_stream,
+ },
+ )
return total_rows_processed
+
+ def _calculate_chain_cover_txn(
+ self,
+ txn: Cursor,
+ last_room_id: str,
+ last_depth: int,
+ last_stream: int,
+ batch_size: Optional[int],
+ single_room: bool,
+ ) -> _CalculateChainCover:
+ """Calculate the chain cover for `batch_size` events, ordered by
+ `(room_id, depth, stream)`.
+
+ Args:
+ txn,
+ last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+ tuple to fetch results after.
+ batch_size: The maximum number of events to process. If None then
+ no limit.
+ single_room: Whether to calculate the index for just the given
+ room.
+ """
+
+ # Get the next set of events in the room (that we haven't already
+ # computed chain cover for). We do this in topological order.
+
+ # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+ # comparison, but that is not supported on older SQLite versions
+ tuple_clause, tuple_args = make_tuple_comparison_clause(
+ self.database_engine,
+ [
+ ("events.room_id", last_room_id),
+ ("topological_ordering", last_depth),
+ ("stream_ordering", last_stream),
+ ],
+ )
+
+ extra_clause = ""
+ if single_room:
+ extra_clause = "AND events.room_id = ?"
+ tuple_args.append(last_room_id)
+
+ sql = """
+ SELECT
+ event_id, state_events.type, state_events.state_key,
+ topological_ordering, stream_ordering,
+ events.room_id
+ FROM events
+ INNER JOIN state_events USING (event_id)
+ LEFT JOIN event_auth_chains USING (event_id)
+ LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+ WHERE event_auth_chains.event_id IS NULL
+ AND event_auth_chain_to_calculate.event_id IS NULL
+ AND %(tuple_cmp)s
+ %(extra)s
+ ORDER BY events.room_id, topological_ordering, stream_ordering
+ %(limit)s
+ """ % {
+ "tuple_cmp": tuple_clause,
+ "limit": "LIMIT ?" if batch_size is not None else "",
+ "extra": extra_clause,
+ }
+
+ if batch_size is not None:
+ tuple_args.append(batch_size)
+
+ txn.execute(sql, tuple_args)
+ rows = txn.fetchall()
+
+ # Put the results in the necessary format for
+ # `_add_chain_cover_index`
+ event_to_room_id = {row[0]: row[5] for row in rows}
+ event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+ # Calculate the new last position we've processed up to.
+ new_last_depth = rows[-1][3] if rows else last_depth # type: int
+ new_last_stream = rows[-1][4] if rows else last_stream # type: int
+ new_last_room_id = rows[-1][5] if rows else "" # type: str
+
+ # Map from room_id to last depth/stream_ordering processed for the room,
+ # excluding the last room (which we're likely still processing). We also
+ # need to include the room passed in if it's not included in the result
+ # set (as we then know we've processed all events in said room).
+ #
+ # This is the set of rooms that we can now safely flip the
+ # `has_auth_chain_index` bit for.
+ finished_rooms = {
+ row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+ }
+ if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+ finished_rooms[last_room_id] = (last_depth, last_stream)
+
+ count = len(rows)
+
+ # We also need to fetch the auth events for them.
+ auth_events = self.db_pool.simple_select_many_txn(
+ txn,
+ table="event_auth",
+ column="event_id",
+ iterable=event_to_room_id,
+ keyvalues={},
+ retcols=("event_id", "auth_id"),
+ )
+
+ event_to_auth_chain = {} # type: Dict[str, List[str]]
+ for row in auth_events:
+ event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+ # Calculate and persist the chain cover index for this set of events.
+ #
+ # Annoyingly we need to gut wrench into the persit event store so that
+ # we can reuse the function to calculate the chain cover for rooms.
+ PersistEventsStore._add_chain_cover_index(
+ txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+ )
+
+ return _CalculateChainCover(
+ room_id=new_last_room_id,
+ depth=new_last_depth,
+ stream=new_last_stream,
+ processed_count=count,
+ finished_room_map=finished_rooms,
+ )
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 77ba9d819e..bc7621b8d6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -17,14 +17,13 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
-from canonicaljson import encode_canonical_json
-
from synapse.push import PusherConfig, ThrottleParams
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.storage.types import Connection
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
+from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
@@ -315,7 +314,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
- "data": bytearray(encode_canonical_json(data)),
+ "data": json_encoder.encode(data),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index 73a009efd1..2d25490374 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -15,9 +15,8 @@
import time
import urllib.parse
-from html.parser import HTMLParser
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from urllib.parse import parse_qs, urlencode, urlparse
+from typing import Any, Dict, Union
+from urllib.parse import urlencode
from mock import Mock
@@ -38,6 +37,7 @@ from tests import unittest
from tests.handlers.test_oidc import HAS_OIDC
from tests.handlers.test_saml import has_saml2
from tests.rest.client.v1.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.test_utils.html_parsers import TestHtmlParser
from tests.unittest import HomeserverTestCase, override_config, skip_unless
try:
@@ -69,6 +69,12 @@ TEST_SAML_METADATA = """
LOGIN_URL = b"/_matrix/client/r0/login"
TEST_URL = b"/_matrix/client/r0/account/whoami"
+# a (valid) url with some annoying characters in. %3D is =, %26 is &, %2B is +
+TEST_CLIENT_REDIRECT_URL = 'https://x?<ab c>&q"+%3D%2B"="fö%26=o"'
+
+# the query params in TEST_CLIENT_REDIRECT_URL
+EXPECTED_CLIENT_REDIRECT_URL_PARAMS = [("<ab c>", ""), ('q" =+"', '"fö&=o"')]
+
class LoginRestServletTestCase(unittest.HomeserverTestCase):
@@ -389,23 +395,44 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
},
}
+ # default OIDC provider
config["oidc_config"] = TEST_OIDC_CONFIG
+ # additional OIDC providers
+ config["oidc_providers"] = [
+ {
+ "idp_id": "idp1",
+ "idp_name": "IDP1",
+ "discover": False,
+ "issuer": "https://issuer1",
+ "client_id": "test-client-id",
+ "client_secret": "test-client-secret",
+ "scopes": ["profile"],
+ "authorization_endpoint": "https://issuer1/auth",
+ "token_endpoint": "https://issuer1/token",
+ "userinfo_endpoint": "https://issuer1/userinfo",
+ "user_mapping_provider": {
+ "config": {"localpart_template": "{{ user.sub }}"}
+ },
+ }
+ ]
return config
def create_resource_dict(self) -> Dict[str, Resource]:
+ from synapse.rest.oidc import OIDCResource
+
d = super().create_resource_dict()
d["/_synapse/client/pick_idp"] = PickIdpResource(self.hs)
+ d["/_synapse/oidc"] = OIDCResource(self.hs)
return d
def test_multi_sso_redirect(self):
"""/login/sso/redirect should redirect to an identity picker"""
- client_redirect_url = "https://x?<abc>"
-
# first hit the redirect url, which should redirect to our idp picker
channel = self.make_request(
"GET",
- "/_matrix/client/r0/login/sso/redirect?redirectUrl=" + client_redirect_url,
+ "/_matrix/client/r0/login/sso/redirect?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL),
)
self.assertEqual(channel.code, 302, channel.result)
uri = channel.headers.getRawHeaders("Location")[0]
@@ -415,46 +442,22 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.result)
# parse the form to check it has fields assumed elsewhere in this class
- class FormPageParser(HTMLParser):
- def __init__(self):
- super().__init__()
-
- # the values of the hidden inputs: map from name to value
- self.hiddens = {} # type: Dict[str, Optional[str]]
-
- # the values of the radio buttons
- self.radios = [] # type: List[Optional[str]]
-
- def handle_starttag(
- self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
- ) -> None:
- attr_dict = dict(attrs)
- if tag == "input":
- if attr_dict["type"] == "radio" and attr_dict["name"] == "idp":
- self.radios.append(attr_dict["value"])
- elif attr_dict["type"] == "hidden":
- input_name = attr_dict["name"]
- assert input_name
- self.hiddens[input_name] = attr_dict["value"]
-
- def error(_, message):
- self.fail(message)
-
- p = FormPageParser()
+ p = TestHtmlParser()
p.feed(channel.result["body"].decode("utf-8"))
p.close()
- self.assertCountEqual(p.radios, ["cas", "oidc", "saml"])
+ self.assertCountEqual(p.radios["idp"], ["cas", "oidc", "idp1", "saml"])
- self.assertEqual(p.hiddens["redirectUrl"], client_redirect_url)
+ self.assertEqual(p.hiddens["redirectUrl"], TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_cas(self):
"""If CAS is chosen, should redirect to the CAS server"""
- client_redirect_url = "https://x?<abc>"
channel = self.make_request(
"GET",
- "/_synapse/client/pick_idp?redirectUrl=" + client_redirect_url + "&idp=cas",
+ "/_synapse/client/pick_idp?redirectUrl="
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ + "&idp=cas",
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
@@ -470,16 +473,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
service_uri = cas_uri_params["service"][0]
_, service_uri_query = service_uri.split("?", 1)
service_uri_params = urllib.parse.parse_qs(service_uri_query)
- self.assertEqual(service_uri_params["redirectUrl"][0], client_redirect_url)
+ self.assertEqual(service_uri_params["redirectUrl"][0], TEST_CLIENT_REDIRECT_URL)
def test_multi_sso_redirect_to_saml(self):
"""If SAML is chosen, should redirect to the SAML server"""
- client_redirect_url = "https://x?<abc>"
-
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
- + client_redirect_url
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
@@ -492,16 +493,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# the RelayState is used to carry the client redirect url
saml_uri_params = urllib.parse.parse_qs(saml_uri_query)
relay_state_param = saml_uri_params["RelayState"][0]
- self.assertEqual(relay_state_param, client_redirect_url)
+ self.assertEqual(relay_state_param, TEST_CLIENT_REDIRECT_URL)
- def test_multi_sso_redirect_to_oidc(self):
+ def test_login_via_oidc(self):
"""If OIDC is chosen, should redirect to the OIDC auth endpoint"""
- client_redirect_url = "https://x?<abc>"
+ # pick the default OIDC provider
channel = self.make_request(
"GET",
"/_synapse/client/pick_idp?redirectUrl="
- + client_redirect_url
+ + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
@@ -521,9 +522,41 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
self.assertEqual(
self._get_value_from_macaroon(macaroon, "client_redirect_url"),
- client_redirect_url,
+ TEST_CLIENT_REDIRECT_URL,
)
+ channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+
+ # that should serve a confirmation page
+ self.assertEqual(channel.code, 200, channel.result)
+ self.assertTrue(
+ channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
+ )
+ p = TestHtmlParser()
+ p.feed(channel.text_body)
+ p.close()
+
+ # ... which should contain our redirect link
+ self.assertEqual(len(p.links), 1)
+ path, query = p.links[0].split("?", 1)
+ self.assertEqual(path, "https://x")
+
+ # it will have url-encoded the params properly, so we'll have to parse them
+ params = urllib.parse.parse_qsl(
+ query, keep_blank_values=True, strict_parsing=True, errors="strict"
+ )
+ self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+ self.assertEqual(params[2][0], "loginToken")
+
+ # finally, submit the matrix login token to the login API, which gives us our
+ # matrix access token, mxid, and device id.
+ login_token = params[2][1]
+ chan = self.make_request(
+ "POST", "/login", content={"type": "m.login.token", "token": login_token},
+ )
+ self.assertEqual(chan.code, 200, chan.result)
+ self.assertEqual(chan.json_body["user_id"], "@user1:test")
+
def test_multi_sso_redirect_to_unknown(self):
"""An unknown IdP should cause a 400"""
channel = self.make_request(
@@ -1082,7 +1115,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
# whitelist this client URI so we redirect straight to it rather than
# serving a confirmation page
- config["sso"] = {"client_whitelist": ["https://whitelisted.client"]}
+ config["sso"] = {"client_whitelist": ["https://x"]}
return config
def create_resource_dict(self) -> Dict[str, Resource]:
@@ -1095,11 +1128,10 @@ class UsernamePickerTestCase(HomeserverTestCase):
def test_username_picker(self):
"""Test the happy path of a username picker flow."""
- client_redirect_url = "https://whitelisted.client"
# do the start of the login flow
channel = self.helper.auth_via_oidc(
- {"sub": "tester", "displayname": "Jonny"}, client_redirect_url
+ {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
)
# that should redirect to the username picker
@@ -1122,7 +1154,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
session = username_mapping_sessions[session_id]
self.assertEqual(session.remote_user_id, "tester")
self.assertEqual(session.display_name, "Jonny")
- self.assertEqual(session.client_redirect_url, client_redirect_url)
+ self.assertEqual(session.client_redirect_url, TEST_CLIENT_REDIRECT_URL)
# the expiry time should be about 15 minutes away
expected_expiry = self.clock.time_msec() + (15 * 60 * 1000)
@@ -1146,15 +1178,19 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
- # ensure that the returned location starts with the requested redirect URL
- self.assertEqual(
- location_headers[0][: len(client_redirect_url)], client_redirect_url
+ # ensure that the returned location matches the requested redirect URL
+ path, query = location_headers[0].split("?", 1)
+ self.assertEqual(path, "https://x")
+
+ # it will have url-encoded the params properly, so we'll have to parse them
+ params = urllib.parse.parse_qsl(
+ query, keep_blank_values=True, strict_parsing=True, errors="strict"
)
+ self.assertEqual(params[0:2], EXPECTED_CLIENT_REDIRECT_URL_PARAMS)
+ self.assertEqual(params[2][0], "loginToken")
# fish the login token out of the returned redirect uri
- parts = urlparse(location_headers[0])
- query = parse_qs(parts.query)
- login_token = query["loginToken"][0]
+ login_token = params[2][1]
# finally, submit the matrix login token to the login API, which gives us our
# matrix access token, mxid, and device id.
diff --git a/tests/rest/client/v1/utils.py b/tests/rest/client/v1/utils.py
index c6647dbe08..b1333df82d 100644
--- a/tests/rest/client/v1/utils.py
+++ b/tests/rest/client/v1/utils.py
@@ -20,8 +20,7 @@ import json
import re
import time
import urllib.parse
-from html.parser import HTMLParser
-from typing import Any, Dict, Iterable, List, MutableMapping, Optional, Tuple
+from typing import Any, Dict, Mapping, MutableMapping, Optional
from mock import patch
@@ -35,6 +34,7 @@ from synapse.types import JsonDict
from tests.server import FakeChannel, FakeSite, make_request
from tests.test_utils import FakeResponse
+from tests.test_utils.html_parsers import TestHtmlParser
@attr.s
@@ -440,10 +440,36 @@ class RestHelper:
# param that synapse passes to the IdP via query params, as well as the cookie
# that synapse passes to the client.
- oauth_uri_path, oauth_uri_qs = oauth_uri.split("?", 1)
+ oauth_uri_path, _ = oauth_uri.split("?", 1)
assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
"unexpected SSO URI " + oauth_uri_path
)
+ return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+
+ def complete_oidc_auth(
+ self, oauth_uri: str, cookies: Mapping[str, str], user_info_dict: JsonDict,
+ ) -> FakeChannel:
+ """Mock out an OIDC authentication flow
+
+ Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
+ initiate_sso_ui_auth; completes the OIDC bits of the flow by making a request to
+ Synapse's OIDC callback endpoint, intercepting the HTTP requests that will get
+ sent back to the OIDC provider.
+
+ Requires the OIDC callback resource to be mounted at the normal place.
+
+ Args:
+ oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
+ from initiate_sso_login or initiate_sso_ui_auth).
+ cookies: the cookies set by synapse's redirect endpoint, which will be
+ sent back to the callback endpoint.
+ user_info_dict: the remote userinfo that the OIDC provider should present.
+ Typically this should be '{"sub": "<remote user id>"}'.
+
+ Returns:
+ A FakeChannel containing the result of calling the OIDC callback endpoint.
+ """
+ _, oauth_uri_qs = oauth_uri.split("?", 1)
params = urllib.parse.parse_qs(oauth_uri_qs)
callback_uri = "%s?%s" % (
urllib.parse.urlparse(params["redirect_uri"][0]).path,
@@ -456,9 +482,9 @@ class RestHelper:
expected_requests = [
# first we get a hit to the token endpoint, which we tell to return
# a dummy OIDC access token
- ("https://issuer.test/token", {"access_token": "TEST"}),
+ (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
# and then one to the user_info endpoint, which returns our remote user id.
- ("https://issuer.test/userinfo", user_info_dict),
+ (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
]
async def mock_req(method: str, uri: str, data=None, headers=None):
@@ -542,25 +568,7 @@ class RestHelper:
channel.extract_cookies(cookies)
# parse the confirmation page to fish out the link.
- class ConfirmationPageParser(HTMLParser):
- def __init__(self):
- super().__init__()
-
- self.links = [] # type: List[str]
-
- def handle_starttag(
- self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
- ) -> None:
- attr_dict = dict(attrs)
- if tag == "a":
- href = attr_dict["href"]
- if href:
- self.links.append(href)
-
- def error(_, message):
- raise AssertionError(message)
-
- p = ConfirmationPageParser()
+ p = TestHtmlParser()
p.feed(channel.text_body)
p.close()
assert len(p.links) == 1, "not exactly one link in confirmation page"
@@ -570,6 +578,8 @@ class RestHelper:
# an 'oidc_config' suitable for login_via_oidc.
TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
+TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
+TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
TEST_OIDC_CONFIG = {
"enabled": True,
"discover": False,
@@ -578,7 +588,7 @@ TEST_OIDC_CONFIG = {
"client_secret": "test-client-secret",
"scopes": ["profile"],
"authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
- "token_endpoint": "https://issuer.test/token",
- "userinfo_endpoint": "https://issuer.test/userinfo",
+ "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
+ "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
"user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
}
diff --git a/tests/rest/client/v2_alpha/test_auth.py b/tests/rest/client/v2_alpha/test_auth.py
index 7728884bae..a6488a3d29 100644
--- a/tests/rest/client/v2_alpha/test_auth.py
+++ b/tests/rest/client/v2_alpha/test_auth.py
@@ -475,7 +475,9 @@ class UIAuthTests(unittest.HomeserverTestCase):
session_id = channel.json_body["session"]
# do the OIDC auth, but auth as the wrong user
- channel = self.helper.auth_via_oidc({"sub": "wrong_user"}, ui_auth_session_id=session_id)
+ channel = self.helper.auth_via_oidc(
+ {"sub": "wrong_user"}, ui_auth_session_id=session_id
+ )
# that should return a failure message
self.assertSubstring("We were unable to validate", channel.text_body)
diff --git a/tests/server.py b/tests/server.py
index 5a1b66270f..5a85d5fe7f 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -74,7 +74,7 @@ class FakeChannel:
return int(self.result["code"])
@property
- def headers(self):
+ def headers(self) -> Headers:
if not self.result:
raise Exception("No result yet.")
h = Headers()
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index ff67a73749..0c46ad595b 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from typing import Dict, List, Tuple
+from typing import Dict, List, Set, Tuple
from twisted.trial import unittest
@@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
login.register_servlets,
]
- def test_background_update(self):
- """Test that the background update to calculate auth chains for historic
- rooms works correctly.
- """
-
- # Create a room
- user_id = self.register_user("foo", "pass")
- token = self.login("foo", "pass")
- room_id = self.helper.create_room_as(user_id, tok=token)
- requester = create_requester(user_id)
+ def prepare(self, reactor, clock, hs):
+ self.store = hs.get_datastore()
+ self.user_id = self.register_user("foo", "pass")
+ self.token = self.login("foo", "pass")
+ self.requester = create_requester(self.user_id)
- store = self.hs.get_datastore()
+ def _generate_room(self) -> Tuple[str, List[Set[str]]]:
+ """Insert a room without a chain cover index.
+ """
+ room_id = self.helper.create_room_as(self.user_id, tok=self.token)
# Mark the room as not having a chain cover index
self.get_success(
- store.db_pool.simple_update(
+ self.store.db_pool.simple_update(
table="rooms",
keyvalues={"room_id": room_id},
updatevalues={"has_auth_chain_index": False},
@@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
# Create a fork in the DAG with different events.
event_handler = self.hs.get_event_creation_handler()
- latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
+ latest_event_ids = self.get_success(
+ self.store.get_prev_events_for_room(room_id)
+ )
event, context = self.get_success(
event_handler.create_event(
- requester,
+ self.requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
- "sender": user_id,
+ "sender": self.user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
- event_handler.handle_new_client_event(requester, event, context)
+ event_handler.handle_new_client_event(self.requester, event, context)
)
- state1 = list(self.get_success(context.get_current_state_ids()).values())
+ state1 = set(self.get_success(context.get_current_state_ids()).values())
event, context = self.get_success(
event_handler.create_event(
- requester,
+ self.requester,
{
"type": "some_state_type",
"state_key": "",
"content": {},
"room_id": room_id,
- "sender": user_id,
+ "sender": self.user_id,
},
prev_event_ids=latest_event_ids,
)
)
self.get_success(
- event_handler.handle_new_client_event(requester, event, context)
+ event_handler.handle_new_client_event(self.requester, event, context)
)
- state2 = list(self.get_success(context.get_current_state_ids()).values())
+ state2 = set(self.get_success(context.get_current_state_ids()).values())
# Delete the chain cover info.
@@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
txn.execute("DELETE FROM event_auth_chains")
txn.execute("DELETE FROM event_auth_chain_links")
- self.get_success(store.db_pool.runInteraction("test", _delete_tables))
+ self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
+
+ return room_id, [state1, state2]
+
+ def test_background_update_single_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create a room
+ room_id, states = self._generate_room()
# Insert and run the background update.
self.get_success(
- store.db_pool.simple_insert(
+ self.store.db_pool.simple_insert(
"background_updates",
{"update_name": "chain_cover", "progress_json": "{}"},
)
)
# Ugh, have to reset this flag
- store.db_pool.updates._all_done = False
+ self.store.db_pool.updates._all_done = False
while not self.get_success(
- store.db_pool.updates.has_completed_background_updates()
+ self.store.db_pool.updates.has_completed_background_updates()
):
self.get_success(
- store.db_pool.updates.do_next_background_update(100), by=0.1
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
)
# Test that the `has_auth_chain_index` has been set
- self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
# Test that calculating the auth chain difference using the newly
# calculated chain cover works.
self.get_success(
- store.db_pool.runInteraction(
+ self.store.db_pool.runInteraction(
"test",
- store._get_auth_chain_difference_using_cover_index_txn,
+ self.store._get_auth_chain_difference_using_cover_index_txn,
room_id,
- [state1, state2],
+ states,
+ )
+ )
+
+ def test_background_update_multiple_rooms(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+ # Create a room
+ room_id1, states1 = self._generate_room()
+ room_id2, states2 = self._generate_room()
+ room_id3, states2 = self._generate_room()
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test",
+ self.store._get_auth_chain_difference_using_cover_index_txn,
+ room_id1,
+ states1,
)
)
+
+ def test_background_update_single_large_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create a room
+ room_id, states = self._generate_room()
+
+ # Add a bunch of state so that it takes multiple iterations of the
+ # background update to process the room.
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ iterations = 0
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ iterations += 1
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Ensure that we did actually take multiple iterations to process the
+ # room.
+ self.assertGreater(iterations, 1)
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+ # Test that calculating the auth chain difference using the newly
+ # calculated chain cover works.
+ self.get_success(
+ self.store.db_pool.runInteraction(
+ "test",
+ self.store._get_auth_chain_difference_using_cover_index_txn,
+ room_id,
+ states,
+ )
+ )
+
+ def test_background_update_multiple_large_room(self):
+ """Test that the background update to calculate auth chains for historic
+ rooms works correctly.
+ """
+
+ # Create the rooms
+ room_id1, _ = self._generate_room()
+ room_id2, _ = self._generate_room()
+
+ # Add a bunch of state so that it takes multiple iterations of the
+ # background update to process the room.
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id1, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ for i in range(0, 150):
+ self.helper.send_state(
+ room_id2, event_type="m.test", body={"index": i}, tok=self.token
+ )
+
+ # Insert and run the background update.
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ "background_updates",
+ {"update_name": "chain_cover", "progress_json": "{}"},
+ )
+ )
+
+ # Ugh, have to reset this flag
+ self.store.db_pool.updates._all_done = False
+
+ iterations = 0
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ iterations += 1
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # Ensure that we did actually take multiple iterations to process the
+ # room.
+ self.assertGreater(iterations, 1)
+
+ # Test that the `has_auth_chain_index` has been set
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+ self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
diff --git a/tests/test_utils/html_parsers.py b/tests/test_utils/html_parsers.py
new file mode 100644
index 0000000000..ad563eb3f0
--- /dev/null
+++ b/tests/test_utils/html_parsers.py
@@ -0,0 +1,53 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from html.parser import HTMLParser
+from typing import Dict, Iterable, List, Optional, Tuple
+
+
+class TestHtmlParser(HTMLParser):
+ """A generic HTML page parser which extracts useful things from the HTML"""
+
+ def __init__(self):
+ super().__init__()
+
+ # a list of links found in the doc
+ self.links = [] # type: List[str]
+
+ # the values of any hidden <input>s: map from name to value
+ self.hiddens = {} # type: Dict[str, Optional[str]]
+
+ # the values of any radio buttons: map from name to list of values
+ self.radios = {} # type: Dict[str, List[Optional[str]]]
+
+ def handle_starttag(
+ self, tag: str, attrs: Iterable[Tuple[str, Optional[str]]]
+ ) -> None:
+ attr_dict = dict(attrs)
+ if tag == "a":
+ href = attr_dict["href"]
+ if href:
+ self.links.append(href)
+ elif tag == "input":
+ input_name = attr_dict.get("name")
+ if attr_dict["type"] == "radio":
+ assert input_name
+ self.radios.setdefault(input_name, []).append(attr_dict["value"])
+ elif attr_dict["type"] == "hidden":
+ assert input_name
+ self.hiddens[input_name] = attr_dict["value"]
+
+ def error(_, message):
+ raise AssertionError(message)
diff --git a/tox.ini b/tox.ini
index 3cf68a47a6..92a59d79c3 100644
--- a/tox.ini
+++ b/tox.ini
@@ -105,6 +105,9 @@ usedevelop=true
[testenv:py35-old]
skip_install=True
deps =
+ # Ensure a version of setuptools that supports Python 3.5 is installed.
+ setuptools < 51.0.0
+
# Old automat version for Twisted
Automat == 0.3.0
lxml
|