diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/rest/client/test_rendezvous.py | 401 | ||||
-rw-r--r-- | tests/server.py | 7 | ||||
-rw-r--r-- | tests/unittest.py | 5 |
3 files changed, 411 insertions, 2 deletions
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py index c84704c090..0ab754a11a 100644 --- a/tests/rest/client/test_rendezvous.py +++ b/tests/rest/client/test_rendezvous.py @@ -2,7 +2,7 @@ # This file is licensed under the Affero General Public License (AGPL) version 3. # # Copyright 2022 The Matrix.org Foundation C.I.C. -# Copyright (C) 2023 New Vector, Ltd +# Copyright (C) 2023-2024 New Vector, Ltd # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as @@ -19,9 +19,14 @@ # # +from typing import Dict +from urllib.parse import urlparse + from twisted.test.proto_helpers import MemoryReactor +from twisted.web.resource import Resource from synapse.rest.client import rendezvous +from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource from synapse.server import HomeServer from synapse.util import Clock @@ -42,6 +47,12 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase): self.hs = self.setup_test_homeserver() return self.hs + def create_resource_dict(self) -> Dict[str, Resource]: + return { + **super().create_resource_dict(), + "/_synapse/client/rendezvous": MSC4108RendezvousSessionResource(self.hs), + } + def test_disabled(self) -> None: channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None) self.assertEqual(channel.code, 404) @@ -75,3 +86,391 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase): channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None) self.assertEqual(channel.code, 307) self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"]) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108(self) -> None: + """ + Test the MSC4108 rendezvous endpoint, including: + - Creating a session + - Getting the data back + - Updating the data + - Deleting the data + - ETag handling + """ + # We can post arbitrary data to the endpoint + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_type=b"text/plain", + access_token=None, + ) + self.assertEqual(channel.code, 201) + self.assertSubstring("/_synapse/client/rendezvous/", channel.json_body["url"]) + headers = dict(channel.headers.getAllRawHeaders()) + self.assertIn(b"ETag", headers) + self.assertIn(b"Expires", headers) + self.assertEqual(headers[b"Content-Type"], [b"application/json"]) + self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"]) + self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"]) + self.assertEqual(headers[b"Cache-Control"], [b"no-store"]) + self.assertEqual(headers[b"Pragma"], [b"no-cache"]) + self.assertIn("url", channel.json_body) + self.assertTrue(channel.json_body["url"].startswith("https://")) + + url = urlparse(channel.json_body["url"]) + session_endpoint = url.path + etag = headers[b"ETag"][0] + + # We can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + headers = dict(channel.headers.getAllRawHeaders()) + self.assertEqual(headers[b"ETag"], [etag]) + self.assertIn(b"Expires", headers) + self.assertEqual(headers[b"Content-Type"], [b"text/plain"]) + self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"]) + self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"]) + self.assertEqual(headers[b"Cache-Control"], [b"no-store"]) + self.assertEqual(headers[b"Pragma"], [b"no-cache"]) + self.assertEqual(channel.text_body, "foo=bar") + + # We can make sure the data hasn't changed + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + custom_headers=[("If-None-Match", etag)], + ) + + self.assertEqual(channel.code, 304) + + # We can update the data + channel = self.make_request( + "PUT", + session_endpoint, + "foo=baz", + content_type=b"text/plain", + access_token=None, + custom_headers=[("If-Match", etag)], + ) + + self.assertEqual(channel.code, 202) + headers = dict(channel.headers.getAllRawHeaders()) + old_etag = etag + new_etag = headers[b"ETag"][0] + + # If we try to update it again with the old etag, it should fail + channel = self.make_request( + "PUT", + session_endpoint, + "bar=baz", + content_type=b"text/plain", + access_token=None, + custom_headers=[("If-Match", old_etag)], + ) + + self.assertEqual(channel.code, 412) + self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN") + self.assertEqual( + channel.json_body["org.matrix.msc4108.errcode"], "M_CONCURRENT_WRITE" + ) + + # If we try to get with the old etag, we should get the updated data + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + custom_headers=[("If-None-Match", old_etag)], + ) + + self.assertEqual(channel.code, 200) + headers = dict(channel.headers.getAllRawHeaders()) + self.assertEqual(headers[b"ETag"], [new_etag]) + self.assertEqual(channel.text_body, "foo=baz") + + # We can delete the data + channel = self.make_request( + "DELETE", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 204) + + # If we try to get the data again, it should fail + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 404) + self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND") + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108_expiration(self) -> None: + """ + Test that entries are evicted after a TTL. + """ + # Start a new session + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_type=b"text/plain", + access_token=None, + ) + self.assertEqual(channel.code, 201) + session_endpoint = urlparse(channel.json_body["url"]).path + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.text_body, "foo=bar") + + # Advance the clock, TTL of entries is 1 minute + self.reactor.advance(60) + + # Get the data back, it should be gone + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108_capacity(self) -> None: + """ + Test that a capacity limit is enforced on the rendezvous sessions, as old + entries are evicted at an interval when the limit is reached. + """ + # Start a new session + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_type=b"text/plain", + access_token=None, + ) + self.assertEqual(channel.code, 201) + session_endpoint = urlparse(channel.json_body["url"]).path + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.text_body, "foo=bar") + + # Start a lot of new sessions + for _ in range(100): + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_type=b"text/plain", + access_token=None, + ) + self.assertEqual(channel.code, 201) + + # Get the data back, it should still be there, as the eviction hasn't run yet + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 200) + + # Advance the clock, as it will trigger the eviction + self.reactor.advance(1) + + # Get the data back, it should be gone + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108_hard_capacity(self) -> None: + """ + Test that a hard capacity limit is enforced on the rendezvous sessions, as old + entries are evicted immediately when the limit is reached. + """ + # Start a new session + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_type=b"text/plain", + access_token=None, + ) + self.assertEqual(channel.code, 201) + session_endpoint = urlparse(channel.json_body["url"]).path + # We advance the clock to make sure that this entry is the "lowest" in the session list + self.reactor.advance(1) + + # Sanity check that we can get the data back + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.text_body, "foo=bar") + + # Start a lot of new sessions + for _ in range(200): + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_type=b"text/plain", + access_token=None, + ) + self.assertEqual(channel.code, 201) + + # Get the data back, it should already be gone as we hit the hard limit + channel = self.make_request( + "GET", + session_endpoint, + access_token=None, + ) + + self.assertEqual(channel.code, 404) + + @unittest.skip_unless(HAS_AUTHLIB, "requires authlib") + @override_config( + { + "disable_registration": True, + "experimental_features": { + "msc4108_enabled": True, + "msc3861": { + "enabled": True, + "issuer": "https://issuer", + "client_id": "client_id", + "client_auth_method": "client_secret_post", + "client_secret": "client_secret", + "admin_token": "admin_token_value", + }, + }, + } + ) + def test_msc4108_content_type(self) -> None: + """ + Test that the content-type is restricted to text/plain. + """ + # We cannot post invalid content-type arbitrary data to the endpoint + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_is_form=True, + access_token=None, + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") + + # Make a valid request + channel = self.make_request( + "POST", + msc4108_endpoint, + "foo=bar", + content_type=b"text/plain", + access_token=None, + ) + self.assertEqual(channel.code, 201) + url = urlparse(channel.json_body["url"]) + session_endpoint = url.path + headers = dict(channel.headers.getAllRawHeaders()) + etag = headers[b"ETag"][0] + + # We can't update the data with invalid content-type + channel = self.make_request( + "PUT", + session_endpoint, + "foo=baz", + content_is_form=True, + access_token=None, + custom_headers=[("If-Match", etag)], + ) + self.assertEqual(channel.code, 400) + self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM") diff --git a/tests/server.py b/tests/server.py index 4aaa91e956..434be3d22c 100644 --- a/tests/server.py +++ b/tests/server.py @@ -351,6 +351,7 @@ def make_request( request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, + content_type: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, custom_headers: Optional[Iterable[CustomHeaderType]] = None, @@ -373,6 +374,8 @@ def make_request( with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + content_type: The content-type to use for the request. If not set then will default to + application/json unless content_is_form is true. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. await_result: whether to wait for the request to complete rendering. If true, @@ -436,7 +439,9 @@ def make_request( ) if content: - if content_is_form: + if content_type is not None: + req.requestHeaders.addRawHeader(b"Content-Type", content_type) + elif content_is_form: req.requestHeaders.addRawHeader( b"Content-Type", b"application/x-www-form-urlencoded" ) diff --git a/tests/unittest.py b/tests/unittest.py index 6fe0cd4a2d..e6aad9ed40 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -523,6 +523,7 @@ class HomeserverTestCase(TestCase): request: Type[Request] = SynapseRequest, shorthand: bool = True, federation_auth_origin: Optional[bytes] = None, + content_type: Optional[bytes] = None, content_is_form: bool = False, await_result: bool = True, custom_headers: Optional[Iterable[CustomHeaderType]] = None, @@ -541,6 +542,9 @@ class HomeserverTestCase(TestCase): with the usual REST API path, if it doesn't contain it. federation_auth_origin: if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + + content_type: The content-type to use for the request. If not set then will default to + application/json unless content_is_form is true. content_is_form: Whether the content is URL encoded form data. Adds the 'Content-Type': 'application/x-www-form-urlencoded' header. @@ -566,6 +570,7 @@ class HomeserverTestCase(TestCase): request, shorthand, federation_auth_origin, + content_type, content_is_form, await_result, custom_headers, |