summary refs log tree commit diff
path: root/synapse/_scripts
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/_scripts')
-rwxr-xr-xsynapse/_scripts/generate_workers_map.py302
1 files changed, 302 insertions, 0 deletions
diff --git a/synapse/_scripts/generate_workers_map.py b/synapse/_scripts/generate_workers_map.py
new file mode 100755
index 0000000000..6c08878523
--- /dev/null
+++ b/synapse/_scripts/generate_workers_map.py
@@ -0,0 +1,302 @@
+#!/usr/bin/env python
+# Copyright 2022-2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import logging
+import re
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import Dict, Iterable, Optional, Pattern, Set, Tuple
+
+import yaml
+
+from synapse.config.homeserver import HomeServerConfig
+from synapse.federation.transport.server import (
+    TransportLayerServer,
+    register_servlets as register_federation_servlets,
+)
+from synapse.http.server import HttpServer, ServletCallback
+from synapse.rest import ClientRestResource
+from synapse.rest.key.v2 import RemoteKey
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+
+logger = logging.getLogger("generate_workers_map")
+
+
+class MockHomeserver(HomeServer):
+    DATASTORE_CLASS = DataStore  # type: ignore
+
+    def __init__(self, config: HomeServerConfig, worker_app: Optional[str]) -> None:
+        super().__init__(config.server.server_name, config=config)
+        self.config.worker.worker_app = worker_app
+
+
+GROUP_PATTERN = re.compile(r"\(\?P<[^>]+?>(.+?)\)")
+
+
+@dataclass
+class EndpointDescription:
+    """
+    Describes an endpoint and how it should be routed.
+    """
+
+    # The servlet class that handles this endpoint
+    servlet_class: object
+
+    # The category of this endpoint. Is read from the `CATEGORY` constant in the servlet
+    # class.
+    category: Optional[str]
+
+    # TODO:
+    #  - does it need to be routed based on a stream writer config?
+    #  - does it benefit from any optimised, but optional, routing?
+    #  - what 'opinionated synapse worker class' (event_creator, synchrotron, etc) does
+    #    it go in?
+
+
+class EnumerationResource(HttpServer):
+    """
+    Accepts servlet registrations for the purposes of building up a description of
+    all endpoints.
+    """
+
+    def __init__(self, is_worker: bool) -> None:
+        self.registrations: Dict[Tuple[str, str], EndpointDescription] = {}
+        self._is_worker = is_worker
+
+    def register_paths(
+        self,
+        method: str,
+        path_patterns: Iterable[Pattern],
+        callback: ServletCallback,
+        servlet_classname: str,
+    ) -> None:
+        # federation servlet callbacks are wrapped, so unwrap them.
+        callback = getattr(callback, "__wrapped__", callback)
+
+        # fish out the servlet class
+        servlet_class = callback.__self__.__class__  # type: ignore
+
+        if self._is_worker and method in getattr(
+            servlet_class, "WORKERS_DENIED_METHODS", ()
+        ):
+            # This endpoint would cause an error if called on a worker, so pretend it
+            # was never registered!
+            return
+
+        sd = EndpointDescription(
+            servlet_class=servlet_class,
+            category=getattr(servlet_class, "CATEGORY", None),
+        )
+
+        for pat in path_patterns:
+            self.registrations[(method, pat.pattern)] = sd
+
+
+def get_registered_paths_for_hs(
+    hs: HomeServer,
+) -> Dict[Tuple[str, str], EndpointDescription]:
+    """
+    Given a homeserver, get all registered endpoints and their descriptions.
+    """
+
+    enumerator = EnumerationResource(is_worker=hs.config.worker.worker_app is not None)
+    ClientRestResource.register_servlets(enumerator, hs)
+    federation_server = TransportLayerServer(hs)
+
+    # we can't use `federation_server.register_servlets` but this line does the
+    # same thing, only it uses this enumerator
+    register_federation_servlets(
+        federation_server.hs,
+        resource=enumerator,
+        ratelimiter=federation_server.ratelimiter,
+        authenticator=federation_server.authenticator,
+        servlet_groups=federation_server.servlet_groups,
+    )
+
+    # the key server endpoints are separate again
+    RemoteKey(hs).register(enumerator)
+
+    return enumerator.registrations
+
+
+def get_registered_paths_for_default(
+    worker_app: Optional[str], base_config: HomeServerConfig
+) -> Dict[Tuple[str, str], EndpointDescription]:
+    """
+    Given the name of a worker application and a base homeserver configuration,
+    returns:
+
+        Dict from (method, path) to EndpointDescription
+
+    TODO Don't require passing in a config
+    """
+
+    hs = MockHomeserver(base_config, worker_app)
+    # TODO We only do this to avoid an error, but don't need the database etc
+    hs.setup()
+    return get_registered_paths_for_hs(hs)
+
+
+def elide_http_methods_if_unconflicting(
+    registrations: Dict[Tuple[str, str], EndpointDescription],
+    all_possible_registrations: Dict[Tuple[str, str], EndpointDescription],
+) -> Dict[Tuple[str, str], EndpointDescription]:
+    """
+    Elides HTTP methods (by replacing them with `*`) if all possible registered methods
+    can be handled by the worker whose registration map is `registrations`.
+
+    i.e. the only endpoints left with methods (other than `*`) should be the ones where
+    the worker can't handle all possible methods for that path.
+    """
+
+    def paths_to_methods_dict(
+        methods_and_paths: Iterable[Tuple[str, str]]
+    ) -> Dict[str, Set[str]]:
+        """
+        Given (method, path) pairs, produces a dict from path to set of methods
+        available at that path.
+        """
+        result: Dict[str, Set[str]] = {}
+        for method, path in methods_and_paths:
+            result.setdefault(path, set()).add(method)
+        return result
+
+    all_possible_reg_methods = paths_to_methods_dict(all_possible_registrations)
+    reg_methods = paths_to_methods_dict(registrations)
+
+    output = {}
+
+    for path, handleable_methods in reg_methods.items():
+        if handleable_methods == all_possible_reg_methods[path]:
+            any_method = next(iter(handleable_methods))
+            # TODO This assumes that all methods have the same servlet.
+            #      I suppose that's possibly dubious?
+            output[("*", path)] = registrations[(any_method, path)]
+        else:
+            for method in handleable_methods:
+                output[(method, path)] = registrations[(method, path)]
+
+    return output
+
+
+def simplify_path_regexes(
+    registrations: Dict[Tuple[str, str], EndpointDescription]
+) -> Dict[Tuple[str, str], EndpointDescription]:
+    """
+    Simplify all the path regexes for the dict of endpoint descriptions,
+    so that we don't use the Python-specific regex extensions
+    (and also to remove needlessly specific detail).
+    """
+
+    def simplify_path_regex(path: str) -> str:
+        """
+        Given a regex pattern, replaces all named capturing groups (e.g. `(?P<blah>xyz)`)
+        with a simpler version available in more common regex dialects (e.g. `.*`).
+        """
+
+        # TODO it's hard to choose between these two;
+        #      `.*` is a vague simplification
+        # return GROUP_PATTERN.sub(r"\1", path)
+        return GROUP_PATTERN.sub(r".*", path)
+
+    return {(m, simplify_path_regex(p)): v for (m, p), v in registrations.items()}
+
+
+def main() -> None:
+    parser = argparse.ArgumentParser(
+        description=(
+            "Updates a synapse database to the latest schema and optionally runs background updates"
+            " on it."
+        )
+    )
+    parser.add_argument("-v", action="store_true")
+    parser.add_argument(
+        "--config-path",
+        type=argparse.FileType("r"),
+        required=True,
+        help="Synapse configuration file",
+    )
+
+    args = parser.parse_args()
+
+    # TODO
+    # logging.basicConfig(**logging_config)
+
+    # Load, process and sanity-check the config.
+    hs_config = yaml.safe_load(args.config_path)
+
+    config = HomeServerConfig()
+    config.parse_config_dict(hs_config, "", "")
+
+    master_paths = get_registered_paths_for_default(None, config)
+    worker_paths = get_registered_paths_for_default(
+        "synapse.app.generic_worker", config
+    )
+
+    all_paths = {**master_paths, **worker_paths}
+
+    elided_worker_paths = elide_http_methods_if_unconflicting(worker_paths, all_paths)
+    elide_http_methods_if_unconflicting(master_paths, all_paths)
+
+    # TODO SSO endpoints (pick_idp etc) NOT REGISTERED BY THIS SCRIPT
+
+    categories_to_methods_and_paths: Dict[
+        Optional[str], Dict[Tuple[str, str], EndpointDescription]
+    ] = defaultdict(dict)
+
+    for (method, path), desc in elided_worker_paths.items():
+        categories_to_methods_and_paths[desc.category][method, path] = desc
+
+    for category, contents in categories_to_methods_and_paths.items():
+        print_category(category, contents)
+
+
+def print_category(
+    category_name: Optional[str],
+    elided_worker_paths: Dict[Tuple[str, str], EndpointDescription],
+) -> None:
+    """
+    Prints out a category, in documentation page style.
+
+    Example:
+    ```
+    # Category name
+    /path/xyz
+
+    GET /path/abc
+    ```
+    """
+
+    if category_name:
+        print(f"# {category_name}")
+    else:
+        print("# (Uncategorised requests)")
+
+    for ln in sorted(
+        p for m, p in simplify_path_regexes(elided_worker_paths) if m == "*"
+    ):
+        print(ln)
+    print()
+    for ln in sorted(
+        f"{m:6} {p}" for m, p in simplify_path_regexes(elided_worker_paths) if m != "*"
+    ):
+        print(ln)
+    print()
+
+
+if __name__ == "__main__":
+    main()