summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_server.py16
-rw-r--r--synapse/federation/transport/server.py137
2 files changed, 111 insertions, 42 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index aeadc9c564..3da86d4ba6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -148,6 +148,22 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
+        # Reject if PDU count > 50 and EDU count > 100
+        if (len(transaction.pdus) > 50
+                or (hasattr(transaction, "edus") and len(transaction.edus) > 100)):
+
+            logger.info(
+                "Transaction PDU or EDU count too large. Returning 400",
+            )
+
+            response = {}
+            yield self.transaction_actions.set_response(
+                origin,
+                transaction,
+                400, response
+            )
+            defer.returnValue((400, response))
+
         received_pdus_counter.inc(len(transaction.pdus))
 
         origin_host, _ = parse_server_name(origin)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 67ae0212c3..a2396ab466 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -43,9 +43,20 @@ logger = logging.getLogger(__name__)
 class TransportLayerServer(JsonResource):
     """Handles incoming federation HTTP requests"""
 
-    def __init__(self, hs):
+    def __init__(self, hs, servlet_groups=None):
+        """Initialize the TransportLayerServer
+
+        Will by default register all servlets. For custom behaviour, pass in
+        a list of servlet_groups to register.
+
+        Args:
+            hs (synapse.server.HomeServer): homeserver
+            servlet_groups (list[str], optional): List of servlet groups to register.
+                Defaults to ``DEFAULT_SERVLET_GROUPS``.
+        """
         self.hs = hs
         self.clock = hs.get_clock()
+        self.servlet_groups = servlet_groups
 
         super(TransportLayerServer, self).__init__(hs, canonical_json=False)
 
@@ -67,6 +78,7 @@ class TransportLayerServer(JsonResource):
             resource=self,
             ratelimiter=self.ratelimiter,
             authenticator=self.authenticator,
+            servlet_groups=self.servlet_groups,
         )
 
 
@@ -1308,10 +1320,12 @@ FEDERATION_SERVLET_CLASSES = (
     FederationClientKeysClaimServlet,
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
-    OpenIdUserInfo,
     FederationVersionServlet,
 )
 
+OPENID_SERVLET_CLASSES = (
+    OpenIdUserInfo,
+)
 
 ROOM_LIST_CLASSES = (
     PublicRoomList,
@@ -1350,44 +1364,83 @@ GROUP_ATTESTATION_SERVLET_CLASSES = (
     FederationGroupsRenewAttestaionServlet,
 )
 
+DEFAULT_SERVLET_GROUPS = (
+    "federation",
+    "room_list",
+    "group_server",
+    "group_local",
+    "group_attestation",
+    "openid",
+)
+
+
+def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None):
+    """Initialize and register servlet classes.
 
-def register_servlets(hs, resource, authenticator, ratelimiter):
-    for servletclass in FEDERATION_SERVLET_CLASSES:
-        servletclass(
-            handler=hs.get_federation_server(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
-
-    for servletclass in ROOM_LIST_CLASSES:
-        servletclass(
-            handler=hs.get_room_list_handler(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
-
-    for servletclass in GROUP_SERVER_SERVLET_CLASSES:
-        servletclass(
-            handler=hs.get_groups_server_handler(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
-
-    for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
-        servletclass(
-            handler=hs.get_groups_local_handler(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
-
-    for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
-        servletclass(
-            handler=hs.get_groups_attestation_renewer(),
-            authenticator=authenticator,
-            ratelimiter=ratelimiter,
-            server_name=hs.hostname,
-        ).register(resource)
+    Will by default register all servlets. For custom behaviour, pass in
+    a list of servlet_groups to register.
+
+    Args:
+        hs (synapse.server.HomeServer): homeserver
+        resource (TransportLayerServer): resource class to register to
+        authenticator (Authenticator): authenticator to use
+        ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
+        servlet_groups (list[str], optional): List of servlet groups to register.
+            Defaults to ``DEFAULT_SERVLET_GROUPS``.
+    """
+    if not servlet_groups:
+        servlet_groups = DEFAULT_SERVLET_GROUPS
+
+    if "federation" in servlet_groups:
+        for servletclass in FEDERATION_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_federation_server(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "openid" in servlet_groups:
+        for servletclass in OPENID_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_federation_server(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "room_list" in servlet_groups:
+        for servletclass in ROOM_LIST_CLASSES:
+            servletclass(
+                handler=hs.get_room_list_handler(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "group_server" in servlet_groups:
+        for servletclass in GROUP_SERVER_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_groups_server_handler(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "group_local" in servlet_groups:
+        for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_groups_local_handler(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)
+
+    if "group_attestation" in servlet_groups:
+        for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
+            servletclass(
+                handler=hs.get_groups_attestation_renewer(),
+                authenticator=authenticator,
+                ratelimiter=ratelimiter,
+                server_name=hs.hostname,
+            ).register(resource)