diff options
Diffstat (limited to 'synapse/federation')
-rw-r--r-- | synapse/federation/federation_base.py | 77 | ||||
-rw-r--r-- | synapse/federation/federation_client.py | 217 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 234 | ||||
-rw-r--r-- | synapse/federation/persistence.py | 15 | ||||
-rw-r--r-- | synapse/federation/send_queue.py | 145 | ||||
-rw-r--r-- | synapse/federation/sender/__init__.py | 70 | ||||
-rw-r--r-- | synapse/federation/sender/per_destination_queue.py | 45 | ||||
-rw-r--r-- | synapse/federation/sender/transaction_manager.py | 38 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 294 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 245 | ||||
-rw-r--r-- | synapse/federation/units.py | 33 |
11 files changed, 644 insertions, 769 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index fc5cfb7d83..1e925b19e7 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -44,8 +44,9 @@ class FederationBase(object): self._clock = hs.get_clock() @defer.inlineCallbacks - def _check_sigs_and_hash_and_fetch(self, origin, pdus, room_version, - outlier=False, include_none=False): + def _check_sigs_and_hash_and_fetch( + self, origin, pdus, room_version, outlier=False, include_none=False + ): """Takes a list of PDUs and checks the signatures and hashs of each one. If a PDU fails its signature check then we check if we have it in the database and if not then request if from the originating server of @@ -79,9 +80,7 @@ class FederationBase(object): if not res: # Check local db. res = yield self.store.get_event( - pdu.event_id, - allow_rejected=True, - allow_none=True, + pdu.event_id, allow_rejected=True, allow_none=True ) if not res and pdu.origin != origin: @@ -98,23 +97,16 @@ class FederationBase(object): if not res: logger.warn( - "Failed to find copy of %s with valid signature", - pdu.event_id, + "Failed to find copy of %s with valid signature", pdu.event_id ) defer.returnValue(res) handle = logcontext.preserve_fn(handle_check_result) - deferreds2 = [ - handle(pdu, deferred) - for pdu, deferred in zip(pdus, deferreds) - ] + deferreds2 = [handle(pdu, deferred) for pdu, deferred in zip(pdus, deferreds)] valid_pdus = yield logcontext.make_deferred_yieldable( - defer.gatherResults( - deferreds2, - consumeErrors=True, - ) + defer.gatherResults(deferreds2, consumeErrors=True) ).addErrback(unwrapFirstError) if include_none: @@ -124,7 +116,7 @@ class FederationBase(object): def _check_sigs_and_hash(self, room_version, pdu): return logcontext.make_deferred_yieldable( - self._check_sigs_and_hashes(room_version, [pdu])[0], + self._check_sigs_and_hashes(room_version, [pdu])[0] ) def _check_sigs_and_hashes(self, room_version, pdus): @@ -159,11 +151,9 @@ class FederationBase(object): # received event was probably a redacted copy (but we then use our # *actual* redacted copy to be on the safe side.) redacted_event = prune_event(pdu) - if ( - set(redacted_event.keys()) == set(pdu.keys()) and - set(six.iterkeys(redacted_event.content)) - == set(six.iterkeys(pdu.content)) - ): + if set(redacted_event.keys()) == set(pdu.keys()) and set( + six.iterkeys(redacted_event.content) + ) == set(six.iterkeys(pdu.content)): logger.info( "Event %s seems to have been redacted; using our redacted " "copy", @@ -172,14 +162,15 @@ class FederationBase(object): else: logger.warning( "Event %s content has been tampered, redacting", - pdu.event_id, pdu.get_pdu_json(), + pdu.event_id, ) return redacted_event if self.spam_checker.check_event_for_spam(pdu): logger.warn( "Event contains spam, redacting %s: %s", - pdu.event_id, pdu.get_pdu_json() + pdu.event_id, + pdu.get_pdu_json(), ) return prune_event(pdu) @@ -190,23 +181,24 @@ class FederationBase(object): with logcontext.PreserveLoggingContext(ctx): logger.warn( "Signature check failed for %s: %s", - pdu.event_id, failure.getErrorMessage(), + pdu.event_id, + failure.getErrorMessage(), ) return failure for deferred, pdu in zip(deferreds, pdus): deferred.addCallbacks( - callback, errback, - callbackArgs=[pdu], - errbackArgs=[pdu], + callback, errback, callbackArgs=[pdu], errbackArgs=[pdu] ) return deferreds -class PduToCheckSig(namedtuple("PduToCheckSig", [ - "pdu", "redacted_pdu_json", "sender_domain", "deferreds", -])): +class PduToCheckSig( + namedtuple( + "PduToCheckSig", ["pdu", "redacted_pdu_json", "sender_domain", "deferreds"] + ) +): pass @@ -260,10 +252,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # First we check that the sender event is signed by the sender's domain # (except if its a 3pid invite, in which case it may be sent by any server) - pdus_to_check_sender = [ - p for p in pdus_to_check - if not _is_invite_via_3pid(p.pdu) - ] + pdus_to_check_sender = [p for p in pdus_to_check if not _is_invite_via_3pid(p.pdu)] more_deferreds = keyring.verify_json_objects_for_server( [ @@ -297,7 +286,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): # (ie, the room version uses old-style non-hash event IDs). if v.event_format == EventFormatVersions.V1: pdus_to_check_event_id = [ - p for p in pdus_to_check + p + for p in pdus_to_check if p.sender_domain != get_domain_from_id(p.pdu.event_id) ] @@ -315,10 +305,8 @@ def _check_sigs_on_pdus(keyring, room_version, pdus): def event_err(e, pdu_to_check): errmsg = ( - "event id %s: unable to verify signature for event id domain: %s" % ( - pdu_to_check.pdu.event_id, - e.getErrorMessage(), - ) + "event id %s: unable to verify signature for event id domain: %s" + % (pdu_to_check.pdu.event_id, e.getErrorMessage()) ) # XX as above: not really sure if these are the right codes raise SynapseError(400, errmsg, Codes.UNAUTHORIZED) @@ -368,21 +356,18 @@ def event_from_pdu_json(pdu_json, event_format_version, outlier=False): """ # we could probably enforce a bunch of other fields here (room_id, sender, # origin, etc etc) - assert_params_in_dict(pdu_json, ('type', 'depth')) + assert_params_in_dict(pdu_json, ("type", "depth")) - depth = pdu_json['depth'] + depth = pdu_json["depth"] if not isinstance(depth, six.integer_types): - raise SynapseError(400, "Depth %r not an intger" % (depth, ), - Codes.BAD_JSON) + raise SynapseError(400, "Depth %r not an intger" % (depth,), Codes.BAD_JSON) if depth < 0: raise SynapseError(400, "Depth too small", Codes.BAD_JSON) elif depth > MAX_DEPTH: raise SynapseError(400, "Depth too large", Codes.BAD_JSON) - event = event_type_from_format_version(event_format_version)( - pdu_json, - ) + event = event_type_from_format_version(event_format_version)(pdu_json) event.internal_metadata.outlier = outlier diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 70573746d6..3883eb525e 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -57,6 +57,7 @@ class InvalidResponseError(RuntimeError): """Helper for _try_destination_list: indicates that the server returned a response we couldn't parse """ + pass @@ -65,9 +66,7 @@ class FederationClient(FederationBase): super(FederationClient, self).__init__(hs) self.pdu_destination_tried = {} - self._clock.looping_call( - self._clear_tried_cache, 60 * 1000, - ) + self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self.state = hs.get_state_handler() self.transport_layer = hs.get_federation_transport_client() @@ -99,8 +98,14 @@ class FederationClient(FederationBase): self.pdu_destination_tried[event_id] = destination_dict @log_function - def make_query(self, destination, query_type, args, - retry_on_dns_fail=False, ignore_backoff=False): + def make_query( + self, + destination, + query_type, + args, + retry_on_dns_fail=False, + ignore_backoff=False, + ): """Sends a federation Query to a remote homeserver of the given type and arguments. @@ -120,7 +125,10 @@ class FederationClient(FederationBase): sent_queries_counter.labels(query_type).inc() return self.transport_layer.make_query( - destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail, + destination, + query_type, + args, + retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff, ) @@ -137,9 +145,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.labels("client_device_keys").inc() - return self.transport_layer.query_client_keys( - destination, content, timeout - ) + return self.transport_layer.query_client_keys(destination, content, timeout) @log_function def query_user_devices(self, destination, user_id, timeout=30000): @@ -147,9 +153,7 @@ class FederationClient(FederationBase): server. """ sent_queries_counter.labels("user_devices").inc() - return self.transport_layer.query_user_devices( - destination, user_id, timeout - ) + return self.transport_layer.query_user_devices(destination, user_id, timeout) @log_function def claim_client_keys(self, destination, content, timeout): @@ -164,9 +168,7 @@ class FederationClient(FederationBase): response """ sent_queries_counter.labels("client_one_time_keys").inc() - return self.transport_layer.claim_client_keys( - destination, content, timeout - ) + return self.transport_layer.claim_client_keys(destination, content, timeout) @defer.inlineCallbacks @log_function @@ -191,7 +193,8 @@ class FederationClient(FederationBase): return transaction_data = yield self.transport_layer.backfill( - dest, room_id, extremities, limit) + dest, room_id, extremities, limit + ) logger.debug("backfill transaction_data=%s", repr(transaction_data)) @@ -204,17 +207,19 @@ class FederationClient(FederationBase): ] # FIXME: We should handle signature failures more gracefully. - pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults( - self._check_sigs_and_hashes(room_version, pdus), - consumeErrors=True, - ).addErrback(unwrapFirstError)) + pdus[:] = yield logcontext.make_deferred_yieldable( + defer.gatherResults( + self._check_sigs_and_hashes(room_version, pdus), consumeErrors=True + ).addErrback(unwrapFirstError) + ) defer.returnValue(pdus) @defer.inlineCallbacks @log_function - def get_pdu(self, destinations, event_id, room_version, outlier=False, - timeout=None): + def get_pdu( + self, destinations, event_id, room_version, outlier=False, timeout=None + ): """Requests the PDU with given origin and ID from the remote home servers. @@ -255,7 +260,7 @@ class FederationClient(FederationBase): try: transaction_data = yield self.transport_layer.get_event( - destination, event_id, timeout=timeout, + destination, event_id, timeout=timeout ) logger.debug( @@ -282,8 +287,7 @@ class FederationClient(FederationBase): except SynapseError as e: logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, + "Failed to get PDU %s from %s because %s", event_id, destination, e ) continue except NotRetryingDestination as e: @@ -296,8 +300,7 @@ class FederationClient(FederationBase): pdu_attempts[destination] = now logger.info( - "Failed to get PDU %s from %s because %s", - event_id, destination, e, + "Failed to get PDU %s from %s because %s", event_id, destination, e ) continue @@ -326,7 +329,7 @@ class FederationClient(FederationBase): # we have most of the state and auth_chain already. # However, this may 404 if the other side has an old synapse. result = yield self.transport_layer.get_room_state_ids( - destination, room_id, event_id=event_id, + destination, room_id, event_id=event_id ) state_event_ids = result["pdu_ids"] @@ -340,12 +343,10 @@ class FederationClient(FederationBase): logger.warning( "Failed to fetch missing state/auth events for %s: %s", room_id, - failed_to_fetch + failed_to_fetch, ) - event_map = { - ev.event_id: ev for ev in fetched_events - } + event_map = {ev.event_id: ev for ev in fetched_events} pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] auth_chain = [ @@ -362,15 +363,14 @@ class FederationClient(FederationBase): raise e result = yield self.transport_layer.get_room_state( - destination, room_id, event_id=event_id, + destination, room_id, event_id=event_id ) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdus = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in result["pdus"] + event_from_pdu_json(p, format_ver, outlier=True) for p in result["pdus"] ] auth_chain = [ @@ -378,9 +378,9 @@ class FederationClient(FederationBase): for p in result.get("auth_chain", []) ] - seen_events = yield self.store.get_events([ - ev.event_id for ev in itertools.chain(pdus, auth_chain) - ]) + seen_events = yield self.store.get_events( + [ev.event_id for ev in itertools.chain(pdus, auth_chain)] + ) signed_pdus = yield self._check_sigs_and_hash_and_fetch( destination, @@ -442,7 +442,7 @@ class FederationClient(FederationBase): batch_size = 20 missing_events = list(missing_events) for i in range(0, len(missing_events), batch_size): - batch = set(missing_events[i:i + batch_size]) + batch = set(missing_events[i : i + batch_size]) deferreds = [ run_in_background( @@ -470,21 +470,17 @@ class FederationClient(FederationBase): @defer.inlineCallbacks @log_function def get_event_auth(self, destination, room_id, event_id): - res = yield self.transport_layer.get_event_auth( - destination, room_id, event_id, - ) + res = yield self.transport_layer.get_event_auth(destination, room_id, event_id) room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ - event_from_pdu_json(p, format_ver, outlier=True) - for p in res["auth_chain"] + event_from_pdu_json(p, format_ver, outlier=True) for p in res["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, - outlier=True, room_version=room_version, + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -527,28 +523,26 @@ class FederationClient(FederationBase): res = yield callback(destination) defer.returnValue(res) except InvalidResponseError as e: - logger.warn( - "Failed to %s via %s: %s", - description, destination, e, - ) + logger.warn("Failed to %s via %s: %s", description, destination, e) except HttpResponseException as e: if not 500 <= e.code < 600: raise e.to_synapse_error() else: logger.warn( "Failed to %s via %s: %i %s", - description, destination, e.code, e.args[0], + description, + destination, + e.code, + e.args[0], ) except Exception: - logger.warn( - "Failed to %s via %s", - description, destination, exc_info=1, - ) + logger.warn("Failed to %s via %s", description, destination, exc_info=1) - raise RuntimeError("Failed to %s via any server" % (description, )) + raise RuntimeError("Failed to %s via any server" % (description,)) - def make_membership_event(self, destinations, room_id, user_id, membership, - content, params): + def make_membership_event( + self, destinations, room_id, user_id, membership, content, params + ): """ Creates an m.room.member event, with context, without participating in the room. @@ -584,14 +578,14 @@ class FederationClient(FederationBase): valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( - "make_membership_event called with membership='%s', must be one of %s" % - (membership, ",".join(valid_memberships)) + "make_membership_event called with membership='%s', must be one of %s" + % (membership, ",".join(valid_memberships)) ) @defer.inlineCallbacks def send_request(destination): ret = yield self.transport_layer.make_membership_event( - destination, room_id, user_id, membership, params, + destination, room_id, user_id, membership, params ) # Note: If not supplied, the room version may be either v1 or v2, @@ -614,16 +608,17 @@ class FederationClient(FederationBase): pdu_dict["prev_state"] = [] ev = builder.create_local_event_from_event_dict( - self._clock, self.hostname, self.signing_key, - format_version=event_format, event_dict=pdu_dict, + self._clock, + self.hostname, + self.signing_key, + format_version=event_format, + event_dict=pdu_dict, ) - defer.returnValue( - (destination, ev, event_format) - ) + defer.returnValue((destination, ev, event_format)) return self._try_destination_list( - "make_" + membership, destinations, send_request, + "make_" + membership, destinations, send_request ) def send_join(self, destinations, pdu, event_format_version): @@ -655,9 +650,7 @@ class FederationClient(FederationBase): create_event = e break else: - raise InvalidResponseError( - "no %s in auth chain" % (EventTypes.Create,), - ) + raise InvalidResponseError("no %s in auth chain" % (EventTypes.Create,)) # the room version should be sane. room_version = create_event.content.get("room_version", "1") @@ -665,9 +658,8 @@ class FederationClient(FederationBase): # This shouldn't be possible, because the remote server should have # rejected the join attempt during make_join. raise InvalidResponseError( - "room appears to have unsupported version %s" % ( - room_version, - )) + "room appears to have unsupported version %s" % (room_version,) + ) @defer.inlineCallbacks def send_request(destination): @@ -691,10 +683,7 @@ class FederationClient(FederationBase): for p in content.get("auth_chain", []) ] - pdus = { - p.event_id: p - for p in itertools.chain(state, auth_chain) - } + pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)} room_version = None for e in state: @@ -710,15 +699,13 @@ class FederationClient(FederationBase): raise SynapseError(400, "No create event in state") valid_pdus = yield self._check_sigs_and_hash_and_fetch( - destination, list(pdus.values()), + destination, + list(pdus.values()), outlier=True, room_version=room_version, ) - valid_pdus_map = { - p.event_id: p - for p in valid_pdus - } + valid_pdus_map = {p.event_id: p for p in valid_pdus} # NB: We *need* to copy to ensure that we don't have multiple # references being passed on, as that causes... issues. @@ -741,11 +728,14 @@ class FederationClient(FederationBase): check_authchain_validity(signed_auth) - defer.returnValue({ - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - }) + defer.returnValue( + { + "state": signed_state, + "auth_chain": signed_auth, + "origin": destination, + } + ) + return self._try_destination_list("send_join", destinations, send_request) @defer.inlineCallbacks @@ -854,6 +844,7 @@ class FederationClient(FederationBase): Fails with a ``RuntimeError`` if no servers were reachable. """ + @defer.inlineCallbacks def send_request(destination): time_now = self._clock.time_msec() @@ -869,14 +860,23 @@ class FederationClient(FederationBase): return self._try_destination_list("send_leave", destinations, send_request) - def get_public_rooms(self, destination, limit=None, since_token=None, - search_filter=None, include_all_networks=False, - third_party_instance_id=None): + def get_public_rooms( + self, + destination, + limit=None, + since_token=None, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): if destination == self.server_name: return return self.transport_layer.get_public_rooms( - destination, limit, since_token, search_filter, + destination, + limit, + since_token, + search_filter, include_all_networks=include_all_networks, third_party_instance_id=third_party_instance_id, ) @@ -891,9 +891,7 @@ class FederationClient(FederationBase): """ time_now = self._clock.time_msec() - send_content = { - "auth_chain": [e.get_pdu_json(time_now) for e in local_auth], - } + send_content = {"auth_chain": [e.get_pdu_json(time_now) for e in local_auth]} code, content = yield self.transport_layer.send_query_auth( destination=destination, @@ -905,13 +903,10 @@ class FederationClient(FederationBase): room_version = yield self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) - auth_chain = [ - event_from_pdu_json(e, format_ver) - for e in content["auth_chain"] - ] + auth_chain = [event_from_pdu_json(e, format_ver) for e in content["auth_chain"]] signed_auth = yield self._check_sigs_and_hash_and_fetch( - destination, auth_chain, outlier=True, room_version=room_version, + destination, auth_chain, outlier=True, room_version=room_version ) signed_auth.sort(key=lambda e: e.depth) @@ -925,8 +920,16 @@ class FederationClient(FederationBase): defer.returnValue(ret) @defer.inlineCallbacks - def get_missing_events(self, destination, room_id, earliest_events_ids, - latest_events, limit, min_depth, timeout): + def get_missing_events( + self, + destination, + room_id, + earliest_events_ids, + latest_events, + limit, + min_depth, + timeout, + ): """Tries to fetch events we are missing. This is called when we receive an event without having received all of its ancestors. @@ -957,12 +960,11 @@ class FederationClient(FederationBase): format_ver = room_version_to_event_format(room_version) events = [ - event_from_pdu_json(e, format_ver) - for e in content.get("events", []) + event_from_pdu_json(e, format_ver) for e in content.get("events", []) ] signed_events = yield self._check_sigs_and_hash_and_fetch( - destination, events, outlier=False, room_version=room_version, + destination, events, outlier=False, room_version=room_version ) except HttpResponseException as e: if not e.code == 400: @@ -982,17 +984,14 @@ class FederationClient(FederationBase): try: yield self.transport_layer.exchange_third_party_invite( - destination=destination, - room_id=room_id, - event_dict=event_dict, + destination=destination, room_id=room_id, event_dict=event_dict ) defer.returnValue(None) except CodeMessageException: raise except Exception as e: logger.exception( - "Failed to send_third_party_invite via %s: %s", - destination, str(e) + "Failed to send_third_party_invite via %s: %s", destination, str(e) ) raise RuntimeError("Failed to send to any server.") diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 4c28c1dc3c..2e0cebb638 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -69,7 +69,6 @@ received_queries_counter = Counter( class FederationServer(FederationBase): - def __init__(self, hs): super(FederationServer, self).__init__(hs) @@ -118,11 +117,13 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. - with (yield self._transaction_linearizer.queue( - (origin, transaction.transaction_id), - )): + with ( + yield self._transaction_linearizer.queue( + (origin, transaction.transaction_id) + ) + ): result = yield self._handle_incoming_transaction( - origin, transaction, request_time, + origin, transaction, request_time ) defer.returnValue(result) @@ -144,7 +145,7 @@ class FederationServer(FederationBase): if response: logger.debug( "[%s] We've already responded to this request", - transaction.transaction_id + transaction.transaction_id, ) defer.returnValue(response) return @@ -152,18 +153,15 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) # Reject if PDU count > 50 and EDU count > 100 - if (len(transaction.pdus) > 50 - or (hasattr(transaction, "edus") and len(transaction.edus) > 100)): + if len(transaction.pdus) > 50 or ( + hasattr(transaction, "edus") and len(transaction.edus) > 100 + ): - logger.info( - "Transaction PDU or EDU count too large. Returning 400", - ) + logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} yield self.transaction_actions.set_response( - origin, - transaction, - 400, response + origin, transaction, 400, response ) defer.returnValue((400, response)) @@ -230,9 +228,7 @@ class FederationServer(FederationBase): try: yield self.check_server_matches_acl(origin_host, room_id) except AuthError as e: - logger.warn( - "Ignoring PDUs for room %s from banned server", room_id, - ) + logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: event_id = pdu.event_id pdu_results[event_id] = e.error_dict() @@ -242,9 +238,7 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu( - origin, pdu - ) + yield self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -259,29 +253,18 @@ class FederationServer(FederationBase): ) yield concurrently_execute( - process_pdus_for_room, pdus_by_room.keys(), - TRANSACTION_CONCURRENCY_LIMIT, + process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu( - origin, - edu.edu_type, - edu.content - ) + yield self.received_edu(origin, edu.edu_type, edu.content) - response = { - "pdus": pdu_results, - } + response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response( - origin, - transaction, - 200, response - ) + yield self.transaction_actions.set_response(origin, transaction, 200, response) defer.returnValue((200, response)) @defer.inlineCallbacks @@ -311,7 +294,8 @@ class FederationServer(FederationBase): resp = yield self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, - room_id, event_id, + room_id, + event_id, ) defer.returnValue((200, resp)) @@ -328,24 +312,17 @@ class FederationServer(FederationBase): if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu( - room_id, event_id, - ) + state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) - defer.returnValue((200, { - "pdu_ids": state_ids, - "auth_chain_ids": auth_chain_ids, - })) + defer.returnValue( + (200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}) + ) @defer.inlineCallbacks def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu( - room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) + pdus = yield self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) for event in auth_chain: # We sign these again because there was a bug where we @@ -355,14 +332,16 @@ class FederationServer(FederationBase): compute_event_signature( event.get_pdu_json(), self.hs.hostname, - self.hs.config.signing_key[0] + self.hs.config.signing_key[0], ) ) - defer.returnValue({ - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - }) + defer.returnValue( + { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + } + ) @defer.inlineCallbacks @log_function @@ -370,9 +349,7 @@ class FederationServer(FederationBase): pdu = yield self.handler.get_persisted_pdu(origin, event_id) if pdu: - defer.returnValue( - (200, self._transaction_from_pdus([pdu]).get_dict()) - ) + defer.returnValue((200, self._transaction_from_pdus([pdu]).get_dict())) else: defer.returnValue((404, "")) @@ -394,10 +371,9 @@ class FederationServer(FederationBase): pdu = yield self.handler.on_make_join_request(room_id, user_id) time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - "room_version": room_version, - }) + defer.returnValue( + {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + ) @defer.inlineCallbacks def on_invite_request(self, origin, content, room_version): @@ -431,12 +407,17 @@ class FederationServer(FederationBase): logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) res_pdus = yield self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() - defer.returnValue((200, { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - })) + defer.returnValue( + ( + 200, + { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + }, + ) + ) @defer.inlineCallbacks def on_make_leave_request(self, origin, room_id, user_id): @@ -447,10 +428,9 @@ class FederationServer(FederationBase): room_version = yield self.store.get_room_version(room_id) time_now = self._clock.time_msec() - defer.returnValue({ - "event": pdu.get_pdu_json(time_now), - "room_version": room_version, - }) + defer.returnValue( + {"event": pdu.get_pdu_json(time_now), "room_version": room_version} + ) @defer.inlineCallbacks def on_send_leave_request(self, origin, content, room_id): @@ -475,9 +455,7 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() auth_pdus = yield self.handler.on_event_auth(event_id) - res = { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], - } + res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} defer.returnValue((200, res)) @defer.inlineCallbacks @@ -508,12 +486,11 @@ class FederationServer(FederationBase): format_ver = room_version_to_event_format(room_version) auth_chain = [ - event_from_pdu_json(e, format_ver) - for e in content["auth_chain"] + event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] signed_auth = yield self._check_sigs_and_hash_and_fetch( - origin, auth_chain, outlier=True, room_version=room_version, + origin, auth_chain, outlier=True, room_version=room_version ) ret = yield self.handler.on_query_auth( @@ -527,17 +504,12 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() send_content = { - "auth_chain": [ - e.get_pdu_json(time_now) - for e in ret["auth_chain"] - ], + "auth_chain": [e.get_pdu_json(time_now) for e in ret["auth_chain"]], "rejects": ret.get("rejects", []), "missing": ret.get("missing", []), } - defer.returnValue( - (200, send_content) - ) + defer.returnValue((200, send_content)) @log_function def on_query_client_keys(self, origin, content): @@ -566,20 +538,23 @@ class FederationServer(FederationBase): logger.info( "Claimed one-time-keys: %s", - ",".join(( - "%s for %s:%s" % (key_id, user_id, device_id) - for user_id, user_keys in iteritems(json_result) - for device_id, device_keys in iteritems(user_keys) - for key_id, _ in iteritems(device_keys) - )), + ",".join( + ( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in iteritems(json_result) + for device_id, device_keys in iteritems(user_keys) + for key_id, _ in iteritems(device_keys) + ) + ), ) defer.returnValue({"one_time_keys": json_result}) @defer.inlineCallbacks @log_function - def on_get_missing_events(self, origin, room_id, earliest_events, - latest_events, limit): + def on_get_missing_events( + self, origin, room_id, earliest_events, latest_events, limit + ): with (yield self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) yield self.check_server_matches_acl(origin_host, room_id) @@ -587,11 +562,13 @@ class FederationServer(FederationBase): logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," " limit: %d", - earliest_events, latest_events, limit, + earliest_events, + latest_events, + limit, ) missing_events = yield self.handler.on_get_missing_events( - origin, room_id, earliest_events, latest_events, limit, + origin, room_id, earliest_events, latest_events, limit ) if len(missing_events) < 5: @@ -603,9 +580,9 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() - defer.returnValue({ - "events": [ev.get_pdu_json(time_now) for ev in missing_events], - }) + defer.returnValue( + {"events": [ev.get_pdu_json(time_now) for ev in missing_events]} + ) @log_function def on_openid_userinfo(self, token): @@ -666,22 +643,17 @@ class FederationServer(FederationBase): # origin. See bug #1893. This is also true for some third party # invites). if not ( - pdu.type == 'm.room.member' and - pdu.content and - pdu.content.get("membership", None) in ( - Membership.JOIN, Membership.INVITE, - ) + pdu.type == "m.room.member" + and pdu.content + and pdu.content.get("membership", None) + in (Membership.JOIN, Membership.INVITE) ): logger.info( - "Discarding PDU %s from invalid origin %s", - pdu.event_id, origin + "Discarding PDU %s from invalid origin %s", pdu.event_id, origin ) return else: - logger.info( - "Accepting join PDU %s from %s", - pdu.event_id, origin - ) + logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point room_version = yield self.store.get_room_version(pdu.room_id) @@ -690,33 +662,19 @@ class FederationServer(FederationBase): try: pdu = yield self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=pdu.event_id, - ) + raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu( - origin, pdu, sent_to_us_directly=True, - ) + yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name @defer.inlineCallbacks def exchange_third_party_invite( - self, - sender_user_id, - target_user_id, - room_id, - signed, + self, sender_user_id, target_user_id, room_id, signed ): ret = yield self.handler.exchange_third_party_invite( - sender_user_id, - target_user_id, - room_id, - signed, + sender_user_id, target_user_id, room_id, signed ) defer.returnValue(ret) @@ -771,7 +729,7 @@ def server_matches_acl_event(server_name, acl_event): allow_ip_literals = True if not allow_ip_literals: # check for ipv6 literals. These start with '['. - if server_name[0] == '[': + if server_name[0] == "[": return False # check for ipv4 literals. We can just lift the routine from twisted. @@ -805,7 +763,9 @@ def server_matches_acl_event(server_name, acl_event): def _acl_entry_matches(server_name, acl_entry): if not isinstance(acl_entry, six.string_types): - logger.warn("Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry)) + logger.warn( + "Ignoring non-str ACL entry '%s' (is %s)", acl_entry, type(acl_entry) + ) return False regex = glob_to_regex(acl_entry) return regex.match(server_name) @@ -815,6 +775,7 @@ class FederationHandlerRegistry(object): """Allows classes to register themselves as handlers for a given EDU or query type for incoming federation traffic. """ + def __init__(self): self.edu_handlers = {} self.query_handlers = {} @@ -848,9 +809,7 @@ class FederationHandlerRegistry(object): on and the result used as the response to the query request. """ if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) + raise KeyError("Already have a Query handler for %s" % (query_type,)) logger.info("Registering federation query handler for %r", query_type) @@ -905,14 +864,10 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: return super(ReplicationFederationHandlerRegistry, self).on_edu( - edu_type, origin, content, + edu_type, origin, content ) - return self._send_edu( - edu_type=edu_type, - origin=origin, - content=content, - ) + return self._send_edu(edu_type=edu_type, origin=origin, content=content) def on_query(self, query_type, args): """Overrides FederationHandlerRegistry @@ -921,7 +876,4 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): if handler: return handler(args) - return self._get_query_client( - query_type=query_type, - args=args, - ) + return self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/federation/persistence.py b/synapse/federation/persistence.py index 74ffd13b4f..7535f79203 100644 --- a/synapse/federation/persistence.py +++ b/synapse/federation/persistence.py @@ -46,12 +46,9 @@ class TransactionActions(object): response code and response body. """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " - "transaction_id") + raise RuntimeError("Cannot persist a transaction with no " "transaction_id") - return self.store.get_received_txn_response( - transaction.transaction_id, origin - ) + return self.store.get_received_txn_response(transaction.transaction_id, origin) @log_function def set_response(self, origin, transaction, code, response): @@ -61,14 +58,10 @@ class TransactionActions(object): Deferred """ if not transaction.transaction_id: - raise RuntimeError("Cannot persist a transaction with no " - "transaction_id") + raise RuntimeError("Cannot persist a transaction with no " "transaction_id") return self.store.set_received_txn_response( - transaction.transaction_id, - origin, - code, - response, + transaction.transaction_id, origin, code, response ) @defer.inlineCallbacks diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 0240b339b0..454456a52d 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -77,12 +77,22 @@ class FederationRemoteSendQueue(object): # lambda binds to the queue rather than to the name of the queue which # changes. ARGH. def register(name, queue): - LaterGauge("synapse_federation_send_queue_%s_size" % (queue_name,), - "", [], lambda: len(queue)) + LaterGauge( + "synapse_federation_send_queue_%s_size" % (queue_name,), + "", + [], + lambda: len(queue), + ) for queue_name in [ - "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", - "edus", "device_messages", "pos_time", "presence_destinations", + "presence_map", + "presence_changed", + "keyed_edu", + "keyed_edu_changed", + "edus", + "device_messages", + "pos_time", + "presence_destinations", ]: register(queue_name, getattr(self, queue_name)) @@ -121,9 +131,7 @@ class FederationRemoteSendQueue(object): del self.presence_changed[key] user_ids = set( - user_id - for uids in self.presence_changed.values() - for user_id in uids + user_id for uids in self.presence_changed.values() for user_id in uids ) keys = self.presence_destinations.keys() @@ -285,19 +293,21 @@ class FederationRemoteSendQueue(object): ] for (key, user_id) in dest_user_ids: - rows.append((key, PresenceRow( - state=self.presence_map[user_id], - ))) + rows.append((key, PresenceRow(state=self.presence_map[user_id]))) # Fetch presence to send to destinations i = self.presence_destinations.bisect_right(from_token) j = self.presence_destinations.bisect_right(to_token) + 1 for pos, (user_id, dests) in self.presence_destinations.items()[i:j]: - rows.append((pos, PresenceDestinationsRow( - state=self.presence_map[user_id], - destinations=list(dests), - ))) + rows.append( + ( + pos, + PresenceDestinationsRow( + state=self.presence_map[user_id], destinations=list(dests) + ), + ) + ) # Fetch changes keyed edus i = self.keyed_edu_changed.bisect_right(from_token) @@ -308,10 +318,14 @@ class FederationRemoteSendQueue(object): keyed_edus = {v: k for k, v in self.keyed_edu_changed.items()[i:j]} for ((destination, edu_key), pos) in iteritems(keyed_edus): - rows.append((pos, KeyedEduRow( - key=edu_key, - edu=self.keyed_edu[(destination, edu_key)], - ))) + rows.append( + ( + pos, + KeyedEduRow( + key=edu_key, edu=self.keyed_edu[(destination, edu_key)] + ), + ) + ) # Fetch changed edus i = self.edus.bisect_right(from_token) @@ -327,9 +341,7 @@ class FederationRemoteSendQueue(object): device_messages = {v: k for k, v in self.device_messages.items()[i:j]} for (destination, pos) in iteritems(device_messages): - rows.append((pos, DeviceRow( - destination=destination, - ))) + rows.append((pos, DeviceRow(destination=destination))) # Sort rows based on pos rows.sort() @@ -377,16 +389,14 @@ class BaseFederationRow(object): raise NotImplementedError() -class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( - "state", # UserPresenceState -))): +class PresenceRow( + BaseFederationRow, namedtuple("PresenceRow", ("state",)) # UserPresenceState +): TypeId = "p" @staticmethod def from_data(data): - return PresenceRow( - state=UserPresenceState.from_dict(data) - ) + return PresenceRow(state=UserPresenceState.from_dict(data)) def to_data(self): return self.state.as_dict() @@ -395,33 +405,35 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", ( buff.presence.append(self.state) -class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", ( - "state", # UserPresenceState - "destinations", # list[str] -))): +class PresenceDestinationsRow( + BaseFederationRow, + namedtuple( + "PresenceDestinationsRow", + ("state", "destinations"), # UserPresenceState # list[str] + ), +): TypeId = "pd" @staticmethod def from_data(data): return PresenceDestinationsRow( - state=UserPresenceState.from_dict(data["state"]), - destinations=data["dests"], + state=UserPresenceState.from_dict(data["state"]), destinations=data["dests"] ) def to_data(self): - return { - "state": self.state.as_dict(), - "dests": self.destinations, - } + return {"state": self.state.as_dict(), "dests": self.destinations} def add_to_buffer(self, buff): buff.presence_destinations.append((self.state, self.destinations)) -class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( - "key", # tuple(str) - the edu key passed to send_edu - "edu", # Edu -))): +class KeyedEduRow( + BaseFederationRow, + namedtuple( + "KeyedEduRow", + ("key", "edu"), # tuple(str) - the edu key passed to send_edu # Edu + ), +): """Streams EDUs that have an associated key that is ued to clobber. For example, typing EDUs clobber based on room_id. """ @@ -430,28 +442,19 @@ class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", ( @staticmethod def from_data(data): - return KeyedEduRow( - key=tuple(data["key"]), - edu=Edu(**data["edu"]), - ) + return KeyedEduRow(key=tuple(data["key"]), edu=Edu(**data["edu"])) def to_data(self): - return { - "key": self.key, - "edu": self.edu.get_internal_dict(), - } + return {"key": self.key, "edu": self.edu.get_internal_dict()} def add_to_buffer(self, buff): - buff.keyed_edus.setdefault( - self.edu.destination, {} - )[self.key] = self.edu + buff.keyed_edus.setdefault(self.edu.destination, {})[self.key] = self.edu -class EduRow(BaseFederationRow, namedtuple("EduRow", ( - "edu", # Edu -))): +class EduRow(BaseFederationRow, namedtuple("EduRow", ("edu",))): # Edu """Streams EDUs that don't have keys. See KeyedEduRow """ + TypeId = "e" @staticmethod @@ -465,13 +468,12 @@ class EduRow(BaseFederationRow, namedtuple("EduRow", ( buff.edus.setdefault(self.edu.destination, []).append(self.edu) -class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( - "destination", # str -))): +class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ("destination",))): # str """Streams the fact that either a) there is pending to device messages for users on the remote, or b) a local users device has changed and needs to be sent to the remote. """ + TypeId = "d" @staticmethod @@ -487,23 +489,20 @@ class DeviceRow(BaseFederationRow, namedtuple("DeviceRow", ( TypeToRow = { Row.TypeId: Row - for Row in ( - PresenceRow, - PresenceDestinationsRow, - KeyedEduRow, - EduRow, - DeviceRow, - ) + for Row in (PresenceRow, PresenceDestinationsRow, KeyedEduRow, EduRow, DeviceRow) } -ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", ( - "presence", # list(UserPresenceState) - "presence_destinations", # list of tuples of UserPresenceState and destinations - "keyed_edus", # dict of destination -> { key -> Edu } - "edus", # dict of destination -> [Edu] - "device_destinations", # set of destinations -)) +ParsedFederationStreamData = namedtuple( + "ParsedFederationStreamData", + ( + "presence", # list(UserPresenceState) + "presence_destinations", # list of tuples of UserPresenceState and destinations + "keyed_edus", # dict of destination -> { key -> Edu } + "edus", # dict of destination -> [Edu] + "device_destinations", # set of destinations + ), +) def process_rows_for_federation(transaction_queue, rows): @@ -542,7 +541,7 @@ def process_rows_for_federation(transaction_queue, rows): for state, destinations in buff.presence_destinations: transaction_queue.send_presence_to_destinations( - states=[state], destinations=destinations, + states=[state], destinations=destinations ) for destination, edu_map in iteritems(buff.keyed_edus): diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 4f0f939102..766c5a37cd 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -44,8 +44,8 @@ sent_pdus_destination_dist_count = Counter( ) sent_pdus_destination_dist_total = Counter( - "synapse_federation_client_sent_pdu_destinations:total", "" - "Total number of PDUs queued for sending across all destinations", + "synapse_federation_client_sent_pdu_destinations:total", + "" "Total number of PDUs queued for sending across all destinations", ) @@ -63,14 +63,15 @@ class FederationSender(object): self._transaction_manager = TransactionManager(hs) # map from destination to PerDestinationQueue - self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] + self._per_destination_queues = {} # type: dict[str, PerDestinationQueue] LaterGauge( "synapse_federation_transaction_queue_pending_destinations", "", [], lambda: sum( - 1 for d in self._per_destination_queues.values() + 1 + for d in self._per_destination_queues.values() if d.transmission_loop_running ), ) @@ -108,8 +109,9 @@ class FederationSender(object): # awaiting a call to flush_read_receipts_for_room. The presence of an entry # here for a given room means that we are rate-limiting RR flushes to that room, # and that there is a pending call to _flush_rrs_for_room in the system. - self._queues_awaiting_rr_flush_by_room = { - } # type: dict[str, set[PerDestinationQueue]] + self._queues_awaiting_rr_flush_by_room = ( + {} + ) # type: dict[str, set[PerDestinationQueue]] self._rr_txn_interval_per_room_ms = ( 1000.0 / hs.get_config().federation_rr_transactions_per_room_per_second @@ -141,8 +143,7 @@ class FederationSender(object): # fire off a processing loop in the background run_as_background_process( - "process_event_queue_for_federation", - self._process_event_queue_loop, + "process_event_queue_for_federation", self._process_event_queue_loop ) @defer.inlineCallbacks @@ -152,7 +153,7 @@ class FederationSender(object): while True: last_token = yield self.store.get_federation_out_pos("events") next_token, events = yield self.store.get_all_new_events_stream( - last_token, self._last_poked_id, limit=100, + last_token, self._last_poked_id, limit=100 ) logger.debug("Handling %s -> %s", last_token, next_token) @@ -168,6 +169,9 @@ class FederationSender(object): if not is_mine and send_on_behalf_of is None: return + if not event.internal_metadata.should_proactively_send(): + return + try: # Get the state from before the event. # We need to make sure that this is the state from before @@ -176,7 +180,7 @@ class FederationSender(object): # banned then it won't receive the event because it won't # be in the room after the ban. destinations = yield self.state.get_current_hosts_in_room( - event.room_id, latest_event_ids=event.prev_event_ids(), + event.room_id, latest_event_ids=event.prev_event_ids() ) except Exception: logger.exception( @@ -206,37 +210,40 @@ class FederationSender(object): for event in events: events_by_room.setdefault(event.room_id, []).append(event) - yield logcontext.make_deferred_yieldable(defer.gatherResults( - [ - logcontext.run_in_background(handle_room_events, evs) - for evs in itervalues(events_by_room) - ], - consumeErrors=True - )) - - yield self.store.update_federation_out_pos( - "events", next_token + yield logcontext.make_deferred_yieldable( + defer.gatherResults( + [ + logcontext.run_in_background(handle_room_events, evs) + for evs in itervalues(events_by_room) + ], + consumeErrors=True, + ) ) + yield self.store.update_federation_out_pos("events", next_token) + if events: now = self.clock.time_msec() ts = yield self.store.get_received_ts(events[-1].event_id) synapse.metrics.event_processing_lag.labels( - "federation_sender").set(now - ts) + "federation_sender" + ).set(now - ts) synapse.metrics.event_processing_last_ts.labels( - "federation_sender").set(ts) + "federation_sender" + ).set(ts) events_processed_counter.inc(len(events)) - event_processing_loop_room_count.labels( - "federation_sender" - ).inc(len(events_by_room)) + event_processing_loop_room_count.labels("federation_sender").inc( + len(events_by_room) + ) event_processing_loop_counter.labels("federation_sender").inc() synapse.metrics.event_processing_positions.labels( - "federation_sender").set(next_token) + "federation_sender" + ).set(next_token) finally: self._is_processing = False @@ -309,9 +316,7 @@ class FederationSender(object): if not domains: return - queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get( - room_id - ) + queues_pending_flush = self._queues_awaiting_rr_flush_by_room.get(room_id) # if there is no flush yet scheduled, we will send out these receipts with # immediate flushes, and schedule the next flush for this room. @@ -374,10 +379,9 @@ class FederationSender(object): # updates in quick succession are correctly handled. # We only want to send presence for our own users, so lets always just # filter here just in case. - self.pending_presence.update({ - state.user_id: state for state in states - if self.is_mine_id(state.user_id) - }) + self.pending_presence.update( + {state.user_id: state for state in states if self.is_mine_id(state.user_id)} + ) # We then handle the new pending presence in batches, first figuring # out the destinations we need to send each state to and then poking it diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 564c57203d..9aab12c0d3 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -189,11 +189,21 @@ class PerDestinationQueue(object): pending_pdus = [] while True: - device_message_edus, device_stream_id, dev_list_id = ( - # We have to keep 2 free slots for presence and rr_edus - yield self._get_new_device_messages(MAX_EDUS_PER_TRANSACTION - 2) + # We have to keep 2 free slots for presence and rr_edus + limit = MAX_EDUS_PER_TRANSACTION - 2 + + device_update_edus, dev_list_id = ( + yield self._get_device_update_edus(limit) + ) + + limit -= len(device_update_edus) + + to_device_edus, device_stream_id = ( + yield self._get_to_device_message_edus(limit) ) + pending_edus = device_update_edus + to_device_edus + # BEGIN CRITICAL SECTION # # In order to avoid a race condition, we need to make sure that @@ -208,10 +218,6 @@ class PerDestinationQueue(object): # We can only include at most 50 PDUs per transactions pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:] - pending_edus = [] - - # We can only include at most 100 EDUs per transactions - # rr_edus and pending_presence take at most one slot each pending_edus.extend(self._get_rr_edus(force_flush=False)) pending_presence = self._pending_presence self._pending_presence = {} @@ -232,7 +238,6 @@ class PerDestinationQueue(object): ) ) - pending_edus.extend(device_message_edus) pending_edus.extend( self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus)) ) @@ -272,10 +277,13 @@ class PerDestinationQueue(object): sent_edus_by_type.labels(edu.edu_type).inc() # Remove the acknowledged device messages from the database # Only bother if we actually sent some device messages - if device_message_edus: + if to_device_edus: yield self._store.delete_device_msgs_for_remote( self._destination, device_stream_id ) + + # also mark the device updates as sent + if device_update_edus: logger.info( "Marking as sent %r %r", self._destination, dev_list_id ) @@ -347,12 +355,12 @@ class PerDestinationQueue(object): return pending_edus @defer.inlineCallbacks - def _get_new_device_messages(self, limit): + def _get_device_update_edus(self, limit): last_device_list = self._last_device_list_stream_id # Retrieve list of new device updates to send to the destination now_stream_id, results = yield self._store.get_devices_by_remote( - self._destination, last_device_list, limit=limit, + self._destination, last_device_list, limit=limit ) edus = [ Edu( @@ -366,15 +374,16 @@ class PerDestinationQueue(object): assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs" + defer.returnValue((edus, now_stream_id)) + + @defer.inlineCallbacks + def _get_to_device_message_edus(self, limit): last_device_stream_id = self._last_device_stream_id to_device_stream_id = self._store.get_to_device_stream_token() contents, stream_id = yield self._store.get_new_device_msgs_for_remote( - self._destination, - last_device_stream_id, - to_device_stream_id, - limit - len(edus), + self._destination, last_device_stream_id, to_device_stream_id, limit ) - edus.extend( + edus = [ Edu( origin=self._server_name, destination=self._destination, @@ -382,6 +391,6 @@ class PerDestinationQueue(object): content=content, ) for content in contents - ) + ] - defer.returnValue((edus, stream_id, now_stream_id)) + defer.returnValue((edus, stream_id)) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index 35e6b8ff5b..c987bb9a0d 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -29,9 +29,10 @@ class TransactionManager(object): shared between PerDestinationQueue objects """ + def __init__(self, hs): self._server_name = hs.hostname - self.clock = hs.get_clock() # nb must be called this for @measure_func + self.clock = hs.get_clock() # nb must be called this for @measure_func self._store = hs.get_datastore() self._transaction_actions = TransactionActions(self._store) self._transport_layer = hs.get_federation_transport_client() @@ -55,9 +56,9 @@ class TransactionManager(object): txn_id = str(self._next_txn_id) logger.debug( - "TX [%s] {%s} Attempting new transaction" - " (pdus: %d, edus: %d)", - destination, txn_id, + "TX [%s] {%s} Attempting new transaction" " (pdus: %d, edus: %d)", + destination, + txn_id, len(pdus), len(edus), ) @@ -79,9 +80,9 @@ class TransactionManager(object): logger.debug("TX [%s] Persisted transaction", destination) logger.info( - "TX [%s] {%s} Sending transaction [%s]," - " (PDUs: %d, EDUs: %d)", - destination, txn_id, + "TX [%s] {%s} Sending transaction [%s]," " (PDUs: %d, EDUs: %d)", + destination, + txn_id, transaction.transaction_id, len(pdus), len(edus), @@ -112,20 +113,12 @@ class TransactionManager(object): response = e.response if e.code in (401, 404, 429) or 500 <= e.code: - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) raise e - logger.info( - "TX [%s] {%s} got %d response", - destination, txn_id, code - ) + logger.info("TX [%s] {%s} got %d response", destination, txn_id, code) - yield self._transaction_actions.delivered( - transaction, code, response - ) + yield self._transaction_actions.delivered(transaction, code, response) logger.debug("TX [%s] {%s} Marked as delivered", destination, txn_id) @@ -134,13 +127,18 @@ class TransactionManager(object): if "error" in r: logger.warn( "TX [%s] {%s} Remote returned error for %s: %s", - destination, txn_id, e_id, r, + destination, + txn_id, + e_id, + r, ) else: for p in pdus: logger.warn( "TX [%s] {%s} Failed to send event %s", - destination, txn_id, p.event_id, + destination, + txn_id, + p.event_id, ) success = False diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index e424c40fdf..aecd142309 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -48,12 +48,13 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_room_state dest=%s, room=%s", - destination, room_id) + logger.debug("get_room_state dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state/%s", room_id) return self.client.get_json( - destination, path=path, args={"event_id": event_id}, + destination, + path=path, + args={"event_id": event_id}, try_trailing_slash_on_400=True, ) @@ -71,12 +72,13 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_room_state_ids dest=%s, room=%s", - destination, room_id) + logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) path = _create_v1_path("/state_ids/%s", room_id) return self.client.get_json( - destination, path=path, args={"event_id": event_id}, + destination, + path=path, + args={"event_id": event_id}, try_trailing_slash_on_400=True, ) @@ -94,13 +96,11 @@ class TransportLayerClient(object): Returns: Deferred: Results in a dict received from the remote homeserver. """ - logger.debug("get_pdu dest=%s, event_id=%s", - destination, event_id) + logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) path = _create_v1_path("/event/%s", event_id) return self.client.get_json( - destination, path=path, timeout=timeout, - try_trailing_slash_on_400=True, + destination, path=path, timeout=timeout, try_trailing_slash_on_400=True ) @log_function @@ -119,7 +119,10 @@ class TransportLayerClient(object): """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%s, limit=%s", - destination, room_id, repr(event_tuples), str(limit) + destination, + room_id, + repr(event_tuples), + str(limit), ) if not event_tuples: @@ -128,16 +131,10 @@ class TransportLayerClient(object): path = _create_v1_path("/backfill/%s", room_id) - args = { - "v": event_tuples, - "limit": [str(limit)], - } + args = {"v": event_tuples, "limit": [str(limit)]} return self.client.get_json( - destination, - path=path, - args=args, - try_trailing_slash_on_400=True, + destination, path=path, args=args, try_trailing_slash_on_400=True ) @defer.inlineCallbacks @@ -163,7 +160,8 @@ class TransportLayerClient(object): """ logger.debug( "send_data dest=%s, txid=%s", - transaction.destination, transaction.transaction_id + transaction.destination, + transaction.transaction_id, ) if transaction.destination == self.server_name: @@ -189,8 +187,9 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function - def make_query(self, destination, query_type, args, retry_on_dns_fail, - ignore_backoff=False): + def make_query( + self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False + ): path = _create_v1_path("/query/%s", query_type) content = yield self.client.get_json( @@ -235,8 +234,8 @@ class TransportLayerClient(object): valid_memberships = {Membership.JOIN, Membership.LEAVE} if membership not in valid_memberships: raise RuntimeError( - "make_membership_event called with membership='%s', must be one of %s" % - (membership, ",".join(valid_memberships)) + "make_membership_event called with membership='%s', must be one of %s" + % (membership, ",".join(valid_memberships)) ) path = _create_v1_path("/make_%s/%s/%s", membership, room_id, user_id) @@ -268,9 +267,7 @@ class TransportLayerClient(object): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, + destination=destination, path=path, data=content ) defer.returnValue(response) @@ -284,7 +281,6 @@ class TransportLayerClient(object): destination=destination, path=path, data=content, - # we want to do our best to send this through. The problem is # that if it fails, we won't retry it later, so if the remote # server was just having a momentary blip, the room will be out of @@ -300,10 +296,7 @@ class TransportLayerClient(object): path = _create_v1_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) defer.returnValue(response) @@ -314,26 +307,27 @@ class TransportLayerClient(object): path = _create_v2_path("/invite/%s/%s", room_id, event_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) defer.returnValue(response) @defer.inlineCallbacks @log_function - def get_public_rooms(self, remote_server, limit, since_token, - search_filter=None, include_all_networks=False, - third_party_instance_id=None): + def get_public_rooms( + self, + remote_server, + limit, + since_token, + search_filter=None, + include_all_networks=False, + third_party_instance_id=None, + ): path = _create_v1_path("/publicRooms") - args = { - "include_all_networks": "true" if include_all_networks else "false", - } + args = {"include_all_networks": "true" if include_all_networks else "false"} if third_party_instance_id: - args["third_party_instance_id"] = third_party_instance_id, + args["third_party_instance_id"] = (third_party_instance_id,) if limit: args["limit"] = [str(limit)] if since_token: @@ -342,10 +336,7 @@ class TransportLayerClient(object): # TODO(erikj): Actually send the search_filter across federation. response = yield self.client.get_json( - destination=remote_server, - path=path, - args=args, - ignore_backoff=True, + destination=remote_server, path=path, args=args, ignore_backoff=True ) defer.returnValue(response) @@ -353,12 +344,10 @@ class TransportLayerClient(object): @defer.inlineCallbacks @log_function def exchange_third_party_invite(self, destination, room_id, event_dict): - path = _create_v1_path("/exchange_third_party_invite/%s", room_id,) + path = _create_v1_path("/exchange_third_party_invite/%s", room_id) response = yield self.client.put_json( - destination=destination, - path=path, - data=event_dict, + destination=destination, path=path, data=event_dict ) defer.returnValue(response) @@ -368,10 +357,7 @@ class TransportLayerClient(object): def get_event_auth(self, destination, room_id, event_id): path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = yield self.client.get_json( - destination=destination, - path=path, - ) + content = yield self.client.get_json(destination=destination, path=path) defer.returnValue(content) @@ -381,9 +367,7 @@ class TransportLayerClient(object): path = _create_v1_path("/query_auth/%s/%s", room_id, event_id) content = yield self.client.post_json( - destination=destination, - path=path, - data=content, + destination=destination, path=path, data=content ) defer.returnValue(content) @@ -416,10 +400,7 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/query") content = yield self.client.post_json( - destination=destination, - path=path, - data=query_content, - timeout=timeout, + destination=destination, path=path, data=query_content, timeout=timeout ) defer.returnValue(content) @@ -443,9 +424,7 @@ class TransportLayerClient(object): path = _create_v1_path("/user/devices/%s", user_id) content = yield self.client.get_json( - destination=destination, - path=path, - timeout=timeout, + destination=destination, path=path, timeout=timeout ) defer.returnValue(content) @@ -479,18 +458,23 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/claim") content = yield self.client.post_json( - destination=destination, - path=path, - data=query_content, - timeout=timeout, + destination=destination, path=path, data=query_content, timeout=timeout ) defer.returnValue(content) @defer.inlineCallbacks @log_function - def get_missing_events(self, destination, room_id, earliest_events, - latest_events, limit, min_depth, timeout): - path = _create_v1_path("/get_missing_events/%s", room_id,) + def get_missing_events( + self, + destination, + room_id, + earliest_events, + latest_events, + limit, + min_depth, + timeout, + ): + path = _create_v1_path("/get_missing_events/%s", room_id) content = yield self.client.post_json( destination=destination, @@ -510,7 +494,7 @@ class TransportLayerClient(object): def get_group_profile(self, destination, group_id, requester_user_id): """Get a group profile """ - path = _create_v1_path("/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id) return self.client.get_json( destination=destination, @@ -529,7 +513,7 @@ class TransportLayerClient(object): requester_user_id (str) content (dict): The new profile of the group """ - path = _create_v1_path("/groups/%s/profile", group_id,) + path = _create_v1_path("/groups/%s/profile", group_id) return self.client.post_json( destination=destination, @@ -543,7 +527,7 @@ class TransportLayerClient(object): def get_group_summary(self, destination, group_id, requester_user_id): """Get a group summary """ - path = _create_v1_path("/groups/%s/summary", group_id,) + path = _create_v1_path("/groups/%s/summary", group_id) return self.client.get_json( destination=destination, @@ -556,7 +540,7 @@ class TransportLayerClient(object): def get_rooms_in_group(self, destination, group_id, requester_user_id): """Get all rooms in a group """ - path = _create_v1_path("/groups/%s/rooms", group_id,) + path = _create_v1_path("/groups/%s/rooms", group_id) return self.client.get_json( destination=destination, @@ -565,11 +549,12 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def add_room_to_group(self, destination, group_id, requester_user_id, room_id, - content): + def add_room_to_group( + self, destination, group_id, requester_user_id, room_id, content + ): """Add a room to a group """ - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.post_json( destination=destination, @@ -579,13 +564,13 @@ class TransportLayerClient(object): ignore_backoff=True, ) - def update_room_in_group(self, destination, group_id, requester_user_id, room_id, - config_key, content): + def update_room_in_group( + self, destination, group_id, requester_user_id, room_id, config_key, content + ): """Update room in group """ path = _create_v1_path( - "/groups/%s/room/%s/config/%s", - group_id, room_id, config_key, + "/groups/%s/room/%s/config/%s", group_id, room_id, config_key ) return self.client.post_json( @@ -599,7 +584,7 @@ class TransportLayerClient(object): def remove_room_from_group(self, destination, group_id, requester_user_id, room_id): """Remove a room from a group """ - path = _create_v1_path("/groups/%s/room/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/room/%s", group_id, room_id) return self.client.delete_json( destination=destination, @@ -612,7 +597,7 @@ class TransportLayerClient(object): def get_users_in_group(self, destination, group_id, requester_user_id): """Get users in a group """ - path = _create_v1_path("/groups/%s/users", group_id,) + path = _create_v1_path("/groups/%s/users", group_id) return self.client.get_json( destination=destination, @@ -625,7 +610,7 @@ class TransportLayerClient(object): def get_invited_users_in_group(self, destination, group_id, requester_user_id): """Get users that have been invited to a group """ - path = _create_v1_path("/groups/%s/invited_users", group_id,) + path = _create_v1_path("/groups/%s/invited_users", group_id) return self.client.get_json( destination=destination, @@ -638,16 +623,10 @@ class TransportLayerClient(object): def accept_group_invite(self, destination, group_id, user_id, content): """Accept a group invite """ - path = _create_v1_path( - "/groups/%s/users/%s/accept_invite", - group_id, user_id, - ) + path = _create_v1_path("/groups/%s/users/%s/accept_invite", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function @@ -657,14 +636,13 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/%s/users/%s/join", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def invite_to_group(self, destination, group_id, user_id, requester_user_id, content): + def invite_to_group( + self, destination, group_id, user_id, requester_user_id, content + ): """Invite a user to a group """ path = _create_v1_path("/groups/%s/users/%s/invite", group_id, user_id) @@ -686,15 +664,13 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/local/%s/users/%s/invite", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def remove_user_from_group(self, destination, group_id, requester_user_id, - user_id, content): + def remove_user_from_group( + self, destination, group_id, requester_user_id, user_id, content + ): """Remove a user fron a group """ path = _create_v1_path("/groups/%s/users/%s/remove", group_id, user_id) @@ -708,8 +684,9 @@ class TransportLayerClient(object): ) @log_function - def remove_user_from_group_notification(self, destination, group_id, user_id, - content): + def remove_user_from_group_notification( + self, destination, group_id, user_id, content + ): """Sent by group server to inform a user's server that they have been kicked from the group. """ @@ -717,10 +694,7 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/local/%s/users/%s/remove", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function @@ -732,24 +706,24 @@ class TransportLayerClient(object): path = _create_v1_path("/groups/%s/renew_attestation/%s", group_id, user_id) return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @log_function - def update_group_summary_room(self, destination, group_id, user_id, room_id, - category_id, content): + def update_group_summary_room( + self, destination, group_id, user_id, room_id, category_id, content + ): """Update a room entry in a group summary """ if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", - group_id, category_id, room_id, + group_id, + category_id, + room_id, ) else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) return self.client.post_json( destination=destination, @@ -760,17 +734,20 @@ class TransportLayerClient(object): ) @log_function - def delete_group_summary_room(self, destination, group_id, user_id, room_id, - category_id): + def delete_group_summary_room( + self, destination, group_id, user_id, room_id, category_id + ): """Delete a room entry in a group summary """ if category_id: path = _create_v1_path( "/groups/%s/summary/categories/%s/rooms/%s", - group_id, category_id, room_id, + group_id, + category_id, + room_id, ) else: - path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id,) + path = _create_v1_path("/groups/%s/summary/rooms/%s", group_id, room_id) return self.client.delete_json( destination=destination, @@ -783,7 +760,7 @@ class TransportLayerClient(object): def get_group_categories(self, destination, group_id, requester_user_id): """Get all categories in a group """ - path = _create_v1_path("/groups/%s/categories", group_id,) + path = _create_v1_path("/groups/%s/categories", group_id) return self.client.get_json( destination=destination, @@ -796,7 +773,7 @@ class TransportLayerClient(object): def get_group_category(self, destination, group_id, requester_user_id, category_id): """Get category info in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.get_json( destination=destination, @@ -806,11 +783,12 @@ class TransportLayerClient(object): ) @log_function - def update_group_category(self, destination, group_id, requester_user_id, category_id, - content): + def update_group_category( + self, destination, group_id, requester_user_id, category_id, content + ): """Update a category in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.post_json( destination=destination, @@ -821,11 +799,12 @@ class TransportLayerClient(object): ) @log_function - def delete_group_category(self, destination, group_id, requester_user_id, - category_id): + def delete_group_category( + self, destination, group_id, requester_user_id, category_id + ): """Delete a category in a group """ - path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id,) + path = _create_v1_path("/groups/%s/categories/%s", group_id, category_id) return self.client.delete_json( destination=destination, @@ -838,7 +817,7 @@ class TransportLayerClient(object): def get_group_roles(self, destination, group_id, requester_user_id): """Get all roles in a group """ - path = _create_v1_path("/groups/%s/roles", group_id,) + path = _create_v1_path("/groups/%s/roles", group_id) return self.client.get_json( destination=destination, @@ -851,7 +830,7 @@ class TransportLayerClient(object): def get_group_role(self, destination, group_id, requester_user_id, role_id): """Get a roles info """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.get_json( destination=destination, @@ -861,11 +840,12 @@ class TransportLayerClient(object): ) @log_function - def update_group_role(self, destination, group_id, requester_user_id, role_id, - content): + def update_group_role( + self, destination, group_id, requester_user_id, role_id, content + ): """Update a role in a group """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.post_json( destination=destination, @@ -879,7 +859,7 @@ class TransportLayerClient(object): def delete_group_role(self, destination, group_id, requester_user_id, role_id): """Delete a role in a group """ - path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id,) + path = _create_v1_path("/groups/%s/roles/%s", group_id, role_id) return self.client.delete_json( destination=destination, @@ -889,17 +869,17 @@ class TransportLayerClient(object): ) @log_function - def update_group_summary_user(self, destination, group_id, requester_user_id, - user_id, role_id, content): + def update_group_summary_user( + self, destination, group_id, requester_user_id, user_id, role_id, content + ): """Update a users entry in a group """ if role_id: path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", - group_id, role_id, user_id, + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id ) else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) return self.client.post_json( destination=destination, @@ -910,11 +890,10 @@ class TransportLayerClient(object): ) @log_function - def set_group_join_policy(self, destination, group_id, requester_user_id, - content): + def set_group_join_policy(self, destination, group_id, requester_user_id, content): """Sets the join policy for a group """ - path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id,) + path = _create_v1_path("/groups/%s/settings/m.join_policy", group_id) return self.client.put_json( destination=destination, @@ -925,17 +904,17 @@ class TransportLayerClient(object): ) @log_function - def delete_group_summary_user(self, destination, group_id, requester_user_id, - user_id, role_id): + def delete_group_summary_user( + self, destination, group_id, requester_user_id, user_id, role_id + ): """Delete a users entry in a group """ if role_id: path = _create_v1_path( - "/groups/%s/summary/roles/%s/users/%s", - group_id, role_id, user_id, + "/groups/%s/summary/roles/%s/users/%s", group_id, role_id, user_id ) else: - path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id,) + path = _create_v1_path("/groups/%s/summary/users/%s", group_id, user_id) return self.client.delete_json( destination=destination, @@ -953,10 +932,7 @@ class TransportLayerClient(object): content = {"user_ids": user_ids} return self.client.post_json( - destination=destination, - path=path, - data=content, - ignore_backoff=True, + destination=destination, path=path, data=content, ignore_backoff=True ) @@ -975,9 +951,8 @@ def _create_v1_path(path, *args): Returns: str """ - return ( - FEDERATION_V1_PREFIX - + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return FEDERATION_V1_PREFIX + path % tuple( + urllib.parse.quote(arg, "") for arg in args ) @@ -996,7 +971,6 @@ def _create_v2_path(path, *args): Returns: str """ - return ( - FEDERATION_V2_PREFIX - + path % tuple(urllib.parse.quote(arg, "") for arg in args) + return FEDERATION_V2_PREFIX + path % tuple( + urllib.parse.quote(arg, "") for arg in args ) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 949a5fb2aa..955f0f4308 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -66,8 +66,7 @@ class TransportLayerServer(JsonResource): self.authenticator = Authenticator(hs) self.ratelimiter = FederationRateLimiter( - self.clock, - config=hs.config.rc_federation, + self.clock, config=hs.config.rc_federation ) self.register_servlets() @@ -84,11 +83,13 @@ class TransportLayerServer(JsonResource): class AuthenticationError(SynapseError): """There was a problem authenticating the request""" + pass class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" + pass @@ -105,8 +106,8 @@ class Authenticator(object): def authenticate_request(self, request, content): now = self._clock.time_msec() json_request = { - "method": request.method.decode('ascii'), - "uri": request.uri.decode('ascii'), + "method": request.method.decode("ascii"), + "uri": request.uri.decode("ascii"), "destination": self.server_name, "signatures": {}, } @@ -120,7 +121,7 @@ class Authenticator(object): if not auth_headers: raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED, + 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) for auth in auth_headers: @@ -130,14 +131,14 @@ class Authenticator(object): json_request["signatures"].setdefault(origin, {})[key] = sig if ( - self.federation_domain_whitelist is not None and - origin not in self.federation_domain_whitelist + self.federation_domain_whitelist is not None + and origin not in self.federation_domain_whitelist ): raise FederationDeniedError(origin) if not json_request["signatures"]: raise NoAuthenticationError( - 401, "Missing Authorization headers", Codes.UNAUTHORIZED, + 401, "Missing Authorization headers", Codes.UNAUTHORIZED ) yield self.keyring.verify_json_for_server( @@ -177,12 +178,12 @@ def _parse_auth_header(header_bytes): AuthenticationError if the header could not be parsed """ try: - header_str = header_bytes.decode('utf-8') + header_str = header_bytes.decode("utf-8") params = header_str.split(" ")[1].split(",") param_dict = dict(kv.split("=") for kv in params) def strip_quotes(value): - if value.startswith("\""): + if value.startswith('"'): return value[1:-1] else: return value @@ -198,11 +199,11 @@ def _parse_auth_header(header_bytes): except Exception as e: logger.warn( "Error parsing auth header '%s': %s", - header_bytes.decode('ascii', 'replace'), + header_bytes.decode("ascii", "replace"), e, ) raise AuthenticationError( - 400, "Malformed Authorization header", Codes.UNAUTHORIZED, + 400, "Malformed Authorization header", Codes.UNAUTHORIZED ) @@ -242,6 +243,7 @@ class BaseFederationServlet(object): Exception: other exceptions will be caught, logged, and a 500 will be returned. """ + REQUIRE_AUTH = True PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version @@ -293,9 +295,7 @@ class BaseFederationServlet(object): origin, content, request.args, *args, **kwargs ) else: - response = yield func( - origin, content, request.args, *args, **kwargs - ) + response = yield func(origin, content, request.args, *args, **kwargs) defer.returnValue(response) @@ -343,14 +343,12 @@ class FederationSendServlet(BaseFederationServlet): try: transaction_data = content - logger.debug( - "Decoded %s: %s", - transaction_id, str(transaction_data) - ) + logger.debug("Decoded %s: %s", transaction_id, str(transaction_data)) logger.info( "Received txn %s from %s. (PDUs: %d, EDUs: %d)", - transaction_id, origin, + transaction_id, + origin, len(transaction_data.get("pdus", [])), len(transaction_data.get("edus", [])), ) @@ -361,8 +359,7 @@ class FederationSendServlet(BaseFederationServlet): # Add some extra data to the transaction dict that isn't included # in the request body. transaction_data.update( - transaction_id=transaction_id, - destination=self.server_name + transaction_id=transaction_id, destination=self.server_name ) except Exception as e: @@ -372,7 +369,7 @@ class FederationSendServlet(BaseFederationServlet): try: code, response = yield self.handler.on_incoming_transaction( - origin, transaction_data, + origin, transaction_data ) except Exception: logger.exception("on_incoming_transaction failed") @@ -416,7 +413,7 @@ class FederationBackfillServlet(BaseFederationServlet): PATH = "/backfill/(?P<context>[^/]*)/?" def on_GET(self, origin, content, query, context): - versions = [x.decode('ascii') for x in query[b"v"]] + versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) if not limit: @@ -432,7 +429,7 @@ class FederationQueryServlet(BaseFederationServlet): def on_GET(self, origin, content, query, query_type): return self.handler.on_query_request( query_type, - {k.decode('utf8'): v[0].decode("utf-8") for k, v in query.items()} + {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()}, ) @@ -456,15 +453,14 @@ class FederationMakeJoinServlet(BaseFederationServlet): Deferred[(int, object)|None]: either (response code, response object) to return a JSON response, or None if the request has already been handled. """ - versions = query.get(b'ver') + versions = query.get(b"ver") if versions is not None: supported_versions = [v.decode("utf-8") for v in versions] else: supported_versions = ["1"] content = yield self.handler.on_make_join_request( - origin, context, user_id, - supported_versions=supported_versions, + origin, context, user_id, supported_versions=supported_versions ) defer.returnValue((200, content)) @@ -474,9 +470,7 @@ class FederationMakeLeaveServlet(BaseFederationServlet): @defer.inlineCallbacks def on_GET(self, origin, content, query, context, user_id): - content = yield self.handler.on_make_leave_request( - origin, context, user_id, - ) + content = yield self.handler.on_make_leave_request(origin, context, user_id) defer.returnValue((200, content)) @@ -517,7 +511,7 @@ class FederationV1InviteServlet(BaseFederationServlet): # state resolution algorithm, and we don't use that for processing # invites content = yield self.handler.on_invite_request( - origin, content, room_version=RoomVersions.V1.identifier, + origin, content, room_version=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` @@ -545,7 +539,7 @@ class FederationV2InviteServlet(BaseFederationServlet): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state content = yield self.handler.on_invite_request( - origin, event, room_version=room_version, + origin, event, room_version=room_version ) defer.returnValue((200, content)) @@ -629,8 +623,10 @@ class On3pidBindServlet(BaseFederationServlet): for invite in content["invites"]: try: if "signed" not in invite or "token" not in invite["signed"]: - message = ("Rejecting received notification of third-" - "party invite without signed: %s" % (invite,)) + message = ( + "Rejecting received notification of third-" + "party invite without signed: %s" % (invite,) + ) logger.info(message) raise SynapseError(400, message) yield self.handler.exchange_third_party_invite( @@ -671,18 +667,23 @@ class OpenIdUserInfo(BaseFederationServlet): def on_GET(self, origin, content, query): token = query.get(b"access_token", [None])[0] if token is None: - defer.returnValue((401, { - "errcode": "M_MISSING_TOKEN", "error": "Access Token required" - })) + defer.returnValue( + (401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}) + ) return - user_id = yield self.handler.on_openid_userinfo(token.decode('ascii')) + user_id = yield self.handler.on_openid_userinfo(token.decode("ascii")) if user_id is None: - defer.returnValue((401, { - "errcode": "M_UNKNOWN_TOKEN", - "error": "Access Token unknown or expired" - })) + defer.returnValue( + ( + 401, + { + "errcode": "M_UNKNOWN_TOKEN", + "error": "Access Token unknown or expired", + }, + ) + ) defer.returnValue((200, {"sub": user_id})) @@ -720,15 +721,15 @@ class PublicRoomList(BaseFederationServlet): PATH = "/publicRooms" - def __init__(self, handler, authenticator, ratelimiter, server_name, deny_access): + def __init__(self, handler, authenticator, ratelimiter, server_name, allow_access): super(PublicRoomList, self).__init__( - handler, authenticator, ratelimiter, server_name, + handler, authenticator, ratelimiter, server_name ) - self.deny_access = deny_access + self.allow_access = allow_access @defer.inlineCallbacks def on_GET(self, origin, content, query): - if self.deny_access: + if not self.allow_access: raise FederationDeniedError(origin) limit = parse_integer_from_args(query, "limit", 0) @@ -748,9 +749,7 @@ class PublicRoomList(BaseFederationServlet): network_tuple = ThirdPartyInstanceID(None, None) data = yield self.handler.get_local_public_room_list( - limit, since_token, - network_tuple=network_tuple, - from_federation=True, + limit, since_token, network_tuple=network_tuple, from_federation=True ) defer.returnValue((200, data)) @@ -761,17 +760,18 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False def on_GET(self, origin, content, query): - return defer.succeed((200, { - "server": { - "name": "Synapse", - "version": get_version_string(synapse) - }, - })) + return defer.succeed( + ( + 200, + {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, + ) + ) class FederationGroupsProfileServlet(BaseFederationServlet): """Get/set the basic profile of a group on behalf of a user """ + PATH = "/groups/(?P<group_id>[^/]*)/profile" @defer.inlineCallbacks @@ -780,9 +780,7 @@ class FederationGroupsProfileServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_profile( - group_id, requester_user_id - ) + new_content = yield self.handler.get_group_profile(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -808,9 +806,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_group_summary( - group_id, requester_user_id - ) + new_content = yield self.handler.get_group_summary(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -818,6 +814,7 @@ class FederationGroupsSummaryServlet(BaseFederationServlet): class FederationGroupsRoomsServlet(BaseFederationServlet): """Get the rooms in a group on behalf of a user """ + PATH = "/groups/(?P<group_id>[^/]*)/rooms" @defer.inlineCallbacks @@ -826,9 +823,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_rooms_in_group( - group_id, requester_user_id - ) + new_content = yield self.handler.get_rooms_in_group(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -836,6 +831,7 @@ class FederationGroupsRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsServlet(BaseFederationServlet): """Add/remove room from group """ + PATH = "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" @defer.inlineCallbacks @@ -857,7 +853,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.remove_room_from_group( - group_id, requester_user_id, room_id, + group_id, requester_user_id, room_id ) defer.returnValue((200, new_content)) @@ -866,6 +862,7 @@ class FederationGroupsAddRoomsServlet(BaseFederationServlet): class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): """Update room config in group """ + PATH = ( "/groups/(?P<group_id>[^/]*)/room/(?P<room_id>[^/]*)" "/config/(?P<config_key>[^/]*)" @@ -878,7 +875,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") result = yield self.groups_handler.update_room_in_group( - group_id, requester_user_id, room_id, config_key, content, + group_id, requester_user_id, room_id, config_key, content ) defer.returnValue((200, result)) @@ -887,6 +884,7 @@ class FederationGroupsAddRoomsConfigServlet(BaseFederationServlet): class FederationGroupsUsersServlet(BaseFederationServlet): """Get the users in a group on behalf of a user """ + PATH = "/groups/(?P<group_id>[^/]*)/users" @defer.inlineCallbacks @@ -895,9 +893,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - new_content = yield self.handler.get_users_in_group( - group_id, requester_user_id - ) + new_content = yield self.handler.get_users_in_group(group_id, requester_user_id) defer.returnValue((200, new_content)) @@ -905,6 +901,7 @@ class FederationGroupsUsersServlet(BaseFederationServlet): class FederationGroupsInvitedUsersServlet(BaseFederationServlet): """Get the users that have been invited to a group """ + PATH = "/groups/(?P<group_id>[^/]*)/invited_users" @defer.inlineCallbacks @@ -923,6 +920,7 @@ class FederationGroupsInvitedUsersServlet(BaseFederationServlet): class FederationGroupsInviteServlet(BaseFederationServlet): """Ask a group server to invite someone to the group """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @defer.inlineCallbacks @@ -932,7 +930,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.invite_to_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, new_content)) @@ -941,6 +939,7 @@ class FederationGroupsInviteServlet(BaseFederationServlet): class FederationGroupsAcceptInviteServlet(BaseFederationServlet): """Accept an invitation from the group server """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite" @defer.inlineCallbacks @@ -948,9 +947,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.accept_invite( - group_id, user_id, content, - ) + new_content = yield self.handler.accept_invite(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -958,6 +955,7 @@ class FederationGroupsAcceptInviteServlet(BaseFederationServlet): class FederationGroupsJoinServlet(BaseFederationServlet): """Attempt to join a group """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/join" @defer.inlineCallbacks @@ -965,9 +963,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") - new_content = yield self.handler.join_group( - group_id, user_id, content, - ) + new_content = yield self.handler.join_group(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -975,6 +971,7 @@ class FederationGroupsJoinServlet(BaseFederationServlet): class FederationGroupsRemoveUserServlet(BaseFederationServlet): """Leave or kick a user from the group """ + PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @defer.inlineCallbacks @@ -984,7 +981,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): raise SynapseError(403, "requester_user_id doesn't match origin") new_content = yield self.handler.remove_user_from_group( - group_id, user_id, requester_user_id, content, + group_id, user_id, requester_user_id, content ) defer.returnValue((200, new_content)) @@ -993,6 +990,7 @@ class FederationGroupsRemoveUserServlet(BaseFederationServlet): class FederationGroupsLocalInviteServlet(BaseFederationServlet): """A group server has invited a local user """ + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite" @defer.inlineCallbacks @@ -1000,9 +998,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") - new_content = yield self.handler.on_invite( - group_id, user_id, content, - ) + new_content = yield self.handler.on_invite(group_id, user_id, content) defer.returnValue((200, new_content)) @@ -1010,6 +1006,7 @@ class FederationGroupsLocalInviteServlet(BaseFederationServlet): class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): """A group server has removed a local user """ + PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove" @defer.inlineCallbacks @@ -1018,7 +1015,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): raise SynapseError(403, "user_id doesn't match origin") new_content = yield self.handler.user_removed_from_group( - group_id, user_id, content, + group_id, user_id, content ) defer.returnValue((200, new_content)) @@ -1027,6 +1024,7 @@ class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet): class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): """A group or user's server renews their attestation """ + PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)" @defer.inlineCallbacks @@ -1047,6 +1045,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): - /groups/:group/summary/rooms/:room_id - /groups/:group/summary/categories/:category/rooms/:room_id """ + PATH = ( "/groups/(?P<group_id>[^/]*)/summary" "(/categories/(?P<category_id>[^/]+))?" @@ -1063,7 +1062,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.update_group_summary_room( - group_id, requester_user_id, + group_id, + requester_user_id, room_id=room_id, category_id=category_id, content=content, @@ -1081,9 +1081,7 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.delete_group_summary_room( - group_id, requester_user_id, - room_id=room_id, - category_id=category_id, + group_id, requester_user_id, room_id=room_id, category_id=category_id ) defer.returnValue((200, resp)) @@ -1092,9 +1090,8 @@ class FederationGroupsSummaryRoomsServlet(BaseFederationServlet): class FederationGroupsCategoriesServlet(BaseFederationServlet): """Get all categories for a group """ - PATH = ( - "/groups/(?P<group_id>[^/]*)/categories/?" - ) + + PATH = "/groups/(?P<group_id>[^/]*)/categories/?" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -1102,9 +1099,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_categories( - group_id, requester_user_id, - ) + resp = yield self.handler.get_group_categories(group_id, requester_user_id) defer.returnValue((200, resp)) @@ -1112,9 +1107,8 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet): class FederationGroupsCategoryServlet(BaseFederationServlet): """Add/remove/get a category in a group """ - PATH = ( - "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" - ) + + PATH = "/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id, category_id): @@ -1138,7 +1132,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.upsert_group_category( - group_id, requester_user_id, category_id, content, + group_id, requester_user_id, category_id, content ) defer.returnValue((200, resp)) @@ -1153,7 +1147,7 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): raise SynapseError(400, "category_id cannot be empty string") resp = yield self.handler.delete_group_category( - group_id, requester_user_id, category_id, + group_id, requester_user_id, category_id ) defer.returnValue((200, resp)) @@ -1162,9 +1156,8 @@ class FederationGroupsCategoryServlet(BaseFederationServlet): class FederationGroupsRolesServlet(BaseFederationServlet): """Get roles in a group """ - PATH = ( - "/groups/(?P<group_id>[^/]*)/roles/?" - ) + + PATH = "/groups/(?P<group_id>[^/]*)/roles/?" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id): @@ -1172,9 +1165,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_roles( - group_id, requester_user_id, - ) + resp = yield self.handler.get_group_roles(group_id, requester_user_id) defer.returnValue((200, resp)) @@ -1182,9 +1173,8 @@ class FederationGroupsRolesServlet(BaseFederationServlet): class FederationGroupsRoleServlet(BaseFederationServlet): """Add/remove/get a role in a group """ - PATH = ( - "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" - ) + + PATH = "/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)" @defer.inlineCallbacks def on_GET(self, origin, content, query, group_id, role_id): @@ -1192,9 +1182,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") - resp = yield self.handler.get_group_role( - group_id, requester_user_id, role_id - ) + resp = yield self.handler.get_group_role(group_id, requester_user_id, role_id) defer.returnValue((200, resp)) @@ -1208,7 +1196,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.update_group_role( - group_id, requester_user_id, role_id, content, + group_id, requester_user_id, role_id, content ) defer.returnValue((200, resp)) @@ -1223,7 +1211,7 @@ class FederationGroupsRoleServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.delete_group_role( - group_id, requester_user_id, role_id, + group_id, requester_user_id, role_id ) defer.returnValue((200, resp)) @@ -1236,6 +1224,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): - /groups/:group/summary/users/:user_id - /groups/:group/summary/roles/:role/users/:user_id """ + PATH = ( "/groups/(?P<group_id>[^/]*)/summary" "(/roles/(?P<role_id>[^/]+))?" @@ -1252,7 +1241,8 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.update_group_summary_user( - group_id, requester_user_id, + group_id, + requester_user_id, user_id=user_id, role_id=role_id, content=content, @@ -1270,9 +1260,7 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): raise SynapseError(400, "role_id cannot be empty string") resp = yield self.handler.delete_group_summary_user( - group_id, requester_user_id, - user_id=user_id, - role_id=role_id, + group_id, requester_user_id, user_id=user_id, role_id=role_id ) defer.returnValue((200, resp)) @@ -1281,14 +1269,13 @@ class FederationGroupsSummaryUsersServlet(BaseFederationServlet): class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): """Get roles in a group """ - PATH = ( - "/get_groups_publicised" - ) + + PATH = "/get_groups_publicised" @defer.inlineCallbacks def on_POST(self, origin, content, query): resp = yield self.handler.bulk_get_publicised_groups( - content["user_ids"], proxy=False, + content["user_ids"], proxy=False ) defer.returnValue((200, resp)) @@ -1297,6 +1284,7 @@ class FederationGroupsBulkPublicisedServlet(BaseFederationServlet): class FederationGroupsSettingJoinPolicyServlet(BaseFederationServlet): """Sets whether a group is joinable without an invite or knock """ + PATH = "/groups/(?P<group_id>[^/]*)/settings/m.join_policy" @defer.inlineCallbacks @@ -1317,6 +1305,7 @@ class RoomComplexityServlet(BaseFederationServlet): Indicates to other servers how complex (and therefore likely resource-intensive) a public room this server knows about is. """ + PATH = "/rooms/(?P<room_id>[^/]*)/complexity" PREFIX = FEDERATION_UNSTABLE_PREFIX @@ -1325,9 +1314,7 @@ class RoomComplexityServlet(BaseFederationServlet): store = self.handler.hs.get_datastore() - is_public = yield store.is_room_world_readable_or_publicly_joinable( - room_id - ) + is_public = yield store.is_room_world_readable_or_publicly_joinable(room_id) if not is_public: raise SynapseError(404, "Room not found", errcode=Codes.INVALID_PARAM) @@ -1362,13 +1349,9 @@ FEDERATION_SERVLET_CLASSES = ( RoomComplexityServlet, ) -OPENID_SERVLET_CLASSES = ( - OpenIdUserInfo, -) +OPENID_SERVLET_CLASSES = (OpenIdUserInfo,) -ROOM_LIST_CLASSES = ( - PublicRoomList, -) +ROOM_LIST_CLASSES = (PublicRoomList,) GROUP_SERVER_SERVLET_CLASSES = ( FederationGroupsProfileServlet, @@ -1399,9 +1382,7 @@ GROUP_LOCAL_SERVLET_CLASSES = ( ) -GROUP_ATTESTATION_SERVLET_CLASSES = ( - FederationGroupsRenewAttestaionServlet, -) +GROUP_ATTESTATION_SERVLET_CLASSES = (FederationGroupsRenewAttestaionServlet,) DEFAULT_SERVLET_GROUPS = ( "federation", @@ -1455,7 +1436,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N authenticator=authenticator, ratelimiter=ratelimiter, server_name=hs.hostname, - deny_access=hs.config.restrict_public_rooms_to_local_users, + allow_access=hs.config.allow_public_rooms_over_federation, ).register(resource) if "group_server" in servlet_groups: diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 025a79c022..14aad8f09d 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -32,21 +32,11 @@ class Edu(JsonEncodedObject): internal ID or previous references graph. """ - valid_keys = [ - "origin", - "destination", - "edu_type", - "content", - ] + valid_keys = ["origin", "destination", "edu_type", "content"] - required_keys = [ - "edu_type", - ] + required_keys = ["edu_type"] - internal_keys = [ - "origin", - "destination", - ] + internal_keys = ["origin", "destination"] class Transaction(JsonEncodedObject): @@ -75,10 +65,7 @@ class Transaction(JsonEncodedObject): "edus", ] - internal_keys = [ - "transaction_id", - "destination", - ] + internal_keys = ["transaction_id", "destination"] required_keys = [ "transaction_id", @@ -98,9 +85,7 @@ class Transaction(JsonEncodedObject): del kwargs["edus"] super(Transaction, self).__init__( - transaction_id=transaction_id, - pdus=pdus, - **kwargs + transaction_id=transaction_id, pdus=pdus, **kwargs ) @staticmethod @@ -109,13 +94,9 @@ class Transaction(JsonEncodedObject): transaction_id and origin_server_ts keys. """ if "origin_server_ts" not in kwargs: - raise KeyError( - "Require 'origin_server_ts' to construct a Transaction" - ) + raise KeyError("Require 'origin_server_ts' to construct a Transaction") if "transaction_id" not in kwargs: - raise KeyError( - "Require 'transaction_id' to construct a Transaction" - ) + raise KeyError("Require 'transaction_id' to construct a Transaction") kwargs["pdus"] = [p.get_pdu_json() for p in pdus] |