summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py342
1 files changed, 195 insertions, 147 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 87a92f6ea9..c9f3c2d352 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -21,25 +21,25 @@ import random
 
 from six.moves import range
 
+from prometheus_client import Counter
+
 from twisted.internet import defer
 
-from synapse.api.constants import Membership
+from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
 from synapse.api.errors import (
-    CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
+    CodeMessageException,
+    FederationDeniedError,
+    HttpResponseException,
+    SynapseError,
 )
 from synapse.events import builder
-from synapse.federation.federation_base import (
-    FederationBase,
-    event_from_pdu_json,
-)
+from synapse.federation.federation_base import FederationBase, event_from_pdu_json
 from synapse.util import logcontext, unwrapFirstError
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logcontext import make_deferred_yieldable, run_in_background
 from synapse.util.logutils import log_function
 from synapse.util.retryutils import NotRetryingDestination
 
-from prometheus_client import Counter
-
 logger = logging.getLogger(__name__)
 
 sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -48,6 +48,13 @@ sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["t
 PDU_RETRY_TIME_MS = 1 * 60 * 1000
 
 
+class InvalidResponseError(RuntimeError):
+    """Helper for _try_destination_list: indicates that the server returned a response
+    we couldn't parse
+    """
+    pass
+
+
 class FederationClient(FederationBase):
     def __init__(self, hs):
         super(FederationClient, self).__init__(hs)
@@ -458,8 +465,63 @@ class FederationClient(FederationBase):
         defer.returnValue(signed_auth)
 
     @defer.inlineCallbacks
+    def _try_destination_list(self, description, destinations, callback):
+        """Try an operation on a series of servers, until it succeeds
+
+        Args:
+            description (unicode): description of the operation we're doing, for logging
+
+            destinations (Iterable[unicode]): list of server_names to try
+
+            callback (callable):  Function to run for each server. Passed a single
+                argument: the server_name to try. May return a deferred.
+
+                If the callback raises a CodeMessageException with a 300/400 code,
+                attempts to perform the operation stop immediately and the exception is
+                reraised.
+
+                Otherwise, if the callback raises an Exception the error is logged and the
+                next server tried. Normally the stacktrace is logged but this is
+                suppressed if the exception is an InvalidResponseError.
+
+        Returns:
+            The [Deferred] result of callback, if it succeeds
+
+        Raises:
+            SynapseError if the chosen remote server returns a 300/400 code.
+
+            RuntimeError if no servers were reachable.
+        """
+        for destination in destinations:
+            if destination == self.server_name:
+                continue
+
+            try:
+                res = yield callback(destination)
+                defer.returnValue(res)
+            except InvalidResponseError as e:
+                logger.warn(
+                    "Failed to %s via %s: %s",
+                    description, destination, e,
+                )
+            except HttpResponseException as e:
+                if not 500 <= e.code < 600:
+                    raise e.to_synapse_error()
+                else:
+                    logger.warn(
+                        "Failed to %s via %s: %i %s",
+                        description, destination, e.code, e.message,
+                    )
+            except Exception:
+                logger.warn(
+                    "Failed to %s via %s",
+                    description, destination, exc_info=1,
+                )
+
+        raise RuntimeError("Failed to %s via any server" % (description, ))
+
     def make_membership_event(self, destinations, room_id, user_id, membership,
-                              content={},):
+                              content, params):
         """
         Creates an m.room.member event, with context, without participating in the room.
 
@@ -475,13 +537,15 @@ class FederationClient(FederationBase):
             user_id (str): The user whose membership is being evented.
             membership (str): The "membership" property of the event. Must be
                 one of "join" or "leave".
-            content (object): Any additional data to put into the content field
+            content (dict): Any additional data to put into the content field
                 of the event.
+            params (dict[str, str|Iterable[str]]): Query parameters to include in the
+                request.
         Return:
             Deferred: resolves to a tuple of (origin (str), event (object))
             where origin is the remote homeserver which generated the event.
 
-            Fails with a ``CodeMessageException`` if the chosen remote server
+            Fails with a ``SynapseError`` if the chosen remote server
             returns a 300/400 code.
 
             Fails with a ``RuntimeError`` if no servers were reachable.
@@ -492,50 +556,37 @@ class FederationClient(FederationBase):
                 "make_membership_event called with membership='%s', must be one of %s" %
                 (membership, ",".join(valid_memberships))
             )
-        for destination in destinations:
-            if destination == self.server_name:
-                continue
 
-            try:
-                ret = yield self.transport_layer.make_membership_event(
-                    destination, room_id, user_id, membership
-                )
+        @defer.inlineCallbacks
+        def send_request(destination):
+            ret = yield self.transport_layer.make_membership_event(
+                destination, room_id, user_id, membership, params,
+            )
 
-                pdu_dict = ret["event"]
+            pdu_dict = ret.get("event", None)
+            if not isinstance(pdu_dict, dict):
+                raise InvalidResponseError("Bad 'event' field in response")
 
-                logger.debug("Got response to make_%s: %s", membership, pdu_dict)
+            logger.debug("Got response to make_%s: %s", membership, pdu_dict)
 
-                pdu_dict["content"].update(content)
+            pdu_dict["content"].update(content)
 
-                # The protoevent received over the JSON wire may not have all
-                # the required fields. Lets just gloss over that because
-                # there's some we never care about
-                if "prev_state" not in pdu_dict:
-                    pdu_dict["prev_state"] = []
+            # The protoevent received over the JSON wire may not have all
+            # the required fields. Lets just gloss over that because
+            # there's some we never care about
+            if "prev_state" not in pdu_dict:
+                pdu_dict["prev_state"] = []
 
-                ev = builder.EventBuilder(pdu_dict)
+            ev = builder.EventBuilder(pdu_dict)
 
-                defer.returnValue(
-                    (destination, ev)
-                )
-                break
-            except CodeMessageException as e:
-                if not 500 <= e.code < 600:
-                    raise
-                else:
-                    logger.warn(
-                        "Failed to make_%s via %s: %s",
-                        membership, destination, e.message
-                    )
-            except Exception as e:
-                logger.warn(
-                    "Failed to make_%s via %s: %s",
-                    membership, destination, e.message
-                )
+            defer.returnValue(
+                (destination, ev)
+            )
 
-        raise RuntimeError("Failed to send to any server.")
+        return self._try_destination_list(
+            "make_" + membership, destinations, send_request,
+        )
 
-    @defer.inlineCallbacks
     def send_join(self, destinations, pdu):
         """Sends a join event to one of a list of homeservers.
 
@@ -552,103 +603,111 @@ class FederationClient(FederationBase):
             giving the serer the event was sent to, ``state`` (?) and
             ``auth_chain``.
 
-            Fails with a ``CodeMessageException`` if the chosen remote server
+            Fails with a ``SynapseError`` if the chosen remote server
             returns a 300/400 code.
 
             Fails with a ``RuntimeError`` if no servers were reachable.
         """
 
-        for destination in destinations:
-            if destination == self.server_name:
-                continue
-
-            try:
-                time_now = self._clock.time_msec()
-                _, content = yield self.transport_layer.send_join(
-                    destination=destination,
-                    room_id=pdu.room_id,
-                    event_id=pdu.event_id,
-                    content=pdu.get_pdu_json(time_now),
+        def check_authchain_validity(signed_auth_chain):
+            for e in signed_auth_chain:
+                if e.type == EventTypes.Create:
+                    create_event = e
+                    break
+            else:
+                raise InvalidResponseError(
+                    "no %s in auth chain" % (EventTypes.Create,),
                 )
 
-                logger.debug("Got content: %s", content)
+            # the room version should be sane.
+            room_version = create_event.content.get("room_version", "1")
+            if room_version not in KNOWN_ROOM_VERSIONS:
+                # This shouldn't be possible, because the remote server should have
+                # rejected the join attempt during make_join.
+                raise InvalidResponseError(
+                    "room appears to have unsupported version %s" % (
+                        room_version,
+                    ))
+
+        @defer.inlineCallbacks
+        def send_request(destination):
+            time_now = self._clock.time_msec()
+            _, content = yield self.transport_layer.send_join(
+                destination=destination,
+                room_id=pdu.room_id,
+                event_id=pdu.event_id,
+                content=pdu.get_pdu_json(time_now),
+            )
 
-                state = [
-                    event_from_pdu_json(p, outlier=True)
-                    for p in content.get("state", [])
-                ]
+            logger.debug("Got content: %s", content)
 
-                auth_chain = [
-                    event_from_pdu_json(p, outlier=True)
-                    for p in content.get("auth_chain", [])
-                ]
+            state = [
+                event_from_pdu_json(p, outlier=True)
+                for p in content.get("state", [])
+            ]
 
-                pdus = {
-                    p.event_id: p
-                    for p in itertools.chain(state, auth_chain)
-                }
+            auth_chain = [
+                event_from_pdu_json(p, outlier=True)
+                for p in content.get("auth_chain", [])
+            ]
 
-                valid_pdus = yield self._check_sigs_and_hash_and_fetch(
-                    destination, list(pdus.values()),
-                    outlier=True,
-                )
+            pdus = {
+                p.event_id: p
+                for p in itertools.chain(state, auth_chain)
+            }
 
-                valid_pdus_map = {
-                    p.event_id: p
-                    for p in valid_pdus
-                }
-
-                # NB: We *need* to copy to ensure that we don't have multiple
-                # references being passed on, as that causes... issues.
-                signed_state = [
-                    copy.copy(valid_pdus_map[p.event_id])
-                    for p in state
-                    if p.event_id in valid_pdus_map
-                ]
+            valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+                destination, list(pdus.values()),
+                outlier=True,
+            )
 
-                signed_auth = [
-                    valid_pdus_map[p.event_id]
-                    for p in auth_chain
-                    if p.event_id in valid_pdus_map
-                ]
+            valid_pdus_map = {
+                p.event_id: p
+                for p in valid_pdus
+            }
 
-                # NB: We *need* to copy to ensure that we don't have multiple
-                # references being passed on, as that causes... issues.
-                for s in signed_state:
-                    s.internal_metadata = copy.deepcopy(s.internal_metadata)
+            # NB: We *need* to copy to ensure that we don't have multiple
+            # references being passed on, as that causes... issues.
+            signed_state = [
+                copy.copy(valid_pdus_map[p.event_id])
+                for p in state
+                if p.event_id in valid_pdus_map
+            ]
 
-                auth_chain.sort(key=lambda e: e.depth)
+            signed_auth = [
+                valid_pdus_map[p.event_id]
+                for p in auth_chain
+                if p.event_id in valid_pdus_map
+            ]
 
-                defer.returnValue({
-                    "state": signed_state,
-                    "auth_chain": signed_auth,
-                    "origin": destination,
-                })
-            except CodeMessageException as e:
-                if not 500 <= e.code < 600:
-                    raise
-                else:
-                    logger.exception(
-                        "Failed to send_join via %s: %s",
-                        destination, e.message
-                    )
-            except Exception as e:
-                logger.exception(
-                    "Failed to send_join via %s: %s",
-                    destination, e.message
-                )
+            # NB: We *need* to copy to ensure that we don't have multiple
+            # references being passed on, as that causes... issues.
+            for s in signed_state:
+                s.internal_metadata = copy.deepcopy(s.internal_metadata)
 
-        raise RuntimeError("Failed to send to any server.")
+            check_authchain_validity(signed_auth)
+
+            defer.returnValue({
+                "state": signed_state,
+                "auth_chain": signed_auth,
+                "origin": destination,
+            })
+        return self._try_destination_list("send_join", destinations, send_request)
 
     @defer.inlineCallbacks
     def send_invite(self, destination, room_id, event_id, pdu):
         time_now = self._clock.time_msec()
-        code, content = yield self.transport_layer.send_invite(
-            destination=destination,
-            room_id=room_id,
-            event_id=event_id,
-            content=pdu.get_pdu_json(time_now),
-        )
+        try:
+            code, content = yield self.transport_layer.send_invite(
+                destination=destination,
+                room_id=room_id,
+                event_id=event_id,
+                content=pdu.get_pdu_json(time_now),
+            )
+        except HttpResponseException as e:
+            if e.code == 403:
+                raise e.to_synapse_error()
+            raise
 
         pdu_dict = content["event"]
 
@@ -663,7 +722,6 @@ class FederationClient(FederationBase):
 
         defer.returnValue(pdu)
 
-    @defer.inlineCallbacks
     def send_leave(self, destinations, pdu):
         """Sends a leave event to one of a list of homeservers.
 
@@ -680,35 +738,25 @@ class FederationClient(FederationBase):
         Return:
             Deferred: resolves to None.
 
-            Fails with a ``CodeMessageException`` if the chosen remote server
-            returns a non-200 code.
+            Fails with a ``SynapseError`` if the chosen remote server
+            returns a 300/400 code.
 
             Fails with a ``RuntimeError`` if no servers were reachable.
         """
-        for destination in destinations:
-            if destination == self.server_name:
-                continue
-
-            try:
-                time_now = self._clock.time_msec()
-                _, content = yield self.transport_layer.send_leave(
-                    destination=destination,
-                    room_id=pdu.room_id,
-                    event_id=pdu.event_id,
-                    content=pdu.get_pdu_json(time_now),
-                )
+        @defer.inlineCallbacks
+        def send_request(destination):
+            time_now = self._clock.time_msec()
+            _, content = yield self.transport_layer.send_leave(
+                destination=destination,
+                room_id=pdu.room_id,
+                event_id=pdu.event_id,
+                content=pdu.get_pdu_json(time_now),
+            )
 
-                logger.debug("Got content: %s", content)
-                defer.returnValue(None)
-            except CodeMessageException:
-                raise
-            except Exception as e:
-                logger.exception(
-                    "Failed to send_leave via %s: %s",
-                    destination, e.message
-                )
+            logger.debug("Got content: %s", content)
+            defer.returnValue(None)
 
-        raise RuntimeError("Failed to send to any server.")
+        return self._try_destination_list("send_leave", destinations, send_request)
 
     def get_public_rooms(self, destination, limit=None, since_token=None,
                          search_filter=None, include_all_networks=False,