diff --git a/tests/rest/client/v1/test_directory.py b/tests/rest/client/v1/test_directory.py
index 914cf54927..633b7dbda0 100644
--- a/tests/rest/client/v1/test_directory.py
+++ b/tests/rest/client/v1/test_directory.py
@@ -51,30 +51,26 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.user = self.register_user("user", "test")
self.user_tok = self.login("user", "test")
- def test_cannot_set_alias_via_state_event(self):
- self.ensure_user_joined_room()
- url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
- self.room_id,
- self.hs.hostname,
- )
-
- data = {"aliases": [self.random_alias(5)]}
- request_data = json.dumps(data)
-
- request, channel = self.make_request(
- "PUT", url, request_data, access_token=self.user_tok
- )
- self.render(request)
- self.assertEqual(channel.code, 400, channel.result)
+ def test_state_event_not_in_room(self):
+ self.ensure_user_left_room()
+ self.set_alias_via_state_event(403)
def test_directory_endpoint_not_in_room(self):
self.ensure_user_left_room()
self.set_alias_via_directory(403)
+ def test_state_event_in_room_too_long(self):
+ self.ensure_user_joined_room()
+ self.set_alias_via_state_event(400, alias_length=256)
+
def test_directory_in_room_too_long(self):
self.ensure_user_joined_room()
self.set_alias_via_directory(400, alias_length=256)
+ def test_state_event_in_room(self):
+ self.ensure_user_joined_room()
+ self.set_alias_via_state_event(200)
+
def test_directory_in_room(self):
self.ensure_user_joined_room()
self.set_alias_via_directory(200)
@@ -106,6 +102,21 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
+ def set_alias_via_state_event(self, expected_code, alias_length=5):
+ url = "/_matrix/client/r0/rooms/%s/state/m.room.aliases/%s" % (
+ self.room_id,
+ self.hs.hostname,
+ )
+
+ data = {"aliases": [self.random_alias(alias_length)]}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.user_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
def set_alias_via_directory(self, expected_code, alias_length=5):
url = "/_matrix/client/r0/directory/room/%s" % self.random_alias(alias_length)
data = {"room_id": self.room_id}
diff --git a/tests/rest/client/v1/test_login.py b/tests/rest/client/v1/test_login.py
index eae5411325..da2c9bfa1e 100644
--- a/tests/rest/client/v1/test_login.py
+++ b/tests/rest/client/v1/test_login.py
@@ -1,4 +1,7 @@
import json
+import urllib.parse
+
+from mock import Mock
import synapse.rest.admin
from synapse.rest.client.v1 import login
@@ -252,3 +255,111 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEquals(channel.code, 200, channel.result)
+
+
+class CASRedirectConfirmTestCase(unittest.HomeserverTestCase):
+
+ servlets = [
+ login.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor, clock):
+ self.base_url = "https://matrix.goodserver.com/"
+ self.redirect_path = "_synapse/client/login/sso/redirect/confirm"
+
+ config = self.default_config()
+ config["cas_config"] = {
+ "enabled": True,
+ "server_url": "https://fake.test",
+ "service_url": "https://matrix.goodserver.com:8448",
+ }
+
+ async def get_raw(uri, args):
+ """Return an example response payload from a call to the `/proxyValidate`
+ endpoint of a CAS server, copied from
+ https://apereo.github.io/cas/5.0.x/protocol/CAS-Protocol-V2-Specification.html#26-proxyvalidate-cas-20
+
+ This needs to be returned by an async function (as opposed to set as the
+ mock's return value) because the corresponding Synapse code awaits on it.
+ """
+ return """
+ <cas:serviceResponse xmlns:cas='http://www.yale.edu/tp/cas'>
+ <cas:authenticationSuccess>
+ <cas:user>username</cas:user>
+ <cas:proxyGrantingTicket>PGTIOU-84678-8a9d...</cas:proxyGrantingTicket>
+ <cas:proxies>
+ <cas:proxy>https://proxy2/pgtUrl</cas:proxy>
+ <cas:proxy>https://proxy1/pgtUrl</cas:proxy>
+ </cas:proxies>
+ </cas:authenticationSuccess>
+ </cas:serviceResponse>
+ """
+
+ mocked_http_client = Mock(spec=["get_raw"])
+ mocked_http_client.get_raw.side_effect = get_raw
+
+ self.hs = self.setup_test_homeserver(
+ config=config, proxied_http_client=mocked_http_client,
+ )
+
+ return self.hs
+
+ def test_cas_redirect_confirm(self):
+ """Tests that the SSO login flow serves a confirmation page before redirecting a
+ user to the redirect URL.
+ """
+ base_url = "/_matrix/client/r0/login/cas/ticket?redirectUrl"
+ redirect_url = "https://dodgy-site.com/"
+
+ url_parts = list(urllib.parse.urlparse(base_url))
+ query = dict(urllib.parse.parse_qsl(url_parts[4]))
+ query.update({"redirectUrl": redirect_url})
+ query.update({"ticket": "ticket"})
+ url_parts[4] = urllib.parse.urlencode(query)
+ cas_ticket_url = urllib.parse.urlunparse(url_parts)
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ # Test that the response is HTML.
+ self.assertEqual(channel.code, 200)
+ content_type_header_value = ""
+ for header in channel.result.get("headers", []):
+ if header[0] == b"Content-Type":
+ content_type_header_value = header[1].decode("utf8")
+
+ self.assertTrue(content_type_header_value.startswith("text/html"))
+
+ # Test that the body isn't empty.
+ self.assertTrue(len(channel.result["body"]) > 0)
+
+ # And that it contains our redirect link
+ self.assertIn(redirect_url, channel.result["body"].decode("UTF-8"))
+
+ @override_config(
+ {
+ "sso": {
+ "client_whitelist": [
+ "https://legit-site.com/",
+ "https://other-site.com/",
+ ]
+ }
+ }
+ )
+ def test_cas_redirect_whitelisted(self):
+ """Tests that the SSO login flow serves a redirect to a whitelisted url
+ """
+ redirect_url = "https://legit-site.com/"
+ cas_ticket_url = (
+ "/_matrix/client/r0/login/cas/ticket?redirectUrl=%s&ticket=ticket"
+ % (urllib.parse.quote(redirect_url))
+ )
+
+ # Get Synapse to call the fake CAS and serve the template.
+ request, channel = self.make_request("GET", cas_ticket_url)
+ self.render(request)
+
+ self.assertEqual(channel.code, 302)
+ location_headers = channel.headers.getRawHeaders("Location")
+ self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py
index 2f3df5f88f..7dd86d0c27 100644
--- a/tests/rest/client/v1/test_rooms.py
+++ b/tests/rest/client/v1/test_rooms.py
@@ -1821,3 +1821,163 @@ class RoomAliasListTestCase(unittest.HomeserverTestCase):
)
self.render(request)
self.assertEqual(channel.code, expected_code, channel.result)
+
+
+class RoomCanonicalAliasTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ directory.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, homeserver):
+ self.room_owner = self.register_user("room_owner", "test")
+ self.room_owner_tok = self.login("room_owner", "test")
+
+ self.room_id = self.helper.create_room_as(
+ self.room_owner, tok=self.room_owner_tok
+ )
+
+ self.alias = "#alias:test"
+ self._set_alias_via_directory(self.alias)
+
+ def _set_alias_via_directory(self, alias: str, expected_code: int = 200):
+ url = "/_matrix/client/r0/directory/room/" + alias
+ data = {"room_id": self.room_id}
+ request_data = json.dumps(data)
+
+ request, channel = self.make_request(
+ "PUT", url, request_data, access_token=self.room_owner_tok
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+
+ def _get_canonical_alias(self, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "GET",
+ "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+ access_token=self.room_owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ return res
+
+ def _set_canonical_alias(self, content: str, expected_code: int = 200) -> JsonDict:
+ """Calls the endpoint under test. returns the json response object."""
+ request, channel = self.make_request(
+ "PUT",
+ "rooms/%s/state/m.room.canonical_alias" % (self.room_id,),
+ json.dumps(content),
+ access_token=self.room_owner_tok,
+ )
+ self.render(request)
+ self.assertEqual(channel.code, expected_code, channel.result)
+ res = channel.json_body
+ self.assertIsInstance(res, dict)
+ return res
+
+ def test_canonical_alias(self):
+ """Test a basic alias message."""
+ # There is no canonical alias to start with.
+ self._get_canonical_alias(expected_code=404)
+
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias})
+
+ # Now remove the alias.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_alt_aliases(self):
+ """Test a canonical alias message with alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alt_aliases": [self.alias]})
+
+ # Now remove the alt_aliases.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_alias_alt_aliases(self):
+ """Test a canonical alias message with an alias and alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Now remove the alias and alt_aliases.
+ self._set_canonical_alias({})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {})
+
+ def test_partial_modify(self):
+ """Test removing only the alt_aliases."""
+ # Create an alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Now remove the alt_aliases.
+ self._set_canonical_alias({"alias": self.alias})
+
+ # There is an alias event, but it is empty.
+ res = self._get_canonical_alias()
+ self.assertEqual(res, {"alias": self.alias})
+
+ def test_add_alias(self):
+ """Test removing only the alt_aliases."""
+ # Create an additional alias.
+ second_alias = "#second:test"
+ self._set_alias_via_directory(second_alias)
+
+ # Add the canonical alias.
+ self._set_canonical_alias({"alias": self.alias, "alt_aliases": [self.alias]})
+
+ # Then add the second alias.
+ self._set_canonical_alias(
+ {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+ )
+
+ # Canonical alias now exists!
+ res = self._get_canonical_alias()
+ self.assertEqual(
+ res, {"alias": self.alias, "alt_aliases": [self.alias, second_alias]}
+ )
+
+ def test_bad_data(self):
+ """Invalid data for alt_aliases should cause errors."""
+ self._set_canonical_alias({"alt_aliases": "@bad:test"}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": None}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": 0}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": 1}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": False}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": True}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": {}}, expected_code=400)
+
+ def test_bad_alias(self):
+ """An alias which does not point to the room raises a SynapseError."""
+ self._set_canonical_alias({"alias": "@unknown:test"}, expected_code=400)
+ self._set_canonical_alias({"alt_aliases": ["@unknown:test"]}, expected_code=400)
|