diff options
Diffstat (limited to 'synapse/federation/transport')
-rw-r--r-- | synapse/federation/transport/client.py | 34 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 176 |
2 files changed, 164 insertions, 46 deletions
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cd2841c4db..3d088e43cb 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -55,6 +55,28 @@ class TransportLayerClient(object): ) @log_function + def get_room_state_ids(self, destination, room_id, event_id): + """ Requests all state for a given room from the given server at the + given event. Returns the state's event_id's + + Args: + destination (str): The host name of the remote home server we want + to get the state from. + context (str): The name of the context we want the state of + event_id (str): The event we want the context at. + + Returns: + Deferred: Results in a dict received from the remote homeserver. + """ + logger.debug("get_room_state_ids dest=%s, room=%s", + destination, room_id) + + path = PREFIX + "/state_ids/%s/" % room_id + return self.client.get_json( + destination, path=path, args={"event_id": event_id}, + ) + + @log_function def get_event(self, destination, event_id, timeout=None): """ Requests the pdu with give id and origin from the given server. @@ -226,6 +248,18 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function + def get_public_rooms(self, remote_server): + path = PREFIX + "/publicRooms" + + response = yield self.client.get_json( + destination=remote_server, + path=path, + ) + + defer.returnValue(response) + + @defer.inlineCallbacks + @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): path = PREFIX + "/exchange_third_party_invite/%s" % (room_id,) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 5b6c7d11dd..37c0d4fbc4 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -18,13 +18,14 @@ from twisted.internet import defer from synapse.api.urls import FEDERATION_PREFIX as PREFIX from synapse.api.errors import Codes, SynapseError from synapse.http.server import JsonResource -from synapse.http.servlet import parse_json_object_from_request, parse_string +from synapse.http.servlet import parse_json_object_from_request from synapse.util.ratelimitutils import FederationRateLimiter +from synapse.util.versionstring import get_version_string import functools import logging -import simplejson as json import re +import synapse logger = logging.getLogger(__name__) @@ -37,7 +38,7 @@ class TransportLayerServer(JsonResource): self.hs = hs self.clock = hs.get_clock() - super(TransportLayerServer, self).__init__(hs) + super(TransportLayerServer, self).__init__(hs, canonical_json=False) self.authenticator = Authenticator(hs) self.ratelimiter = FederationRateLimiter( @@ -60,6 +61,16 @@ class TransportLayerServer(JsonResource): ) +class AuthenticationError(SynapseError): + """There was a problem authenticating the request""" + pass + + +class NoAuthenticationError(AuthenticationError): + """The request had no authentication information""" + pass + + class Authenticator(object): def __init__(self, hs): self.keyring = hs.get_keyring() @@ -67,7 +78,7 @@ class Authenticator(object): # A method just so we can pass 'self' as the authenticator to the Servlets @defer.inlineCallbacks - def authenticate_request(self, request): + def authenticate_request(self, request, content): json_request = { "method": request.method, "uri": request.uri, @@ -75,17 +86,10 @@ class Authenticator(object): "signatures": {}, } - content = None - origin = None + if content is not None: + json_request["content"] = content - if request.method in ["PUT", "POST"]: - # TODO: Handle other method types? other content types? - try: - content_bytes = request.content.read() - content = json.loads(content_bytes) - json_request["content"] = content - except: - raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON) + origin = None def parse_auth_header(header_str): try: @@ -103,14 +107,14 @@ class Authenticator(object): sig = strip_quotes(param_dict["sig"]) return (origin, key, sig) except: - raise SynapseError( + raise AuthenticationError( 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") if not auth_headers: - raise SynapseError( + raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED, ) @@ -121,7 +125,7 @@ class Authenticator(object): json_request["signatures"].setdefault(origin, {})[key] = sig if not json_request["signatures"]: - raise SynapseError( + raise NoAuthenticationError( 401, "Missing Authorization headers", Codes.UNAUTHORIZED, ) @@ -130,38 +134,59 @@ class Authenticator(object): logger.info("Request from %s", origin) request.authenticated_entity = origin - defer.returnValue((origin, content)) + defer.returnValue(origin) class BaseFederationServlet(object): - def __init__(self, handler, authenticator, ratelimiter, server_name): + REQUIRE_AUTH = True + + def __init__(self, handler, authenticator, ratelimiter, server_name, + room_list_handler): self.handler = handler self.authenticator = authenticator self.ratelimiter = ratelimiter + self.room_list_handler = room_list_handler - def _wrap(self, code): + def _wrap(self, func): authenticator = self.authenticator ratelimiter = self.ratelimiter @defer.inlineCallbacks - @functools.wraps(code) - def new_code(request, *args, **kwargs): + @functools.wraps(func) + def new_func(request, *args, **kwargs): + content = None + if request.method in ["PUT", "POST"]: + # TODO: Handle other method types? other content types? + content = parse_json_object_from_request(request) + try: - (origin, content) = yield authenticator.authenticate_request(request) + origin = yield authenticator.authenticate_request(request, content) + except NoAuthenticationError: + origin = None + if self.REQUIRE_AUTH: + logger.exception("authenticate_request failed") + raise + except: + logger.exception("authenticate_request failed") + raise + + if origin: with ratelimiter.ratelimit(origin) as d: yield d - response = yield code( + response = yield func( origin, content, request.args, *args, **kwargs ) - except: - logger.exception("authenticate_request failed") - raise + else: + response = yield func( + origin, content, request.args, *args, **kwargs + ) + defer.returnValue(response) # Extra logic that functools.wraps() doesn't finish - new_code.__self__ = code.__self__ + new_func.__self__ = func.__self__ - return new_code + return new_func def register(self, server): pattern = re.compile("^" + PREFIX + self.PATH + "$") @@ -269,6 +294,17 @@ class FederationStateServlet(BaseFederationServlet): ) +class FederationStateIdsServlet(BaseFederationServlet): + PATH = "/state_ids/(?P<room_id>[^/]*)/" + + def on_GET(self, origin, content, query, room_id): + return self.handler.on_state_ids_request( + origin, + room_id, + query.get("event_id", [None])[0], + ) + + class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P<context>[^/]*)/" @@ -365,10 +401,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet): class FederationClientKeysQueryServlet(BaseFederationServlet): PATH = "/user/keys/query" - @defer.inlineCallbacks def on_POST(self, origin, content, query): - response = yield self.handler.on_query_client_keys(origin, content) - defer.returnValue((200, response)) + return self.handler.on_query_client_keys(origin, content) class FederationClientKeysClaimServlet(BaseFederationServlet): @@ -386,7 +420,7 @@ class FederationQueryAuthServlet(BaseFederationServlet): @defer.inlineCallbacks def on_POST(self, origin, content, query, context, event_id): new_content = yield self.handler.on_query_auth_request( - origin, content, event_id + origin, content, context, event_id ) defer.returnValue((200, new_content)) @@ -418,9 +452,10 @@ class FederationGetMissingEventsServlet(BaseFederationServlet): class On3pidBindServlet(BaseFederationServlet): PATH = "/3pid/onbind" + REQUIRE_AUTH = False + @defer.inlineCallbacks - def on_POST(self, request): - content = parse_json_object_from_request(request) + def on_POST(self, origin, content, query): if "invites" in content: last_exception = None for invite in content["invites"]: @@ -442,11 +477,6 @@ class On3pidBindServlet(BaseFederationServlet): raise last_exception defer.returnValue((200, {})) - # Avoid doing remote HS authorization checks which are done by default by - # BaseFederationServlet. - def _wrap(self, code): - return code - class OpenIdUserInfo(BaseFederationServlet): """ @@ -467,9 +497,11 @@ class OpenIdUserInfo(BaseFederationServlet): PATH = "/openid/userinfo" + REQUIRE_AUTH = False + @defer.inlineCallbacks - def on_GET(self, request): - token = parse_string(request, "access_token") + def on_GET(self, origin, content, query): + token = query.get("access_token", [None])[0] if token is None: defer.returnValue((401, { "errcode": "M_MISSING_TOKEN", "error": "Access Token required" @@ -486,10 +518,58 @@ class OpenIdUserInfo(BaseFederationServlet): defer.returnValue((200, {"sub": user_id})) - # Avoid doing remote HS authorization checks which are done by default by - # BaseFederationServlet. - def _wrap(self, code): - return code + +class PublicRoomList(BaseFederationServlet): + """ + Fetch the public room list for this server. + + This API returns information in the same format as /publicRooms on the + client API, but will only ever include local public rooms and hence is + intended for consumption by other home servers. + + GET /publicRooms HTTP/1.1 + + HTTP/1.1 200 OK + Content-Type: application/json + + { + "chunk": [ + { + "aliases": [ + "#test:localhost" + ], + "guest_can_join": false, + "name": "test room", + "num_joined_members": 3, + "room_id": "!whkydVegtvatLfXmPN:localhost", + "world_readable": false + } + ], + "end": "END", + "start": "START" + } + """ + + PATH = "/publicRooms" + + @defer.inlineCallbacks + def on_GET(self, origin, content, query): + data = yield self.room_list_handler.get_local_public_room_list() + defer.returnValue((200, data)) + + +class FederationVersionServlet(BaseFederationServlet): + PATH = "/version" + + REQUIRE_AUTH = False + + def on_GET(self, origin, content, query): + return defer.succeed((200, { + "server": { + "name": "Synapse", + "version": get_version_string(synapse) + }, + })) SERVLET_CLASSES = ( @@ -497,6 +577,7 @@ SERVLET_CLASSES = ( FederationPullServlet, FederationEventServlet, FederationStateServlet, + FederationStateIdsServlet, FederationBackfillServlet, FederationQueryServlet, FederationMakeJoinServlet, @@ -513,6 +594,8 @@ SERVLET_CLASSES = ( FederationThirdPartyInviteExchangeServlet, On3pidBindServlet, OpenIdUserInfo, + PublicRoomList, + FederationVersionServlet, ) @@ -523,4 +606,5 @@ def register_servlets(hs, resource, authenticator, ratelimiter): authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, + room_list_handler=hs.get_room_list_handler(), ).register(resource) |