diff options
Diffstat (limited to 'synapse/federation/transport/server.py')
-rw-r--r-- | synapse/federation/transport/server.py | 196 |
1 files changed, 138 insertions, 58 deletions
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 6d4a26f595..5ba94be2ec 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -21,8 +21,9 @@ import re from twisted.internet import defer import synapse +from synapse.api.constants import RoomVersions from synapse.api.errors import Codes, FederationDeniedError, SynapseError -from synapse.api.urls import FEDERATION_PREFIX as PREFIX +from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX from synapse.http.endpoint import parse_and_validate_server_name from synapse.http.server import JsonResource from synapse.http.servlet import ( @@ -42,9 +43,20 @@ logger = logging.getLogger(__name__) class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" - def __init__(self, hs): + def __init__(self, hs, servlet_groups=None): + """Initialize the TransportLayerServer + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs (synapse.server.HomeServer): homeserver + servlet_groups (list[str], optional): List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ self.hs = hs self.clock = hs.get_clock() + self.servlet_groups = servlet_groups super(TransportLayerServer, self).__init__(hs, canonical_json=False) @@ -66,6 +78,7 @@ class TransportLayerServer(JsonResource): resource=self, ratelimiter=self.ratelimiter, authenticator=self.authenticator, + servlet_groups=self.servlet_groups, ) @@ -227,6 +240,8 @@ class BaseFederationServlet(object): """ REQUIRE_AUTH = True + PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + def __init__(self, handler, authenticator, ratelimiter, server_name): self.handler = handler self.authenticator = authenticator @@ -286,7 +301,7 @@ class BaseFederationServlet(object): return new_func def register(self, server): - pattern = re.compile("^" + PREFIX + self.PATH + "$") + pattern = re.compile("^" + self.PREFIX + self.PATH + "$") for method in ("GET", "PUT", "POST"): code = getattr(self, "on_%s" % (method), None) @@ -362,14 +377,6 @@ class FederationSendServlet(BaseFederationServlet): defer.returnValue((code, response)) -class FederationPullServlet(BaseFederationServlet): - PATH = "/pull/" - - # This is for when someone asks us for everything since version X - def on_GET(self, origin, content, query): - return self.handler.on_pull_request(query["origin"][0], query["v"]) - - class FederationEventServlet(BaseFederationServlet): PATH = "/event/(?P<event_id>[^/]*)/" @@ -474,7 +481,7 @@ class FederationSendLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_PUT(self, origin, content, query, room_id, event_id): - content = yield self.handler.on_send_leave_request(origin, content) + content = yield self.handler.on_send_leave_request(origin, content, room_id) defer.returnValue((200, content)) @@ -492,18 +499,50 @@ class FederationSendJoinServlet(BaseFederationServlet): def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content - content = yield self.handler.on_send_join_request(origin, content) + content = yield self.handler.on_send_join_request(origin, content, context) defer.returnValue((200, content)) -class FederationInviteServlet(BaseFederationServlet): +class FederationV1InviteServlet(BaseFederationServlet): PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" @defer.inlineCallbacks def on_PUT(self, origin, content, query, context, event_id): + # We don't get a room version, so we have to assume its EITHER v1 or + # v2. This is "fine" as the only difference between V1 and V2 is the + # state resolution algorithm, and we don't use that for processing + # invites + content = yield self.handler.on_invite_request( + origin, content, room_version=RoomVersions.V1, + ) + + # V1 federation API is defined to return a content of `[200, {...}]` + # due to a historical bug. + defer.returnValue((200, (200, content))) + + +class FederationV2InviteServlet(BaseFederationServlet): + PATH = "/invite/(?P<context>[^/]*)/(?P<event_id>[^/]*)" + + PREFIX = FEDERATION_V2_PREFIX + + @defer.inlineCallbacks + def on_PUT(self, origin, content, query, context, event_id): # TODO(paul): assert that context/event_id parsed from path actually # match those given in content - content = yield self.handler.on_invite_request(origin, content) + + room_version = content["room_version"] + event = content["event"] + invite_room_state = content["invite_room_state"] + + # Synapse expects invite_room_state to be in unsigned, as it is in v1 + # API + + event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state + + content = yield self.handler.on_invite_request( + origin, event, room_version=room_version, + ) defer.returnValue((200, content)) @@ -1262,7 +1301,6 @@ class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): FEDERATION_SERVLET_CLASSES = ( FederationSendServlet, - FederationPullServlet, FederationEventServlet, FederationStateServlet, FederationStateIdsServlet, @@ -1273,7 +1311,8 @@ FEDERATION_SERVLET_CLASSES = ( FederationEventServlet, FederationSendJoinServlet, FederationSendLeaveServlet, - FederationInviteServlet, + FederationV1InviteServlet, + FederationV2InviteServlet, FederationQueryAuthServlet, FederationGetMissingEventsServlet, FederationEventAuthServlet, @@ -1282,10 +1321,12 @@ FEDERATION_SERVLET_CLASSES = ( FederationClientKeysClaimServlet, FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, - OpenIdUserInfo, FederationVersionServlet, ) +OPENID_SERVLET_CLASSES = ( + OpenIdUserInfo, +) ROOM_LIST_CLASSES = ( PublicRoomList, @@ -1324,44 +1365,83 @@ GROUP_ATTESTATION_SERVLET_CLASSES = ( FederationGroupsRenewAttestaionServlet, ) +DEFAULT_SERVLET_GROUPS = ( + "federation", + "room_list", + "group_server", + "group_local", + "group_attestation", + "openid", +) + -def register_servlets(hs, resource, authenticator, ratelimiter): - for servletclass in FEDERATION_SERVLET_CLASSES: - servletclass( - handler=hs.get_federation_server(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - for servletclass in ROOM_LIST_CLASSES: - servletclass( - handler=hs.get_room_list_handler(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - for servletclass in GROUP_SERVER_SERVLET_CLASSES: - servletclass( - handler=hs.get_groups_server_handler(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - for servletclass in GROUP_LOCAL_SERVLET_CLASSES: - servletclass( - handler=hs.get_groups_local_handler(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) - - for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: - servletclass( - handler=hs.get_groups_attestation_renewer(), - authenticator=authenticator, - ratelimiter=ratelimiter, - server_name=hs.hostname, - ).register(resource) +def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=None): + """Initialize and register servlet classes. + + Will by default register all servlets. For custom behaviour, pass in + a list of servlet_groups to register. + + Args: + hs (synapse.server.HomeServer): homeserver + resource (TransportLayerServer): resource class to register to + authenticator (Authenticator): authenticator to use + ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use + servlet_groups (list[str], optional): List of servlet groups to register. + Defaults to ``DEFAULT_SERVLET_GROUPS``. + """ + if not servlet_groups: + servlet_groups = DEFAULT_SERVLET_GROUPS + + if "federation" in servlet_groups: + for servletclass in FEDERATION_SERVLET_CLASSES: + servletclass( + handler=hs.get_federation_server(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "openid" in servlet_groups: + for servletclass in OPENID_SERVLET_CLASSES: + servletclass( + handler=hs.get_federation_server(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "room_list" in servlet_groups: + for servletclass in ROOM_LIST_CLASSES: + servletclass( + handler=hs.get_room_list_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "group_server" in servlet_groups: + for servletclass in GROUP_SERVER_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_server_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "group_local" in servlet_groups: + for servletclass in GROUP_LOCAL_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_local_handler(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) + + if "group_attestation" in servlet_groups: + for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES: + servletclass( + handler=hs.get_groups_attestation_renewer(), + authenticator=authenticator, + ratelimiter=ratelimiter, + server_name=hs.hostname, + ).register(resource) |