diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 67ae0212c3..a2396ab466 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -43,9 +43,20 @@ logger = logging.getLogger(__name__)
class TransportLayerServer(JsonResource):
"""Handles incoming federation HTTP requests"""
- def __init__(self, hs):
+ def __init__(self, hs, servlet_groups=None):
+ """Initialize the TransportLayerServer
+
+ Will by default register all servlets. For custom behaviour, pass in
+ a list of servlet_groups to register.
+
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ servlet_groups (list[str], optional): List of servlet groups to register.
+ Defaults to ``DEFAULT_SERVLET_GROUPS``.
+ """
self.hs = hs
self.clock = hs.get_clock()
+ self.servlet_groups = servlet_groups
super(TransportLayerServer, self).__init__(hs, canonical_json=False)
@@ -67,6 +78,7 @@ class TransportLayerServer(JsonResource):
resource=self,
ratelimiter=self.ratelimiter,
authenticator=self.authenticator,
+ servlet_groups=self.servlet_groups,
)
@@ -1308,10 +1320,12 @@ FEDERATION_SERVLET_CLASSES = (
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,
- OpenIdUserInfo,
FederationVersionServlet,
)
+OPENID_SERVLET_CLASSES = (
+ OpenIdUserInfo,
+)
ROOM_LIST_CLASSES = (
PublicRoomList,
@@ -1350,44 +1364,83 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
FederationGroupsRenewAttestaionServlet,
)
+DEFAULT_SERVLET_GROUPS = (
+ "federation",
+ "room_list",
+ "group_server",
+ "group_local",
+ "group_attestation",
+ "openid",
+)
+
+
+def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None):
+ """Initialize and register servlet classes.
-def register_servlets(hs, resource, authenticator, ratelimiter):
- for servletclass in FEDERATION_SERVLET_CLASSES:
- servletclass(
- handler=hs.get_federation_server(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
- for servletclass in ROOM_LIST_CLASSES:
- servletclass(
- handler=hs.get_room_list_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
- for servletclass in GROUP_SERVER_SERVLET_CLASSES:
- servletclass(
- handler=hs.get_groups_server_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
- for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
- servletclass(
- handler=hs.get_groups_local_handler(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
-
- for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
- servletclass(
- handler=hs.get_groups_attestation_renewer(),
- authenticator=authenticator,
- ratelimiter=ratelimiter,
- server_name=hs.hostname,
- ).register(resource)
+ Will by default register all servlets. For custom behaviour, pass in
+ a list of servlet_groups to register.
+
+ Args:
+ hs (synapse.server.HomeServer): homeserver
+ resource (TransportLayerServer): resource class to register to
+ authenticator (Authenticator): authenticator to use
+ ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
+ servlet_groups (list[str], optional): List of servlet groups to register.
+ Defaults to ``DEFAULT_SERVLET_GROUPS``.
+ """
+ if not servlet_groups:
+ servlet_groups = DEFAULT_SERVLET_GROUPS
+
+ if "federation" in servlet_groups:
+ for servletclass in FEDERATION_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_federation_server(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "openid" in servlet_groups:
+ for servletclass in OPENID_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_federation_server(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "room_list" in servlet_groups:
+ for servletclass in ROOM_LIST_CLASSES:
+ servletclass(
+ handler=hs.get_room_list_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "group_server" in servlet_groups:
+ for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_server_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "group_local" in servlet_groups:
+ for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_local_handler(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
+
+ if "group_attestation" in servlet_groups:
+ for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
+ servletclass(
+ handler=hs.get_groups_attestation_renewer(),
+ authenticator=authenticator,
+ ratelimiter=ratelimiter,
+ server_name=hs.hostname,
+ ).register(resource)
|