summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py158
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/federation_server.py30
-rw-r--r--synapse/federation/persistence.py8
-rw-r--r--synapse/federation/transaction_queue.py12
-rw-r--r--synapse/federation/transport/client.py5
-rw-r--r--synapse/federation/transport/server.py26
7 files changed, 186 insertions, 61 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index c11798093d..b7ad729c63 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -13,17 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from collections import namedtuple
 
 import six
 
 from twisted.internet import defer
+from twisted.internet.defer import DeferredList
 
-from synapse.api.constants import MAX_DEPTH
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.event_signing import check_event_content_hash
 from synapse.events import FrozenEvent
 from synapse.events.utils import prune_event
 from synapse.http.servlet import assert_params_in_dict
+from synapse.types import get_domain_from_id
 from synapse.util import logcontext, unwrapFirstError
 
 logger = logging.getLogger(__name__)
@@ -133,34 +136,45 @@ class FederationBase(object):
               * throws a SynapseError if the signature check failed.
             The deferreds run their callbacks in the sentinel logcontext.
         """
-
-        redacted_pdus = [
-            prune_event(pdu)
-            for pdu in pdus
-        ]
-
-        deferreds = self.keyring.verify_json_objects_for_server([
-            (p.origin, p.get_pdu_json())
-            for p in redacted_pdus
-        ])
+        deferreds = _check_sigs_on_pdus(self.keyring, pdus)
 
         ctx = logcontext.LoggingContext.current_context()
 
-        def callback(_, pdu, redacted):
+        def callback(_, pdu):
             with logcontext.PreserveLoggingContext(ctx):
                 if not check_event_content_hash(pdu):
-                    logger.warn(
-                        "Event content has been tampered, redacting %s: %s",
-                        pdu.event_id, pdu.get_pdu_json()
-                    )
-                    return redacted
+                    # let's try to distinguish between failures because the event was
+                    # redacted (which are somewhat expected) vs actual ball-tampering
+                    # incidents.
+                    #
+                    # This is just a heuristic, so we just assume that if the keys are
+                    # about the same between the redacted and received events, then the
+                    # received event was probably a redacted copy (but we then use our
+                    # *actual* redacted copy to be on the safe side.)
+                    redacted_event = prune_event(pdu)
+                    if (
+                        set(redacted_event.keys()) == set(pdu.keys()) and
+                        set(six.iterkeys(redacted_event.content))
+                            == set(six.iterkeys(pdu.content))
+                    ):
+                        logger.info(
+                            "Event %s seems to have been redacted; using our redacted "
+                            "copy",
+                            pdu.event_id,
+                        )
+                    else:
+                        logger.warning(
+                            "Event %s content has been tampered, redacting",
+                            pdu.event_id, pdu.get_pdu_json(),
+                        )
+                    return redacted_event
 
                 if self.spam_checker.check_event_for_spam(pdu):
                     logger.warn(
                         "Event contains spam, redacting %s: %s",
                         pdu.event_id, pdu.get_pdu_json()
                     )
-                    return redacted
+                    return prune_event(pdu)
 
                 return pdu
 
@@ -168,21 +182,121 @@ class FederationBase(object):
             failure.trap(SynapseError)
             with logcontext.PreserveLoggingContext(ctx):
                 logger.warn(
-                    "Signature check failed for %s",
-                    pdu.event_id,
+                    "Signature check failed for %s: %s",
+                    pdu.event_id, failure.getErrorMessage(),
                 )
             return failure
 
-        for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+        for deferred, pdu in zip(deferreds, pdus):
             deferred.addCallbacks(
                 callback, errback,
-                callbackArgs=[pdu, redacted],
+                callbackArgs=[pdu],
                 errbackArgs=[pdu],
             )
 
         return deferreds
 
 
+class PduToCheckSig(namedtuple("PduToCheckSig", [
+    "pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
+])):
+    pass
+
+
+def _check_sigs_on_pdus(keyring, pdus):
+    """Check that the given events are correctly signed
+
+    Args:
+        keyring (synapse.crypto.Keyring): keyring object to do the checks
+        pdus (Collection[EventBase]): the events to be checked
+
+    Returns:
+        List[Deferred]: a Deferred for each event in pdus, which will either succeed if
+           the signatures are valid, or fail (with a SynapseError) if not.
+    """
+
+    # (currently this is written assuming the v1 room structure; we'll probably want a
+    # separate function for checking v2 rooms)
+
+    # we want to check that the event is signed by:
+    #
+    # (a) the server which created the event_id
+    #
+    # (b) the sender's server.
+    #
+    #     - except in the case of invites created from a 3pid invite, which are exempt
+    #     from this check, because the sender has to match that of the original 3pid
+    #     invite, but the event may come from a different HS, for reasons that I don't
+    #     entirely grok (why do the senders have to match? and if they do, why doesn't the
+    #     joining server ask the inviting server to do the switcheroo with
+    #     exchange_third_party_invite?).
+    #
+    #     That's pretty awful, since redacting such an invite will render it invalid
+    #     (because it will then look like a regular invite without a valid signature),
+    #     and signatures are *supposed* to be valid whether or not an event has been
+    #     redacted. But this isn't the worst of the ways that 3pid invites are broken.
+    #
+    # let's start by getting the domain for each pdu, and flattening the event back
+    # to JSON.
+    pdus_to_check = [
+        PduToCheckSig(
+            pdu=p,
+            redacted_pdu_json=prune_event(p).get_pdu_json(),
+            event_id_domain=get_domain_from_id(p.event_id),
+            sender_domain=get_domain_from_id(p.sender),
+            deferreds=[],
+        )
+        for p in pdus
+    ]
+
+    # first make sure that the event is signed by the event_id's domain
+    deferreds = keyring.verify_json_objects_for_server([
+        (p.event_id_domain, p.redacted_pdu_json)
+        for p in pdus_to_check
+    ])
+
+    for p, d in zip(pdus_to_check, deferreds):
+        p.deferreds.append(d)
+
+    # now let's look for events where the sender's domain is different to the
+    # event id's domain (normally only the case for joins/leaves), and add additional
+    # checks.
+    pdus_to_check_sender = [
+        p for p in pdus_to_check
+        if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
+    ]
+
+    more_deferreds = keyring.verify_json_objects_for_server([
+        (p.sender_domain, p.redacted_pdu_json)
+        for p in pdus_to_check_sender
+    ])
+
+    for p, d in zip(pdus_to_check_sender, more_deferreds):
+        p.deferreds.append(d)
+
+    # replace lists of deferreds with single Deferreds
+    return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
+
+
+def _flatten_deferred_list(deferreds):
+    """Given a list of one or more deferreds, either return the single deferred, or
+    combine into a DeferredList.
+    """
+    if len(deferreds) > 1:
+        return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
+    else:
+        assert len(deferreds) == 1
+        return deferreds[0]
+
+
+def _is_invite_via_3pid(event):
+    return (
+        event.type == EventTypes.Member
+        and event.membership == Membership.INVITE
+        and "third_party_invite" in event.content
+    )
+
+
 def event_from_pdu_json(pdu_json, outlier=False):
     """Construct a FrozenEvent from an event json received over federation
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c9f3c2d352..fe67b2ff42 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -271,10 +271,10 @@ class FederationClient(FederationBase):
                     event_id, destination, e,
                 )
             except NotRetryingDestination as e:
-                logger.info(e.message)
+                logger.info(str(e))
                 continue
             except FederationDeniedError as e:
-                logger.info(e.message)
+                logger.info(str(e))
                 continue
             except Exception as e:
                 pdu_attempts[destination] = now
@@ -510,7 +510,7 @@ class FederationClient(FederationBase):
                 else:
                     logger.warn(
                         "Failed to %s via %s: %i %s",
-                        description, destination, e.code, e.message,
+                        description, destination, e.code, e.args[0],
                     )
             except Exception:
                 logger.warn(
@@ -875,7 +875,7 @@ class FederationClient(FederationBase):
             except Exception as e:
                 logger.exception(
                     "Failed to send_third_party_invite via %s: %s",
-                    destination, e.message
+                    destination, str(e)
                 )
 
         raise RuntimeError("Failed to send to any server.")
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 3e0cd294a1..dbee404ea7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -99,7 +99,7 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     @log_function
-    def on_incoming_transaction(self, transaction_data):
+    def on_incoming_transaction(self, origin, transaction_data):
         # keep this as early as possible to make the calculated origin ts as
         # accurate as possible.
         request_time = self._clock.time_msec()
@@ -108,34 +108,33 @@ class FederationServer(FederationBase):
 
         if not transaction.transaction_id:
             raise Exception("Transaction missing transaction_id")
-        if not transaction.origin:
-            raise Exception("Transaction missing origin")
 
         logger.debug("[%s] Got transaction", transaction.transaction_id)
 
         # use a linearizer to ensure that we don't process the same transaction
         # multiple times in parallel.
         with (yield self._transaction_linearizer.queue(
-                (transaction.origin, transaction.transaction_id),
+                (origin, transaction.transaction_id),
         )):
             result = yield self._handle_incoming_transaction(
-                transaction, request_time,
+                origin, transaction, request_time,
             )
 
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _handle_incoming_transaction(self, transaction, request_time):
+    def _handle_incoming_transaction(self, origin, transaction, request_time):
         """ Process an incoming transaction and return the HTTP response
 
         Args:
+            origin (unicode): the server making the request
             transaction (Transaction): incoming transaction
             request_time (int): timestamp that the HTTP request arrived at
 
         Returns:
             Deferred[(int, object)]: http response code and body
         """
-        response = yield self.transaction_actions.have_responded(transaction)
+        response = yield self.transaction_actions.have_responded(origin, transaction)
 
         if response:
             logger.debug(
@@ -149,7 +148,7 @@ class FederationServer(FederationBase):
 
         received_pdus_counter.inc(len(transaction.pdus))
 
-        origin_host, _ = parse_server_name(transaction.origin)
+        origin_host, _ = parse_server_name(origin)
 
         pdus_by_room = {}
 
@@ -190,7 +189,7 @@ class FederationServer(FederationBase):
                 event_id = pdu.event_id
                 try:
                     yield self._handle_received_pdu(
-                        transaction.origin, pdu
+                        origin, pdu
                     )
                     pdu_results[event_id] = {}
                 except FederationError as e:
@@ -212,7 +211,7 @@ class FederationServer(FederationBase):
         if hasattr(transaction, "edus"):
             for edu in (Edu(**x) for x in transaction.edus):
                 yield self.received_edu(
-                    transaction.origin,
+                    origin,
                     edu.edu_type,
                     edu.content
                 )
@@ -224,6 +223,7 @@ class FederationServer(FederationBase):
         logger.debug("Returning: %s", str(response))
 
         yield self.transaction_actions.set_response(
+            origin,
             transaction,
             200, response
         )
@@ -838,9 +838,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
             )
 
         return self._send_edu(
-                edu_type=edu_type,
-                origin=origin,
-                content=content,
+            edu_type=edu_type,
+            origin=origin,
+            content=content,
         )
 
     def on_query(self, query_type, args):
@@ -851,6 +851,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
             return handler(args)
 
         return self._get_query_client(
-                query_type=query_type,
-                args=args,
+            query_type=query_type,
+            args=args,
         )
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 9146215c21..74ffd13b4f 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -36,7 +36,7 @@ class TransactionActions(object):
         self.store = datastore
 
     @log_function
-    def have_responded(self, transaction):
+    def have_responded(self, origin, transaction):
         """ Have we already responded to a transaction with the same id and
         origin?
 
@@ -50,11 +50,11 @@ class TransactionActions(object):
                                "transaction_id")
 
         return self.store.get_received_txn_response(
-            transaction.transaction_id, transaction.origin
+            transaction.transaction_id, origin
         )
 
     @log_function
-    def set_response(self, transaction, code, response):
+    def set_response(self, origin, transaction, code, response):
         """ Persist how we responded to a transaction.
 
         Returns:
@@ -66,7 +66,7 @@ class TransactionActions(object):
 
         return self.store.set_received_txn_response(
             transaction.transaction_id,
-            transaction.origin,
+            origin,
             code,
             response,
         )
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 94d7423d01..8cbf8c4f7f 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -463,7 +463,19 @@ class TransactionQueue(object):
                 # pending_transactions flag.
 
                 pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+
+                # We can only include at most 50 PDUs per transactions
+                pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:]
+                if leftover_pdus:
+                    self.pending_pdus_by_dest[destination] = leftover_pdus
+
                 pending_edus = self.pending_edus_by_dest.pop(destination, [])
+
+                # We can only include at most 100 EDUs per transactions
+                pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:]
+                if leftover_edus:
+                    self.pending_edus_by_dest[destination] = leftover_edus
+
                 pending_presence = self.pending_presence_by_dest.pop(destination, {})
 
                 pending_edus.extend(
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 1054441ca5..2ab973d6c8 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -15,7 +15,8 @@
 # limitations under the License.
 
 import logging
-import urllib
+
+from six.moves import urllib
 
 from twisted.internet import defer
 
@@ -951,4 +952,4 @@ def _create_path(prefix, path, *args):
     Returns:
         str
     """
-    return prefix + path % tuple(urllib.quote(arg, "") for arg in args)
+    return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args)
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 7a993fd1cf..2f874b4838 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -90,8 +90,8 @@ class Authenticator(object):
     @defer.inlineCallbacks
     def authenticate_request(self, request, content):
         json_request = {
-            "method": request.method,
-            "uri": request.uri,
+            "method": request.method.decode('ascii'),
+            "uri": request.uri.decode('ascii'),
             "destination": self.server_name,
             "signatures": {},
         }
@@ -252,7 +252,7 @@ class BaseFederationServlet(object):
                     by the callback method. None if the request has already been handled.
             """
             content = None
-            if request.method in ["PUT", "POST"]:
+            if request.method in [b"PUT", b"POST"]:
                 # TODO: Handle other method types? other content types?
                 content = parse_json_object_from_request(request)
 
@@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet):
 
         try:
             code, response = yield self.handler.on_incoming_transaction(
-                transaction_data
+                origin, transaction_data,
             )
         except Exception:
             logger.exception("on_incoming_transaction failed")
@@ -386,7 +386,7 @@ class FederationStateServlet(BaseFederationServlet):
         return self.handler.on_context_state_request(
             origin,
             context,
-            query.get("event_id", [None])[0],
+            parse_string_from_args(query, "event_id", None),
         )
 
 
@@ -397,7 +397,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
         return self.handler.on_state_ids_request(
             origin,
             room_id,
-            query.get("event_id", [None])[0],
+            parse_string_from_args(query, "event_id", None),
         )
 
 
@@ -405,14 +405,12 @@ class FederationBackfillServlet(BaseFederationServlet):
     PATH = "/backfill/(?P<context>[^/]*)/"
 
     def on_GET(self, origin, content, query, context):
-        versions = query["v"]
-        limits = query["limit"]
+        versions = [x.decode('ascii') for x in query[b"v"]]
+        limit = parse_integer_from_args(query, "limit", None)
 
-        if not limits:
+        if not limit:
             return defer.succeed((400, {"error": "Did not include limit param"}))
 
-        limit = int(limits[-1])
-
         return self.handler.on_backfill_request(origin, context, versions, limit)
 
 
@@ -423,7 +421,7 @@ class FederationQueryServlet(BaseFederationServlet):
     def on_GET(self, origin, content, query, query_type):
         return self.handler.on_query_request(
             query_type,
-            {k: v[0].decode("utf-8") for k, v in query.items()}
+            {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()}
         )
 
 
@@ -630,14 +628,14 @@ class OpenIdUserInfo(BaseFederationServlet):
 
     @defer.inlineCallbacks
     def on_GET(self, origin, content, query):
-        token = query.get("access_token", [None])[0]
+        token = query.get(b"access_token", [None])[0]
         if token is None:
             defer.returnValue((401, {
                 "errcode": "M_MISSING_TOKEN", "error": "Access Token required"
             }))
             return
 
-        user_id = yield self.handler.on_openid_userinfo(token)
+        user_id = yield self.handler.on_openid_userinfo(token.decode('ascii'))
 
         if user_id is None:
             defer.returnValue((401, {