diff options
Diffstat (limited to 'synapse/federation')
-rw-r--r-- | synapse/federation/federation_base.py | 158 | ||||
-rw-r--r-- | synapse/federation/federation_client.py | 46 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 62 | ||||
-rw-r--r-- | synapse/federation/persistence.py | 8 | ||||
-rw-r--r-- | synapse/federation/send_queue.py | 4 | ||||
-rw-r--r-- | synapse/federation/transaction_queue.py | 48 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 5 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 26 |
8 files changed, 227 insertions, 130 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index c11798093d..b7ad729c63 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -13,17 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +from collections import namedtuple import six from twisted.internet import defer +from twisted.internet.defer import DeferredList -from synapse.api.constants import MAX_DEPTH +from synapse.api.constants import MAX_DEPTH, EventTypes, Membership from synapse.api.errors import Codes, SynapseError from synapse.crypto.event_signing import check_event_content_hash from synapse.events import FrozenEvent from synapse.events.utils import prune_event from synapse.http.servlet import assert_params_in_dict +from synapse.types import get_domain_from_id from synapse.util import logcontext, unwrapFirstError logger = logging.getLogger(__name__) @@ -133,34 +136,45 @@ class FederationBase(object): * throws a SynapseError if the signature check failed. The deferreds run their callbacks in the sentinel logcontext. """ - - redacted_pdus = [ - prune_event(pdu) - for pdu in pdus - ] - - deferreds = self.keyring.verify_json_objects_for_server([ - (p.origin, p.get_pdu_json()) - for p in redacted_pdus - ]) + deferreds = _check_sigs_on_pdus(self.keyring, pdus) ctx = logcontext.LoggingContext.current_context() - def callback(_, pdu, redacted): + def callback(_, pdu): with logcontext.PreserveLoggingContext(ctx): if not check_event_content_hash(pdu): - logger.warn( - "Event content has been tampered, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() - ) - return redacted + # let's try to distinguish between failures because the event was + # redacted (which are somewhat expected) vs actual ball-tampering + # incidents. + # + # This is just a heuristic, so we just assume that if the keys are + # about the same between the redacted and received events, then the + # received event was probably a redacted copy (but we then use our + # *actual* redacted copy to be on the safe side.) + redacted_event = prune_event(pdu) + if ( + set(redacted_event.keys()) == set(pdu.keys()) and + set(six.iterkeys(redacted_event.content)) + == set(six.iterkeys(pdu.content)) + ): + logger.info( + "Event %s seems to have been redacted; using our redacted " + "copy", + pdu.event_id, + ) + else: + logger.warning( + "Event %s content has been tampered, redacting", + pdu.event_id, pdu.get_pdu_json(), + ) + return redacted_event if self.spam_checker.check_event_for_spam(pdu): logger.warn( "Event contains spam, redacting %s: %s", pdu.event_id, pdu.get_pdu_json() ) - return redacted + return prune_event(pdu) return pdu @@ -168,21 +182,121 @@ class FederationBase(object): failure.trap(SynapseError) with logcontext.PreserveLoggingContext(ctx): logger.warn( - "Signature check failed for %s", - pdu.event_id, + "Signature check failed for %s: %s", + pdu.event_id, failure.getErrorMessage(), ) return failure - for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus): + for deferred, pdu in zip(deferreds, pdus): deferred.addCallbacks( callback, errback, - callbackArgs=[pdu, redacted], + callbackArgs=[pdu], errbackArgs=[pdu], ) return deferreds +class PduToCheckSig(namedtuple("PduToCheckSig", [ + "pdu", "redacted_pdu_json", "event_id_domain", "sender_domain", "deferreds", +])): + pass + + +def _check_sigs_on_pdus(keyring, pdus): + """Check that the given events are correctly signed + + Args: + keyring (synapse.crypto.Keyring): keyring object to do the checks + pdus (Collection[EventBase]): the events to be checked + + Returns: + List[Deferred]: a Deferred for each event in pdus, which will either succeed if + the signatures are valid, or fail (with a SynapseError) if not. + """ + + # (currently this is written assuming the v1 room structure; we'll probably want a + # separate function for checking v2 rooms) + + # we want to check that the event is signed by: + # + # (a) the server which created the event_id + # + # (b) the sender's server. + # + # - except in the case of invites created from a 3pid invite, which are exempt + # from this check, because the sender has to match that of the original 3pid + # invite, but the event may come from a different HS, for reasons that I don't + # entirely grok (why do the senders have to match? and if they do, why doesn't the + # joining server ask the inviting server to do the switcheroo with + # exchange_third_party_invite?). + # + # That's pretty awful, since redacting such an invite will render it invalid + # (because it will then look like a regular invite without a valid signature), + # and signatures are *supposed* to be valid whether or not an event has been + # redacted. But this isn't the worst of the ways that 3pid invites are broken. + # + # let's start by getting the domain for each pdu, and flattening the event back + # to JSON. + pdus_to_check = [ + PduToCheckSig( + pdu=p, + redacted_pdu_json=prune_event(p).get_pdu_json(), + event_id_domain=get_domain_from_id(p.event_id), + sender_domain=get_domain_from_id(p.sender), + deferreds=[], + ) + for p in pdus + ] + + # first make sure that the event is signed by the event_id's domain + deferreds = keyring.verify_json_objects_for_server([ + (p.event_id_domain, p.redacted_pdu_json) + for p in pdus_to_check + ]) + + for p, d in zip(pdus_to_check, deferreds): + p.deferreds.append(d) + + # now let's look for events where the sender's domain is different to the + # event id's domain (normally only the case for joins/leaves), and add additional + # checks. + pdus_to_check_sender = [ + p for p in pdus_to_check + if p.sender_domain != p.event_id_domain and not _is_invite_via_3pid(p.pdu) + ] + + more_deferreds = keyring.verify_json_objects_for_server([ + (p.sender_domain, p.redacted_pdu_json) + for p in pdus_to_check_sender + ]) + + for p, d in zip(pdus_to_check_sender, more_deferreds): + p.deferreds.append(d) + + # replace lists of deferreds with single Deferreds + return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check] + + +def _flatten_deferred_list(deferreds): + """Given a list of one or more deferreds, either return the single deferred, or + combine into a DeferredList. + """ + if len(deferreds) > 1: + return DeferredList(deferreds, fireOnOneErrback=True, consumeErrors=True) + else: + assert len(deferreds) == 1 + return deferreds[0] + + +def _is_invite_via_3pid(event): + return ( + event.type == EventTypes.Member + and event.membership == Membership.INVITE + and "third_party_invite" in event.content + ) + + def event_from_pdu_json(pdu_json, outlier=False): """Construct a FrozenEvent from an event json received over federation diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index c9f3c2d352..d05ed91d64 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -66,6 +66,14 @@ class FederationClient(FederationBase): self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() + self._get_pdu_cache = ExpiringCache( + cache_name="get_pdu_cache", + clock=self._clock, + max_len=1000, + expiry_ms=120 * 1000, + reset_expiry_on_get=False, + ) + def _clear_tried_cache(self): """Clear pdu_destination_tried cache""" now = self._clock.time_msec() @@ -82,17 +90,6 @@ class FederationClient(FederationBase): if destination_dict: self.pdu_destination_tried[event_id] = destination_dict - def start_get_pdu_cache(self): - self._get_pdu_cache = ExpiringCache( - cache_name="get_pdu_cache", - clock=self._clock, - max_len=1000, - expiry_ms=120 * 1000, - reset_expiry_on_get=False, - ) - - self._get_pdu_cache.start() - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False, ignore_backoff=False): @@ -212,8 +209,6 @@ class FederationClient(FederationBase): Will attempt to get the PDU from each destination in the list until one succeeds. - This will persist the PDU locally upon receipt. - Args: destinations (list): Which home servers to query event_id (str): event to fetch @@ -229,10 +224,9 @@ class FederationClient(FederationBase): # TODO: Rate limit the number of times we try and get the same event. - if self._get_pdu_cache: - ev = self._get_pdu_cache.get(event_id) - if ev: - defer.returnValue(ev) + ev = self._get_pdu_cache.get(event_id) + if ev: + defer.returnValue(ev) pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {}) @@ -271,10 +265,10 @@ class FederationClient(FederationBase): event_id, destination, e, ) except NotRetryingDestination as e: - logger.info(e.message) + logger.info(str(e)) continue except FederationDeniedError as e: - logger.info(e.message) + logger.info(str(e)) continue except Exception as e: pdu_attempts[destination] = now @@ -285,7 +279,7 @@ class FederationClient(FederationBase): ) continue - if self._get_pdu_cache is not None and signed_pdu: + if signed_pdu: self._get_pdu_cache[event_id] = signed_pdu defer.returnValue(signed_pdu) @@ -293,8 +287,7 @@ class FederationClient(FederationBase): @defer.inlineCallbacks @log_function def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the `current` state PDUs for a given room from - a remote home server. + """Requests all of the room state at a given event from a remote home server. Args: destination (str): The remote homeserver to query for the state. @@ -302,9 +295,10 @@ class FederationClient(FederationBase): event_id (str): The id of the event we want the state at. Returns: - Deferred: Results in a list of PDUs. + Deferred[Tuple[List[EventBase], List[EventBase]]]: + A list of events in the state, and a list of events in the auth chain + for the given event. """ - try: # First we try and ask for just the IDs, as thats far quicker if # we have most of the state and auth_chain already. @@ -510,7 +504,7 @@ class FederationClient(FederationBase): else: logger.warn( "Failed to %s via %s: %i %s", - description, destination, e.code, e.message, + description, destination, e.code, e.args[0], ) except Exception: logger.warn( @@ -875,7 +869,7 @@ class FederationClient(FederationBase): except Exception as e: logger.exception( "Failed to send_third_party_invite via %s: %s", - destination, e.message + destination, str(e) ) raise RuntimeError("Failed to send to any server.") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 3e0cd294a1..819e8f7331 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -46,6 +46,7 @@ from synapse.replication.http.federation import ( from synapse.types import get_domain_from_id from synapse.util.async_helpers import Linearizer, concurrently_execute from synapse.util.caches.response_cache import ResponseCache +from synapse.util.logcontext import nested_logging_context from synapse.util.logutils import log_function # when processing incoming transactions, we try to handle multiple rooms in @@ -99,7 +100,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function - def on_incoming_transaction(self, transaction_data): + def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -108,34 +109,33 @@ class FederationServer(FederationBase): if not transaction.transaction_id: raise Exception("Transaction missing transaction_id") - if not transaction.origin: - raise Exception("Transaction missing origin") logger.debug("[%s] Got transaction", transaction.transaction_id) # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with (yield self._transaction_linearizer.queue( - (transaction.origin, transaction.transaction_id), + (origin, transaction.transaction_id), )): result = yield self._handle_incoming_transaction( - transaction, request_time, + origin, transaction, request_time, ) defer.returnValue(result) @defer.inlineCallbacks - def _handle_incoming_transaction(self, transaction, request_time): + def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: + origin (unicode): the server making the request transaction (Transaction): incoming transaction request_time (int): timestamp that the HTTP request arrived at Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(transaction) + response = yield self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -149,7 +149,7 @@ class FederationServer(FederationBase): received_pdus_counter.inc(len(transaction.pdus)) - origin_host, _ = parse_server_name(transaction.origin) + origin_host, _ = parse_server_name(origin) pdus_by_room = {} @@ -188,21 +188,22 @@ class FederationServer(FederationBase): for pdu in pdus_by_room[room_id]: event_id = pdu.event_id - try: - yield self._handle_received_pdu( - transaction.origin, pdu - ) - pdu_results[event_id] = {} - except FederationError as e: - logger.warn("Error handling PDU %s: %s", event_id, e) - pdu_results[event_id] = {"error": str(e)} - except Exception as e: - f = failure.Failure() - pdu_results[event_id] = {"error": str(e)} - logger.error( - "Failed to handle PDU %s: %s", - event_id, f.getTraceback().rstrip(), - ) + with nested_logging_context(event_id): + try: + yield self._handle_received_pdu( + origin, pdu + ) + pdu_results[event_id] = {} + except FederationError as e: + logger.warn("Error handling PDU %s: %s", event_id, e) + pdu_results[event_id] = {"error": str(e)} + except Exception as e: + f = failure.Failure() + pdu_results[event_id] = {"error": str(e)} + logger.error( + "Failed to handle PDU %s: %s", + event_id, f.getTraceback().rstrip(), + ) yield concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), @@ -212,7 +213,7 @@ class FederationServer(FederationBase): if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): yield self.received_edu( - transaction.origin, + origin, edu.edu_type, edu.content ) @@ -224,6 +225,7 @@ class FederationServer(FederationBase): logger.debug("Returning: %s", str(response)) yield self.transaction_actions.set_response( + origin, transaction, 200, response ) @@ -618,7 +620,7 @@ class FederationServer(FederationBase): ) yield self.handler.on_receive_pdu( - origin, pdu, get_missing=True, sent_to_us_directly=True, + origin, pdu, sent_to_us_directly=True, ) def __str__(self): @@ -838,9 +840,9 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): ) return self._send_edu( - edu_type=edu_type, - origin=origin, - content=content, + edu_type=edu_type, + origin=origin, + content=content, ) def on_query(self, query_type, args): @@ -851,6 +853,6 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): return handler(args) return self._get_query_client( - query_type=query_type, - args=args, + query_type=query_type, + args=args, ) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 9146215c21..74ffd13b4f 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -36,7 +36,7 @@ class TransactionActions(object): self.store = datastore @log_function - def have_responded(self, transaction): + def have_responded(self, origin, transaction): """ Have we already responded to a transaction with the same id and origin? @@ -50,11 +50,11 @@ class TransactionActions(object): "transaction_id") return self.store.get_received_txn_response( - transaction.transaction_id, transaction.origin + transaction.transaction_id, origin ) @log_function - def set_response(self, transaction, code, response): + def set_response(self, origin, transaction, code, response): """ Persist how we responded to a transaction. Returns: @@ -66,7 +66,7 @@ class TransactionActions(object): return self.store.set_received_txn_response( transaction.transaction_id, - transaction.origin, + origin, code, response, ) diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0bb468385d..6f5995735a 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -32,7 +32,7 @@ Events are replicated via a separate events stream. import logging from collections import namedtuple -from six import iteritems, itervalues +from six import iteritems from sortedcontainers import SortedDict @@ -117,7 +117,7 @@ class FederationRemoteSendQueue(object): user_ids = set( user_id - for uids in itervalues(self.presence_changed) + for uids in self.presence_changed.values() for user_id in uids ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 94d7423d01..98b5950800 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -137,26 +137,6 @@ class TransactionQueue(object): self._processing_pending_presence = False - def can_send_to(self, destination): - """Can we send messages to the given server? - - We can't send messages to ourselves. If we are running on localhost - then we can only federation with other servers running on localhost. - Otherwise we only federate with servers on a public domain. - - Args: - destination(str): The server we are possibly trying to send to. - Returns: - bool: True if we can send to the server. - """ - - if destination == self.server_name: - return False - if self.server_name.startswith("localhost"): - return destination.startswith("localhost") - else: - return not destination.startswith("localhost") - def notify_new_events(self, current_id): """This gets called when we have some new events we might want to send out to other servers. @@ -279,10 +259,7 @@ class TransactionQueue(object): self._order += 1 destinations = set(destinations) - destinations = set( - dest for dest in destinations if self.can_send_to(dest) - ) - + destinations.discard(self.server_name) logger.debug("Sending to: %s", str(destinations)) if not destinations: @@ -358,7 +335,7 @@ class TransactionQueue(object): for destinations, states in hosts_and_states: for destination in destinations: - if not self.can_send_to(destination): + if destination == self.server_name: continue self.pending_presence_by_dest.setdefault( @@ -377,7 +354,8 @@ class TransactionQueue(object): content=content, ) - if not self.can_send_to(destination): + if destination == self.server_name: + logger.info("Not sending EDU to ourselves") return sent_edus_counter.inc() @@ -392,10 +370,8 @@ class TransactionQueue(object): self._attempt_new_transaction(destination) def send_device_messages(self, destination): - if destination == self.server_name or destination == "localhost": - return - - if not self.can_send_to(destination): + if destination == self.server_name: + logger.info("Not sending device update to ourselves") return self._attempt_new_transaction(destination) @@ -463,7 +439,19 @@ class TransactionQueue(object): # pending_transactions flag. pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + + # We can only include at most 50 PDUs per transactions + pending_pdus, leftover_pdus = pending_pdus[:50], pending_pdus[50:] + if leftover_pdus: + self.pending_pdus_by_dest[destination] = leftover_pdus + pending_edus = self.pending_edus_by_dest.pop(destination, []) + + # We can only include at most 100 EDUs per transactions + pending_edus, leftover_edus = pending_edus[:100], pending_edus[100:] + if leftover_edus: + self.pending_edus_by_dest[destination] = leftover_edus + pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_edus.extend( diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 1054441ca5..2ab973d6c8 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -15,7 +15,8 @@ # limitations under the License. import logging -import urllib + +from six.moves import urllib from twisted.internet import defer @@ -951,4 +952,4 @@ def _create_path(prefix, path, *args): Returns: str """ - return prefix + path % tuple(urllib.quote(arg, "") for arg in args) + return prefix + path % tuple(urllib.parse.quote(arg, "") for arg in args) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 7a993fd1cf..2f874b4838 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -90,8 +90,8 @@ class Authenticator(object): @defer.inlineCallbacks def authenticate_request(self, request, content): json_request = { - "method": request.method, - "uri": request.uri, + "method": request.method.decode('ascii'), + "uri": request.uri.decode('ascii'), "destination": self.server_name, "signatures": {}, } @@ -252,7 +252,7 @@ class BaseFederationServlet(object): by the callback method. None if the request has already been handled. """ content = None - if request.method in ["PUT", "POST"]: + if request.method in [b"PUT", b"POST"]: # TODO: Handle other method types? other content types? content = parse_json_object_from_request(request) @@ -353,7 +353,7 @@ class FederationSendServlet(BaseFederationServlet): try: code, response = yield self.handler.on_incoming_transaction( - transaction_data + origin, transaction_data, ) except Exception: logger.exception("on_incoming_transaction failed") @@ -386,7 +386,7 @@ class FederationStateServlet(BaseFederationServlet): return self.handler.on_context_state_request( origin, context, - query.get("event_id", [None])[0], + parse_string_from_args(query, "event_id", None), ) @@ -397,7 +397,7 @@ class FederationStateIdsServlet(BaseFederationServlet): return self.handler.on_state_ids_request( origin, room_id, - query.get("event_id", [None])[0], + parse_string_from_args(query, "event_id", None), ) @@ -405,14 +405,12 @@ class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P<context>[^/]*)/" def on_GET(self, origin, content, query, context): - versions = query["v"] - limits = query["limit"] + versions = [x.decode('ascii') for x in query[b"v"]] + limit = parse_integer_from_args(query, "limit", None) - if not limits: + if not limit: return defer.succeed((400, {"error": "Did not include limit param"})) - limit = int(limits[-1]) - return self.handler.on_backfill_request(origin, context, versions, limit) @@ -423,7 +421,7 @@ class FederationQueryServlet(BaseFederationServlet): def on_GET(self, origin, content, query, query_type): return self.handler.on_query_request( query_type, - {k: v[0].decode("utf-8") for k, v in query.items()} + {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()} ) @@ -630,14 +628,14 @@ class OpenIdUserInfo(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query): - token = query.get("access_token", [None])[0] + token = query.get(b"access_token", [None])[0] if token is None: defer.returnValue((401, { "errcode": "M_MISSING_TOKEN", "error": "Access Token required" })) return - user_id = yield self.handler.on_openid_userinfo(token) + user_id = yield self.handler.on_openid_userinfo(token.decode('ascii')) if user_id is None: defer.returnValue((401, { |