summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/rest/client/test_rendezvous.py401
-rw-r--r--tests/server.py7
-rw-r--r--tests/unittest.py5
3 files changed, 411 insertions, 2 deletions
diff --git a/tests/rest/client/test_rendezvous.py b/tests/rest/client/test_rendezvous.py
index c84704c090..0ab754a11a 100644
--- a/tests/rest/client/test_rendezvous.py
+++ b/tests/rest/client/test_rendezvous.py
@@ -2,7 +2,7 @@
 # This file is licensed under the Affero General Public License (AGPL) version 3.
 #
 # Copyright 2022 The Matrix.org Foundation C.I.C.
-# Copyright (C) 2023 New Vector, Ltd
+# Copyright (C) 2023-2024 New Vector, Ltd
 #
 # This program is free software: you can redistribute it and/or modify
 # it under the terms of the GNU Affero General Public License as
@@ -19,9 +19,14 @@
 #
 #
 
+from typing import Dict
+from urllib.parse import urlparse
+
 from twisted.test.proto_helpers import MemoryReactor
+from twisted.web.resource import Resource
 
 from synapse.rest.client import rendezvous
+from synapse.rest.synapse.client.rendezvous import MSC4108RendezvousSessionResource
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -42,6 +47,12 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
         self.hs = self.setup_test_homeserver()
         return self.hs
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        return {
+            **super().create_resource_dict(),
+            "/_synapse/client/rendezvous": MSC4108RendezvousSessionResource(self.hs),
+        }
+
     def test_disabled(self) -> None:
         channel = self.make_request("POST", msc3886_endpoint, {}, access_token=None)
         self.assertEqual(channel.code, 404)
@@ -75,3 +86,391 @@ class RendezvousServletTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("POST", msc4108_endpoint, {}, access_token=None)
         self.assertEqual(channel.code, 307)
         self.assertEqual(channel.headers.getRawHeaders("Location"), ["https://asd"])
+
+    @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
+    @override_config(
+        {
+            "disable_registration": True,
+            "experimental_features": {
+                "msc4108_enabled": True,
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": "https://issuer",
+                    "client_id": "client_id",
+                    "client_auth_method": "client_secret_post",
+                    "client_secret": "client_secret",
+                    "admin_token": "admin_token_value",
+                },
+            },
+        }
+    )
+    def test_msc4108(self) -> None:
+        """
+        Test the MSC4108 rendezvous endpoint, including:
+            - Creating a session
+            - Getting the data back
+            - Updating the data
+            - Deleting the data
+            - ETag handling
+        """
+        # We can post arbitrary data to the endpoint
+        channel = self.make_request(
+            "POST",
+            msc4108_endpoint,
+            "foo=bar",
+            content_type=b"text/plain",
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 201)
+        self.assertSubstring("/_synapse/client/rendezvous/", channel.json_body["url"])
+        headers = dict(channel.headers.getAllRawHeaders())
+        self.assertIn(b"ETag", headers)
+        self.assertIn(b"Expires", headers)
+        self.assertEqual(headers[b"Content-Type"], [b"application/json"])
+        self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
+        self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
+        self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
+        self.assertEqual(headers[b"Pragma"], [b"no-cache"])
+        self.assertIn("url", channel.json_body)
+        self.assertTrue(channel.json_body["url"].startswith("https://"))
+
+        url = urlparse(channel.json_body["url"])
+        session_endpoint = url.path
+        etag = headers[b"ETag"][0]
+
+        # We can get the data back
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+
+        self.assertEqual(channel.code, 200)
+        headers = dict(channel.headers.getAllRawHeaders())
+        self.assertEqual(headers[b"ETag"], [etag])
+        self.assertIn(b"Expires", headers)
+        self.assertEqual(headers[b"Content-Type"], [b"text/plain"])
+        self.assertEqual(headers[b"Access-Control-Allow-Origin"], [b"*"])
+        self.assertEqual(headers[b"Access-Control-Expose-Headers"], [b"etag"])
+        self.assertEqual(headers[b"Cache-Control"], [b"no-store"])
+        self.assertEqual(headers[b"Pragma"], [b"no-cache"])
+        self.assertEqual(channel.text_body, "foo=bar")
+
+        # We can make sure the data hasn't changed
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+            custom_headers=[("If-None-Match", etag)],
+        )
+
+        self.assertEqual(channel.code, 304)
+
+        # We can update the data
+        channel = self.make_request(
+            "PUT",
+            session_endpoint,
+            "foo=baz",
+            content_type=b"text/plain",
+            access_token=None,
+            custom_headers=[("If-Match", etag)],
+        )
+
+        self.assertEqual(channel.code, 202)
+        headers = dict(channel.headers.getAllRawHeaders())
+        old_etag = etag
+        new_etag = headers[b"ETag"][0]
+
+        # If we try to update it again with the old etag, it should fail
+        channel = self.make_request(
+            "PUT",
+            session_endpoint,
+            "bar=baz",
+            content_type=b"text/plain",
+            access_token=None,
+            custom_headers=[("If-Match", old_etag)],
+        )
+
+        self.assertEqual(channel.code, 412)
+        self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN")
+        self.assertEqual(
+            channel.json_body["org.matrix.msc4108.errcode"], "M_CONCURRENT_WRITE"
+        )
+
+        # If we try to get with the old etag, we should get the updated data
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+            custom_headers=[("If-None-Match", old_etag)],
+        )
+
+        self.assertEqual(channel.code, 200)
+        headers = dict(channel.headers.getAllRawHeaders())
+        self.assertEqual(headers[b"ETag"], [new_etag])
+        self.assertEqual(channel.text_body, "foo=baz")
+
+        # We can delete the data
+        channel = self.make_request(
+            "DELETE",
+            session_endpoint,
+            access_token=None,
+        )
+
+        self.assertEqual(channel.code, 204)
+
+        # If we try to get the data again, it should fail
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+
+        self.assertEqual(channel.code, 404)
+        self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
+
+    @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
+    @override_config(
+        {
+            "disable_registration": True,
+            "experimental_features": {
+                "msc4108_enabled": True,
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": "https://issuer",
+                    "client_id": "client_id",
+                    "client_auth_method": "client_secret_post",
+                    "client_secret": "client_secret",
+                    "admin_token": "admin_token_value",
+                },
+            },
+        }
+    )
+    def test_msc4108_expiration(self) -> None:
+        """
+        Test that entries are evicted after a TTL.
+        """
+        # Start a new session
+        channel = self.make_request(
+            "POST",
+            msc4108_endpoint,
+            "foo=bar",
+            content_type=b"text/plain",
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 201)
+        session_endpoint = urlparse(channel.json_body["url"]).path
+
+        # Sanity check that we can get the data back
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.text_body, "foo=bar")
+
+        # Advance the clock, TTL of entries is 1 minute
+        self.reactor.advance(60)
+
+        # Get the data back, it should be gone
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 404)
+
+    @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
+    @override_config(
+        {
+            "disable_registration": True,
+            "experimental_features": {
+                "msc4108_enabled": True,
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": "https://issuer",
+                    "client_id": "client_id",
+                    "client_auth_method": "client_secret_post",
+                    "client_secret": "client_secret",
+                    "admin_token": "admin_token_value",
+                },
+            },
+        }
+    )
+    def test_msc4108_capacity(self) -> None:
+        """
+        Test that a capacity limit is enforced on the rendezvous sessions, as old
+        entries are evicted at an interval when the limit is reached.
+        """
+        # Start a new session
+        channel = self.make_request(
+            "POST",
+            msc4108_endpoint,
+            "foo=bar",
+            content_type=b"text/plain",
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 201)
+        session_endpoint = urlparse(channel.json_body["url"]).path
+
+        # Sanity check that we can get the data back
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.text_body, "foo=bar")
+
+        # Start a lot of new sessions
+        for _ in range(100):
+            channel = self.make_request(
+                "POST",
+                msc4108_endpoint,
+                "foo=bar",
+                content_type=b"text/plain",
+                access_token=None,
+            )
+            self.assertEqual(channel.code, 201)
+
+        # Get the data back, it should still be there, as the eviction hasn't run yet
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+
+        self.assertEqual(channel.code, 200)
+
+        # Advance the clock, as it will trigger the eviction
+        self.reactor.advance(1)
+
+        # Get the data back, it should be gone
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+
+    @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
+    @override_config(
+        {
+            "disable_registration": True,
+            "experimental_features": {
+                "msc4108_enabled": True,
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": "https://issuer",
+                    "client_id": "client_id",
+                    "client_auth_method": "client_secret_post",
+                    "client_secret": "client_secret",
+                    "admin_token": "admin_token_value",
+                },
+            },
+        }
+    )
+    def test_msc4108_hard_capacity(self) -> None:
+        """
+        Test that a hard capacity limit is enforced on the rendezvous sessions, as old
+        entries are evicted immediately when the limit is reached.
+        """
+        # Start a new session
+        channel = self.make_request(
+            "POST",
+            msc4108_endpoint,
+            "foo=bar",
+            content_type=b"text/plain",
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 201)
+        session_endpoint = urlparse(channel.json_body["url"]).path
+        # We advance the clock to make sure that this entry is the "lowest" in the session list
+        self.reactor.advance(1)
+
+        # Sanity check that we can get the data back
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.text_body, "foo=bar")
+
+        # Start a lot of new sessions
+        for _ in range(200):
+            channel = self.make_request(
+                "POST",
+                msc4108_endpoint,
+                "foo=bar",
+                content_type=b"text/plain",
+                access_token=None,
+            )
+            self.assertEqual(channel.code, 201)
+
+        # Get the data back, it should already be gone as we hit the hard limit
+        channel = self.make_request(
+            "GET",
+            session_endpoint,
+            access_token=None,
+        )
+
+        self.assertEqual(channel.code, 404)
+
+    @unittest.skip_unless(HAS_AUTHLIB, "requires authlib")
+    @override_config(
+        {
+            "disable_registration": True,
+            "experimental_features": {
+                "msc4108_enabled": True,
+                "msc3861": {
+                    "enabled": True,
+                    "issuer": "https://issuer",
+                    "client_id": "client_id",
+                    "client_auth_method": "client_secret_post",
+                    "client_secret": "client_secret",
+                    "admin_token": "admin_token_value",
+                },
+            },
+        }
+    )
+    def test_msc4108_content_type(self) -> None:
+        """
+        Test that the content-type is restricted to text/plain.
+        """
+        # We cannot post invalid content-type arbitrary data to the endpoint
+        channel = self.make_request(
+            "POST",
+            msc4108_endpoint,
+            "foo=bar",
+            content_is_form=True,
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 400)
+        self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
+
+        # Make a valid request
+        channel = self.make_request(
+            "POST",
+            msc4108_endpoint,
+            "foo=bar",
+            content_type=b"text/plain",
+            access_token=None,
+        )
+        self.assertEqual(channel.code, 201)
+        url = urlparse(channel.json_body["url"])
+        session_endpoint = url.path
+        headers = dict(channel.headers.getAllRawHeaders())
+        etag = headers[b"ETag"][0]
+
+        # We can't update the data with invalid content-type
+        channel = self.make_request(
+            "PUT",
+            session_endpoint,
+            "foo=baz",
+            content_is_form=True,
+            access_token=None,
+            custom_headers=[("If-Match", etag)],
+        )
+        self.assertEqual(channel.code, 400)
+        self.assertEqual(channel.json_body["errcode"], "M_INVALID_PARAM")
diff --git a/tests/server.py b/tests/server.py
index 4aaa91e956..434be3d22c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -351,6 +351,7 @@ def make_request(
     request: Type[Request] = SynapseRequest,
     shorthand: bool = True,
     federation_auth_origin: Optional[bytes] = None,
+    content_type: Optional[bytes] = None,
     content_is_form: bool = False,
     await_result: bool = True,
     custom_headers: Optional[Iterable[CustomHeaderType]] = None,
@@ -373,6 +374,8 @@ def make_request(
             with the usual REST API path, if it doesn't contain it.
         federation_auth_origin: if set to not-None, we will add a fake
             Authorization header pretenting to be the given server name.
+        content_type: The content-type to use for the request. If not set then will default to
+            application/json unless content_is_form is true.
         content_is_form: Whether the content is URL encoded form data. Adds the
             'Content-Type': 'application/x-www-form-urlencoded' header.
         await_result: whether to wait for the request to complete rendering. If true,
@@ -436,7 +439,9 @@ def make_request(
         )
 
     if content:
-        if content_is_form:
+        if content_type is not None:
+            req.requestHeaders.addRawHeader(b"Content-Type", content_type)
+        elif content_is_form:
             req.requestHeaders.addRawHeader(
                 b"Content-Type", b"application/x-www-form-urlencoded"
             )
diff --git a/tests/unittest.py b/tests/unittest.py
index 6fe0cd4a2d..e6aad9ed40 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -523,6 +523,7 @@ class HomeserverTestCase(TestCase):
         request: Type[Request] = SynapseRequest,
         shorthand: bool = True,
         federation_auth_origin: Optional[bytes] = None,
+        content_type: Optional[bytes] = None,
         content_is_form: bool = False,
         await_result: bool = True,
         custom_headers: Optional[Iterable[CustomHeaderType]] = None,
@@ -541,6 +542,9 @@ class HomeserverTestCase(TestCase):
             with the usual REST API path, if it doesn't contain it.
             federation_auth_origin: if set to not-None, we will add a fake
                 Authorization header pretenting to be the given server name.
+
+            content_type: The content-type to use for the request. If not set then will default to
+                application/json unless content_is_form is true.
             content_is_form: Whether the content is URL encoded form data. Adds the
                 'Content-Type': 'application/x-www-form-urlencoded' header.
 
@@ -566,6 +570,7 @@ class HomeserverTestCase(TestCase):
             request,
             shorthand,
             federation_auth_origin,
+            content_type,
             content_is_form,
             await_result,
             custom_headers,