summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/__init__.py2
-rw-r--r--synapse/federation/federation_base.py126
-rw-r--r--synapse/federation/federation_server.py20
-rw-r--r--synapse/federation/persistence.py8
-rw-r--r--synapse/federation/transport/server.py2
-rw-r--r--synapse/handlers/auth.py8
-rw-r--r--synapse/handlers/e2e_keys.py5
-rw-r--r--synapse/handlers/federation.py6
-rw-r--r--synapse/handlers/room_list.py2
-rw-r--r--synapse/handlers/search.py14
-rw-r--r--synapse/handlers/sync.py6
11 files changed, 148 insertions, 51 deletions
diff --git a/synapse/__init__.py b/synapse/__init__.py
index e62901b761..e395a0a9ee 100644
--- a/synapse/__init__.py
+++ b/synapse/__init__.py
@@ -17,4 +17,4 @@
 """ This is a reference implementation of a Matrix home server.
 """
 
-__version__ = "0.33.3"
+__version__ = "0.33.3.1"
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index c11798093d..5be8e66fb8 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -13,17 +13,20 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from collections import namedtuple
 
 import six
 
 from twisted.internet import defer
+from twisted.internet.defer import DeferredList
 
-from synapse.api.constants import MAX_DEPTH
+from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
 from synapse.crypto.event_signing import check_event_content_hash
 from synapse.events import FrozenEvent
 from synapse.events.utils import prune_event
 from synapse.http.servlet import assert_params_in_dict
+from synapse.types import get_domain_from_id
 from synapse.util import logcontext, unwrapFirstError
 
 logger = logging.getLogger(__name__)
@@ -133,34 +136,25 @@ class FederationBase(object):
               * throws a SynapseError if the signature check failed.
             The deferreds run their callbacks in the sentinel logcontext.
         """
-
-        redacted_pdus = [
-            prune_event(pdu)
-            for pdu in pdus
-        ]
-
-        deferreds = self.keyring.verify_json_objects_for_server([
-            (p.origin, p.get_pdu_json())
-            for p in redacted_pdus
-        ])
+        deferreds = _check_sigs_on_pdus(self.keyring, pdus)
 
         ctx = logcontext.LoggingContext.current_context()
 
-        def callback(_, pdu, redacted):
+        def callback(_, pdu):
             with logcontext.PreserveLoggingContext(ctx):
                 if not check_event_content_hash(pdu):
                     logger.warn(
                         "Event content has been tampered, redacting %s: %s",
                         pdu.event_id, pdu.get_pdu_json()
                     )
-                    return redacted
+                    return prune_event(pdu)
 
                 if self.spam_checker.check_event_for_spam(pdu):
                     logger.warn(
                         "Event contains spam, redacting %s: %s",
                         pdu.event_id, pdu.get_pdu_json()
                     )
-                    return redacted
+                    return prune_event(pdu)
 
                 return pdu
 
@@ -173,16 +167,116 @@ class FederationBase(object):
                 )
             return failure
 
-        for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+        for deferred, pdu in zip(deferreds, pdus):
             deferred.addCallbacks(
                 callback, errback,
-                callbackArgs=[pdu, redacted],
+                callbackArgs=[pdu],
                 errbackArgs=[pdu],
             )
 
         return deferreds
 
 
+class PduToCheckSig(namedtuple("PduToCheckSig", [
+    "pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds",
+])):
+    pass
+
+
+def _check_sigs_on_pdus(keyring, pdus):
+    """Check that the given events are correctly signed
+
+    Args:
+        keyring (synapse.crypto.Keyring): keyring object to do the checks
+        pdus (Collection[EventBase]): the events to be checked
+
+    Returns:
+        List[Deferred]: a Deferred for each event in pdus, which will either succeed if
+           the signatures are valid, or fail (with a SynapseError) if not.
+    """
+
+    # (currently this is written assuming the v1 room structure; we'll probably want a
+    # separate function for checking v2 rooms)
+
+    # we want to check that the event is signed by:
+    #
+    # (a) the server which created the event_id
+    #
+    # (b) the sender's server.
+    #
+    #     - except in the case of invites created from a 3pid invite, which are exempt
+    #     from this check, because the sender has to match that of the original 3pid
+    #     invite, but the event may come from a different HS, for reasons that I don't
+    #     entirely grok (why do the senders have to match? and if they do, why doesn't the
+    #     joining server ask the inviting server to do the switcheroo with
+    #     exchange_third_party_invite?).
+    #
+    #     That's pretty awful, since redacting such an invite will render it invalid
+    #     (because it will then look like a regular invite without a valid signature),
+    #     and signatures are *supposed* to be valid whether or not an event has been
+    #     redacted. But this isn't the worst of the ways that 3pid invites are broken.
+    #
+    # let's start by getting the domain for each pdu, and flattening the event back
+    # to JSON.
+    pdus_to_check = [
+        PduToCheckSig(
+            pdu=p,
+            redacted_pdu_json=prune_event(p).get_pdu_json(),
+            event_id_domain=get_domain_from_id(p.event_id),
+            sender_domain=get_domain_from_id(p.sender),
+            deferreds=[],
+        )
+        for p in pdus
+    ]
+
+    # first make sure that the event is signed by the event_id's domain
+    deferreds = keyring.verify_json_objects_for_server([
+        (p.event_id_domain, p.redacted_pdu_json)
+        for p in pdus_to_check
+    ])
+
+    for p, d in zip(pdus_to_check, deferreds):
+        p.deferreds.append(d)
+
+    # now let's look for events where the sender's domain is different to the
+    # event id's domain (normally only the case for joins/leaves), and add additional
+    # checks.
+    pdus_to_check_sender = [
+        p for p in pdus_to_check
+        if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu)
+    ]
+
+    more_deferreds = keyring.verify_json_objects_for_server([
+        (p.sender_domain, p.redacted_pdu_json)
+        for p in pdus_to_check_sender
+    ])
+
+    for p, d in zip(pdus_to_check_sender, more_deferreds):
+        p.deferreds.append(d)
+
+    # replace lists of deferreds with single Deferreds
+    return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
+
+
+def _flatten_deferred_list(deferreds):
+    """Given a list of one or more deferreds, either return the single deferred, or
+    combine into a DeferredList.
+    """
+    if len(deferreds) > 1:
+        return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True)
+    else:
+        assert len(deferreds) == 1
+        return deferreds[0]
+
+
+def _is_invite_via_3pid(event):
+    return (
+        event.type == EventTypes.Member
+        and event.membership == Membership.INVITE
+        and "third_party_invite" in event.content
+    )
+
+
 def event_from_pdu_json(pdu_json, outlier=False):
     """Construct a FrozenEvent from an event json received over federation
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 6e52c4b6b5..dbee404ea7 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -99,7 +99,7 @@ class FederationServer(FederationBase):
 
     @defer.inlineCallbacks
     @log_function
-    def on_incoming_transaction(self, transaction_data):
+    def on_incoming_transaction(self, origin, transaction_data):
         # keep this as early as possible to make the calculated origin ts as
         # accurate as possible.
         request_time = self._clock.time_msec()
@@ -108,34 +108,33 @@ class FederationServer(FederationBase):
 
         if not transaction.transaction_id:
             raise Exception("Transaction missing transaction_id")
-        if not transaction.origin:
-            raise Exception("Transaction missing origin")
 
         logger.debug("[%s] Got transaction", transaction.transaction_id)
 
         # use a linearizer to ensure that we don't process the same transaction
         # multiple times in parallel.
         with (yield self._transaction_linearizer.queue(
-                (transaction.origin, transaction.transaction_id),
+                (origin, transaction.transaction_id),
         )):
             result = yield self._handle_incoming_transaction(
-                transaction, request_time,
+                origin, transaction, request_time,
             )
 
         defer.returnValue(result)
 
     @defer.inlineCallbacks
-    def _handle_incoming_transaction(self, transaction, request_time):
+    def _handle_incoming_transaction(self, origin, transaction, request_time):
         """ Process an incoming transaction and return the HTTP response
 
         Args:
+            origin (unicode): the server making the request
             transaction (Transaction): incoming transaction
             request_time (int): timestamp that the HTTP request arrived at
 
         Returns:
             Deferred[(int, object)]: http response code and body
         """
-        response = yield self.transaction_actions.have_responded(transaction)
+        response = yield self.transaction_actions.have_responded(origin, transaction)
 
         if response:
             logger.debug(
@@ -149,7 +148,7 @@ class FederationServer(FederationBase):
 
         received_pdus_counter.inc(len(transaction.pdus))
 
-        origin_host, _ = parse_server_name(transaction.origin)
+        origin_host, _ = parse_server_name(origin)
 
         pdus_by_room = {}
 
@@ -190,7 +189,7 @@ class FederationServer(FederationBase):
                 event_id = pdu.event_id
                 try:
                     yield self._handle_received_pdu(
-                        transaction.origin, pdu
+                        origin, pdu
                     )
                     pdu_results[event_id] = {}
                 except FederationError as e:
@@ -212,7 +211,7 @@ class FederationServer(FederationBase):
         if hasattr(transaction, "edus"):
             for edu in (Edu(**x) for x in transaction.edus):
                 yield self.received_edu(
-                    transaction.origin,
+                    origin,
                     edu.edu_type,
                     edu.content
                 )
@@ -224,6 +223,7 @@ class FederationServer(FederationBase):
         logger.debug("Returning: %s", str(response))
 
         yield self.transaction_actions.set_response(
+            origin,
             transaction,
             200, response
         )
diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py
index 9146215c21..74ffd13b4f 100644
--- a/synapse/federation/persistence.py
+++ b/synapse/federation/persistence.py
@@ -36,7 +36,7 @@ class TransactionActions(object):
         self.store = datastore
 
     @log_function
-    def have_responded(self, transaction):
+    def have_responded(self, origin, transaction):
         """ Have we already responded to a transaction with the same id and
         origin?
 
@@ -50,11 +50,11 @@ class TransactionActions(object):
                                "transaction_id")
 
         return self.store.get_received_txn_response(
-            transaction.transaction_id, transaction.origin
+            transaction.transaction_id, origin
         )
 
     @log_function
-    def set_response(self, transaction, code, response):
+    def set_response(self, origin, transaction, code, response):
         """ Persist how we responded to a transaction.
 
         Returns:
@@ -66,7 +66,7 @@ class TransactionActions(object):
 
         return self.store.set_received_txn_response(
             transaction.transaction_id,
-            transaction.origin,
+            origin,
             code,
             response,
         )
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 7a993fd1cf..3972922ff9 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet):
 
         try:
             code, response = yield self.handler.on_incoming_transaction(
-                transaction_data
+                origin, transaction_data,
             )
         except Exception:
             logger.exception("on_incoming_transaction failed")
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 4a81bd2ba9..2a5eab124f 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -895,22 +895,24 @@ class AuthHandler(BaseHandler):
 
         Args:
             password (unicode): Password to hash.
-            stored_hash (unicode): Expected hash value.
+            stored_hash (bytes): Expected hash value.
 
         Returns:
             Deferred(bool): Whether self.hash(password) == stored_hash.
         """
-
         def _do_validate_hash():
             # Normalise the Unicode in the password
             pw = unicodedata.normalize("NFKC", password)
 
             return bcrypt.checkpw(
                 pw.encode('utf8') + self.hs.config.password_pepper.encode("utf8"),
-                stored_hash.encode('utf8')
+                stored_hash
             )
 
         if stored_hash:
+            if not isinstance(stored_hash, bytes):
+                stored_hash = stored_hash.encode('ascii')
+
             return make_deferred_yieldable(
                 threads.deferToThreadPool(
                     self.hs.get_reactor(),
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 5816bf8b4f..578e9250fb 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -330,7 +330,8 @@ class E2eKeysHandler(object):
                         (algorithm, key_id, ex_json, key)
                     )
             else:
-                new_keys.append((algorithm, key_id, encode_canonical_json(key)))
+                new_keys.append((
+                    algorithm, key_id, encode_canonical_json(key).decode('ascii')))
 
         yield self.store.add_e2e_one_time_keys(
             user_id, device_id, time_now, new_keys
@@ -358,7 +359,7 @@ def _exception_to_failure(e):
     # Note that some Exceptions (notably twisted's ResponseFailed etc) don't
     # give a string for e.message, which json then fails to serialize.
     return {
-        "status": 503, "message": str(e.message),
+        "status": 503, "message": str(e),
     }
 
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3fa7a98445..0c68e8a472 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -594,7 +594,7 @@ class FederationHandler(BaseHandler):
 
         required_auth = set(
             a_id
-            for event in events + state_events.values() + auth_events.values()
+            for event in events + list(state_events.values()) + list(auth_events.values())
             for a_id, _ in event.auth_events
         )
         auth_events.update({
@@ -802,7 +802,7 @@ class FederationHandler(BaseHandler):
                     )
                     continue
                 except NotRetryingDestination as e:
-                    logger.info(e.message)
+                    logger.info(str(e))
                     continue
                 except FederationDeniedError as e:
                     logger.info(e)
@@ -1358,7 +1358,7 @@ class FederationHandler(BaseHandler):
         )
 
         if state_groups:
-            _, state = state_groups.items().pop()
+            _, state = list(state_groups.items()).pop()
             results = state
 
             if event.is_state():
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 37e41afd61..38e1737ec9 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -162,7 +162,7 @@ class RoomListHandler(BaseHandler):
         # Filter out rooms that we don't want to return
         rooms_to_scan = [
             r for r in sorted_rooms
-            if r not in newly_unpublished and rooms_to_num_joined[room_id] > 0
+            if r not in newly_unpublished and rooms_to_num_joined[r] > 0
         ]
 
         total_room_count = len(rooms_to_scan)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index c464adbd0b..0c1d52fd11 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -54,7 +54,7 @@ class SearchHandler(BaseHandler):
         batch_token = None
         if batch:
             try:
-                b = decode_base64(batch)
+                b = decode_base64(batch).decode('ascii')
                 batch_group, batch_group_key, batch_token = b.split("\n")
 
                 assert batch_group is not None
@@ -258,18 +258,18 @@ class SearchHandler(BaseHandler):
                 # it returns more from the same group (if applicable) rather
                 # than reverting to searching all results again.
                 if batch_group and batch_group_key:
-                    global_next_batch = encode_base64("%s\n%s\n%s" % (
+                    global_next_batch = encode_base64(("%s\n%s\n%s" % (
                         batch_group, batch_group_key, pagination_token
-                    ))
+                    )).encode('ascii'))
                 else:
-                    global_next_batch = encode_base64("%s\n%s\n%s" % (
+                    global_next_batch = encode_base64(("%s\n%s\n%s" % (
                         "all", "", pagination_token
-                    ))
+                    )).encode('ascii'))
 
                 for room_id, group in room_groups.items():
-                    group["next_batch"] = encode_base64("%s\n%s\n%s" % (
+                    group["next_batch"] = encode_base64(("%s\n%s\n%s" % (
                         "room_id", room_id, pagination_token
-                    ))
+                    )).encode('ascii'))
 
             allowed_events.extend(room_events)
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ef20c2296c..0091ceb80e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -545,7 +545,7 @@ class SyncHandler(object):
 
         member_ids = {
             state_key: event_id
-            for (t, state_key), event_id in state_ids.iteritems()
+            for (t, state_key), event_id in iteritems(state_ids)
             if t == EventTypes.Member
         }
         name_id = state_ids.get((EventTypes.Name, ''))
@@ -774,7 +774,7 @@ class SyncHandler(object):
                     logger.debug("filtering state from %r...", state_ids)
                     state_ids = {
                         t: event_id
-                        for t, event_id in state_ids.iteritems()
+                        for t, event_id in iteritems(state_ids)
                         if cache.get(t[1]) != event_id
                     }
                     logger.debug("...to %r", state_ids)
@@ -1753,7 +1753,7 @@ def _calculate_state(
 
     if lazy_load_members:
         p_ids.difference_update(
-            e for t, e in timeline_start.iteritems()
+            e for t, e in iteritems(timeline_start)
             if t[0] == EventTypes.Member
         )