summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-10-20 11:32:47 -0400
committerGitHub <noreply@github.com>2022-10-20 11:32:47 -0400
commit755bfeee3a1ac7077045ab9e5a994b6ca89afba3 (patch)
tree470b311d29d49a5b3910a9fa06f81c033f6abe71
parentStop returning `unsigned.invite_room_state` in `PUT /_matrix/federation/v2/in... (diff)
downloadsynapse-755bfeee3a1ac7077045ab9e5a994b6ca89afba3.tar.xz
Use servlets for /key/ endpoints. (#14229)
To fix the response for unknown endpoints under that prefix.

See MSC3743.
-rw-r--r--changelog.d/14229.misc1
-rw-r--r--synapse/api/urls.py2
-rw-r--r--synapse/app/generic_worker.py20
-rw-r--r--synapse/app/homeserver.py26
-rw-r--r--synapse/rest/key/v2/__init__.py19
-rw-r--r--synapse/rest/key/v2/local_key_resource.py22
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py73
-rw-r--r--tests/app/test_openid_listener.py2
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
9 files changed, 86 insertions, 83 deletions
diff --git a/changelog.d/14229.misc b/changelog.d/14229.misc
new file mode 100644
index 0000000000..b9cd9a34d5
--- /dev/null
+++ b/changelog.d/14229.misc
@@ -0,0 +1 @@
+Refactor `/key/` endpoints to use `RestServlet` classes.
diff --git a/synapse/api/urls.py b/synapse/api/urls.py
index bd49fa6a5f..a918579f50 100644
--- a/synapse/api/urls.py
+++ b/synapse/api/urls.py
@@ -28,7 +28,7 @@ FEDERATION_V1_PREFIX = FEDERATION_PREFIX + "/v1"
 FEDERATION_V2_PREFIX = FEDERATION_PREFIX + "/v2"
 FEDERATION_UNSTABLE_PREFIX = FEDERATION_PREFIX + "/unstable"
 STATIC_PREFIX = "/_matrix/static"
-SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
+SERVER_KEY_PREFIX = "/_matrix/key"
 MEDIA_R0_PREFIX = "/_matrix/media/r0"
 MEDIA_V3_PREFIX = "/_matrix/media/v3"
 LEGACY_MEDIA_PREFIX = "/_matrix/media/v1"
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index dc49840f73..2a9f039367 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -28,7 +28,7 @@ from synapse.api.urls import (
     LEGACY_MEDIA_PREFIX,
     MEDIA_R0_PREFIX,
     MEDIA_V3_PREFIX,
-    SERVER_KEY_V2_PREFIX,
+    SERVER_KEY_PREFIX,
 )
 from synapse.app import _base
 from synapse.app._base import (
@@ -89,7 +89,7 @@ from synapse.rest.client.register import (
     RegistrationTokenValidityRestServlet,
 )
 from synapse.rest.health import HealthResource
-from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.key.v2 import KeyResource
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.rest.well_known import well_known_resource
 from synapse.server import HomeServer
@@ -325,13 +325,13 @@ class GenericWorkerServer(HomeServer):
 
                     presence.register_servlets(self, resource)
 
-                    resources.update({CLIENT_API_PREFIX: resource})
+                    resources[CLIENT_API_PREFIX] = resource
 
                     resources.update(build_synapse_client_resource_tree(self))
-                    resources.update({"/.well-known": well_known_resource(self)})
+                    resources["/.well-known"] = well_known_resource(self)
 
                 elif name == "federation":
-                    resources.update({FEDERATION_PREFIX: TransportLayerServer(self)})
+                    resources[FEDERATION_PREFIX] = TransportLayerServer(self)
                 elif name == "media":
                     if self.config.media.can_load_media_repo:
                         media_repo = self.get_media_repository_resource()
@@ -359,16 +359,12 @@ class GenericWorkerServer(HomeServer):
                     # Only load the openid resource separately if federation resource
                     # is not specified since federation resource includes openid
                     # resource.
-                    resources.update(
-                        {
-                            FEDERATION_PREFIX: TransportLayerServer(
-                                self, servlet_groups=["openid"]
-                            )
-                        }
+                    resources[FEDERATION_PREFIX] = TransportLayerServer(
+                        self, servlet_groups=["openid"]
                     )
 
                 if name in ["keys", "federation"]:
-                    resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
+                    resources[SERVER_KEY_PREFIX] = KeyResource(self)
 
                 if name == "replication":
                     resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index 883f2fd2ec..de3f08876f 100644
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -31,7 +31,7 @@ from synapse.api.urls import (
     LEGACY_MEDIA_PREFIX,
     MEDIA_R0_PREFIX,
     MEDIA_V3_PREFIX,
-    SERVER_KEY_V2_PREFIX,
+    SERVER_KEY_PREFIX,
     STATIC_PREFIX,
 )
 from synapse.app import _base
@@ -60,7 +60,7 @@ from synapse.replication.http import REPLICATION_PREFIX, ReplicationRestResource
 from synapse.rest import ClientRestResource
 from synapse.rest.admin import AdminRestResource
 from synapse.rest.health import HealthResource
-from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.key.v2 import KeyResource
 from synapse.rest.synapse.client import build_synapse_client_resource_tree
 from synapse.rest.well_known import well_known_resource
 from synapse.server import HomeServer
@@ -215,30 +215,22 @@ class SynapseHomeServer(HomeServer):
             consent_resource: Resource = ConsentResource(self)
             if compress:
                 consent_resource = gz_wrap(consent_resource)
-            resources.update({"/_matrix/consent": consent_resource})
+            resources["/_matrix/consent"] = consent_resource
 
         if name == "federation":
             federation_resource: Resource = TransportLayerServer(self)
             if compress:
                 federation_resource = gz_wrap(federation_resource)
-            resources.update({FEDERATION_PREFIX: federation_resource})
+            resources[FEDERATION_PREFIX] = federation_resource
 
         if name == "openid":
-            resources.update(
-                {
-                    FEDERATION_PREFIX: TransportLayerServer(
-                        self, servlet_groups=["openid"]
-                    )
-                }
+            resources[FEDERATION_PREFIX] = TransportLayerServer(
+                self, servlet_groups=["openid"]
             )
 
         if name in ["static", "client"]:
-            resources.update(
-                {
-                    STATIC_PREFIX: StaticResource(
-                        os.path.join(os.path.dirname(synapse.__file__), "static")
-                    )
-                }
+            resources[STATIC_PREFIX] = StaticResource(
+                os.path.join(os.path.dirname(synapse.__file__), "static")
             )
 
         if name in ["media", "federation", "client"]:
@@ -257,7 +249,7 @@ class SynapseHomeServer(HomeServer):
                 )
 
         if name in ["keys", "federation"]:
-            resources[SERVER_KEY_V2_PREFIX] = KeyApiV2Resource(self)
+            resources[SERVER_KEY_PREFIX] = KeyResource(self)
 
         if name == "metrics" and self.config.metrics.enable_metrics:
             metrics_resource: Resource = MetricsResource(RegistryProxy)
diff --git a/synapse/rest/key/v2/__init__.py b/synapse/rest/key/v2/__init__.py
index 7f8c1de1ff..26403facb8 100644
--- a/synapse/rest/key/v2/__init__.py
+++ b/synapse/rest/key/v2/__init__.py
@@ -14,17 +14,20 @@
 
 from typing import TYPE_CHECKING
 
-from twisted.web.resource import Resource
-
-from .local_key_resource import LocalKey
-from .remote_key_resource import RemoteKey
+from synapse.http.server import HttpServer, JsonResource
+from synapse.rest.key.v2.local_key_resource import LocalKey
+from synapse.rest.key.v2.remote_key_resource import RemoteKey
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 
-class KeyApiV2Resource(Resource):
+class KeyResource(JsonResource):
     def __init__(self, hs: "HomeServer"):
-        Resource.__init__(self)
-        self.putChild(b"server", LocalKey(hs))
-        self.putChild(b"query", RemoteKey(hs))
+        super().__init__(hs, canonical_json=True)
+        self.register_servlets(self, hs)
+
+    @staticmethod
+    def register_servlets(http_server: HttpServer, hs: "HomeServer") -> None:
+        LocalKey(hs).register(http_server)
+        RemoteKey(hs).register(http_server)
diff --git a/synapse/rest/key/v2/local_key_resource.py b/synapse/rest/key/v2/local_key_resource.py
index 095993415c..d03e728d42 100644
--- a/synapse/rest/key/v2/local_key_resource.py
+++ b/synapse/rest/key/v2/local_key_resource.py
@@ -13,16 +13,15 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Optional
+import re
+from typing import TYPE_CHECKING, Optional, Tuple
 
-from canonicaljson import encode_canonical_json
 from signedjson.sign import sign_json
 from unpaddedbase64 import encode_base64
 
-from twisted.web.resource import Resource
+from twisted.web.server import Request
 
-from synapse.http.server import respond_with_json_bytes
-from synapse.http.site import SynapseRequest
+from synapse.http.servlet import RestServlet
 from synapse.types import JsonDict
 
 if TYPE_CHECKING:
@@ -31,7 +30,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class LocalKey(Resource):
+class LocalKey(RestServlet):
     """HTTP resource containing encoding the TLS X.509 certificate and NACL
     signature verification keys for this server::
 
@@ -61,18 +60,17 @@ class LocalKey(Resource):
         }
     """
 
-    isLeaf = True
+    PATTERNS = (re.compile("^/_matrix/key/v2/server(/(?P<key_id>[^/]*))?$"),)
 
     def __init__(self, hs: "HomeServer"):
         self.config = hs.config
         self.clock = hs.get_clock()
         self.update_response_body(self.clock.time_msec())
-        Resource.__init__(self)
 
     def update_response_body(self, time_now_msec: int) -> None:
         refresh_interval = self.config.key.key_refresh_interval
         self.valid_until_ts = int(time_now_msec + refresh_interval)
-        self.response_body = encode_canonical_json(self.response_json_object())
+        self.response_body = self.response_json_object()
 
     def response_json_object(self) -> JsonDict:
         verify_keys = {}
@@ -99,9 +97,11 @@ class LocalKey(Resource):
             json_object = sign_json(json_object, self.config.server.server_name, key)
         return json_object
 
-    def render_GET(self, request: SynapseRequest) -> Optional[int]:
+    def on_GET(
+        self, request: Request, key_id: Optional[str] = None
+    ) -> Tuple[int, JsonDict]:
         time_now = self.clock.time_msec()
         # Update the expiry time if less than half the interval remains.
         if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts:
             self.update_response_body(time_now)
-        return respond_with_json_bytes(request, 200, self.response_body)
+        return 200, self.response_body
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 7f8ad29566..19820886f5 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -13,15 +13,20 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Set
+import re
+from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
 
 from signedjson.sign import sign_json
 
-from synapse.api.errors import Codes, SynapseError
+from twisted.web.server import Request
+
 from synapse.crypto.keyring import ServerKeyFetcher
-from synapse.http.server import DirectServeJsonResource, respond_with_json
-from synapse.http.servlet import parse_integer, parse_json_object_from_request
-from synapse.http.site import SynapseRequest
+from synapse.http.server import HttpServer
+from synapse.http.servlet import (
+    RestServlet,
+    parse_integer,
+    parse_json_object_from_request,
+)
 from synapse.types import JsonDict
 from synapse.util import json_decoder
 from synapse.util.async_helpers import yieldable_gather_results
@@ -32,7 +37,7 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class RemoteKey(DirectServeJsonResource):
+class RemoteKey(RestServlet):
     """HTTP resource for retrieving the TLS certificate and NACL signature
     verification keys for a collection of servers. Checks that the reported
     X.509 TLS certificate matches the one used in the HTTPS connection. Checks
@@ -88,11 +93,7 @@ class RemoteKey(DirectServeJsonResource):
     }
     """
 
-    isLeaf = True
-
     def __init__(self, hs: "HomeServer"):
-        super().__init__()
-
         self.fetcher = ServerKeyFetcher(hs)
         self.store = hs.get_datastores().main
         self.clock = hs.get_clock()
@@ -101,36 +102,48 @@ class RemoteKey(DirectServeJsonResource):
         )
         self.config = hs.config
 
-    async def _async_render_GET(self, request: SynapseRequest) -> None:
-        assert request.postpath is not None
-        if len(request.postpath) == 1:
-            (server,) = request.postpath
-            query: dict = {server.decode("ascii"): {}}
-        elif len(request.postpath) == 2:
-            server, key_id = request.postpath
+    def register(self, http_server: HttpServer) -> None:
+        http_server.register_paths(
+            "GET",
+            (
+                re.compile(
+                    "^/_matrix/key/v2/query/(?P<server>[^/]*)(/(?P<key_id>[^/]*))?$"
+                ),
+            ),
+            self.on_GET,
+            self.__class__.__name__,
+        )
+        http_server.register_paths(
+            "POST",
+            (re.compile("^/_matrix/key/v2/query$"),),
+            self.on_POST,
+            self.__class__.__name__,
+        )
+
+    async def on_GET(
+        self, request: Request, server: str, key_id: Optional[str] = None
+    ) -> Tuple[int, JsonDict]:
+        if server and key_id:
             minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
             arguments = {}
             if minimum_valid_until_ts is not None:
                 arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
-            query = {server.decode("ascii"): {key_id.decode("ascii"): arguments}}
+            query = {server: {key_id: arguments}}
         else:
-            raise SynapseError(404, "Not found %r" % request.postpath, Codes.NOT_FOUND)
+            query = {server: {}}
 
-        await self.query_keys(request, query, query_remote_on_cache_miss=True)
+        return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
 
-    async def _async_render_POST(self, request: SynapseRequest) -> None:
+    async def on_POST(self, request: Request) -> Tuple[int, JsonDict]:
         content = parse_json_object_from_request(request)
 
         query = content["server_keys"]
 
-        await self.query_keys(request, query, query_remote_on_cache_miss=True)
+        return 200, await self.query_keys(query, query_remote_on_cache_miss=True)
 
     async def query_keys(
-        self,
-        request: SynapseRequest,
-        query: JsonDict,
-        query_remote_on_cache_miss: bool = False,
-    ) -> None:
+        self, query: JsonDict, query_remote_on_cache_miss: bool = False
+    ) -> JsonDict:
         logger.info("Handling query for keys %r", query)
 
         store_queries = []
@@ -232,7 +245,7 @@ class RemoteKey(DirectServeJsonResource):
                     for server_name, keys in cache_misses.items()
                 ),
             )
-            await self.query_keys(request, query, query_remote_on_cache_miss=False)
+            return await self.query_keys(query, query_remote_on_cache_miss=False)
         else:
             signed_keys = []
             for key_json_raw in json_results:
@@ -244,6 +257,4 @@ class RemoteKey(DirectServeJsonResource):
 
                 signed_keys.append(key_json)
 
-            response = {"server_keys": signed_keys}
-
-            respond_with_json(request, 200, response, canonical_json=True)
+            return {"server_keys": signed_keys}
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index c7dae58eb5..8d03da7f96 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
         self.assertEqual(channel.code, 401)
 
 
-@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock())
+@patch("synapse.app.homeserver.KeyResource", new=Mock())
 class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index ac0ac06b7e..7f1fba1086 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource
 
 from synapse.crypto.keyring import PerspectivesKeyFetcher
 from synapse.http.site import SynapseRequest
-from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.key.v2 import KeyResource
 from synapse.server import HomeServer
 from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict
@@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
 
     def create_test_resource(self) -> Resource:
         return create_resource_tree(
-            {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+            {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource()
         )
 
     def expect_outgoing_key_request(