diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/federation/federation_client.py | 116 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 123 | ||||
-rw-r--r-- | synapse/federation/transaction_queue.py | 6 | ||||
-rw-r--r-- | synapse/federation/transport/__init__.py | 12 | ||||
-rw-r--r-- | synapse/federation/transport/client.py | 19 | ||||
-rw-r--r-- | synapse/federation/transport/server.py | 45 |
6 files changed, 259 insertions, 62 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index cd3c962d50..f131941f45 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -19,14 +19,18 @@ from twisted.internet import defer from .federation_base import FederationBase from .units import Edu -from synapse.api.errors import CodeMessageException, SynapseError +from synapse.api.errors import ( + CodeMessageException, HttpResponseException, SynapseError, +) from synapse.util.expiringcache import ExpiringCache from synapse.util.logutils import log_function from synapse.events import FrozenEvent from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination +import itertools import logging +import random logger = logging.getLogger(__name__) @@ -439,6 +443,116 @@ class FederationClient(FederationBase): defer.returnValue(ret) + @defer.inlineCallbacks + def get_missing_events(self, destination, room_id, earliest_events_ids, + latest_events, limit, min_depth): + """Tries to fetch events we are missing. This is called when we receive + an event without having received all of its ancestors. + + Args: + destination (str) + room_id (str) + earliest_events_ids (list): List of event ids. Effectively the + events we expected to receive, but haven't. `get_missing_events` + should only return events that didn't happen before these. + latest_events (list): List of events we have received that we don't + have all previous events for. + limit (int): Maximum number of events to return. + min_depth (int): Minimum depth of events tor return. + """ + try: + content = yield self.transport_layer.get_missing_events( + destination=destination, + room_id=room_id, + earliest_events=earliest_events_ids, + latest_events=[e.event_id for e in latest_events], + limit=limit, + min_depth=min_depth, + ) + + events = [ + self.event_from_pdu_json(e) + for e in content.get("events", []) + ] + + signed_events = yield self._check_sigs_and_hash_and_fetch( + destination, events, outlier=True + ) + + have_gotten_all_from_destination = True + except HttpResponseException as e: + if not e.code == 400: + raise + + # We are probably hitting an old server that doesn't support + # get_missing_events + signed_events = [] + have_gotten_all_from_destination = False + + if len(signed_events) >= limit: + defer.returnValue(signed_events) + + servers = yield self.store.get_joined_hosts_for_room(room_id) + + servers = set(servers) + servers.discard(self.server_name) + + failed_to_fetch = set() + + while len(signed_events) < limit: + # Are we missing any? + + seen_events = set(earliest_events_ids) + seen_events.update(e.event_id for e in signed_events) + + missing_events = {} + for e in itertools.chain(latest_events, signed_events): + if e.depth > min_depth: + missing_events.update({ + e_id: e.depth for e_id, _ in e.prev_events + if e_id not in seen_events + and e_id not in failed_to_fetch + }) + + if not missing_events: + break + + have_seen = yield self.store.have_events(missing_events) + + for k in have_seen: + missing_events.pop(k, None) + + if not missing_events: + break + + # Okay, we haven't gotten everything yet. Lets get them. + ordered_missing = sorted(missing_events.items(), key=lambda x: x[0]) + + if have_gotten_all_from_destination: + servers.discard(destination) + + def random_server_list(): + srvs = list(servers) + random.shuffle(srvs) + return srvs + + deferreds = [ + self.get_pdu( + destinations=random_server_list(), + event_id=e_id, + ) + for e_id, depth in ordered_missing[:limit - len(signed_events)] + ] + + res = yield defer.DeferredList(deferreds, consumeErrors=True) + for (result, val), (e_id, _) in zip(res, ordered_missing): + if result: + signed_events.append(val) + else: + failed_to_fetch.add(e_id) + + defer.returnValue(signed_events) + def event_from_pdu_json(self, pdu_json, outlier=False): event = FrozenEvent( pdu_json diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 22b9663831..9c7dcdba96 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -112,17 +112,20 @@ class FederationServer(FederationBase): logger.debug("[%s] Transaction is new", transaction.transaction_id) with PreserveLoggingContext(): - dl = [] + results = [] + for pdu in pdu_list: d = self._handle_new_pdu(transaction.origin, pdu) - def handle_failure(failure): - failure.trap(FederationError) - self.send_failure(failure.value, transaction.origin) - - d.addErrback(handle_failure) - - dl.append(d) + try: + yield d + results.append({}) + except FederationError as e: + self.send_failure(e, transaction.origin) + results.append({"error": str(e)}) + except Exception as e: + results.append({"error": str(e)}) + logger.exception("Failed to handle PDU") if hasattr(transaction, "edus"): for edu in [Edu(**x) for x in transaction.edus]: @@ -135,21 +138,11 @@ class FederationServer(FederationBase): for failure in getattr(transaction, "pdu_failures", []): logger.info("Got failure %r", failure) - results = yield defer.DeferredList(dl, consumeErrors=True) - - ret = [] - for r in results: - if r[0]: - ret.append({}) - else: - logger.exception(r[1]) - ret.append({"error": str(r[1].value)}) - - logger.debug("Returning: %s", str(ret)) + logger.debug("Returning: %s", str(results)) response = { "pdus": dict(zip( - (p.event_id for p in pdu_list), ret + (p.event_id for p in pdu_list), results )), } @@ -305,6 +298,20 @@ class FederationServer(FederationBase): (200, send_content) ) + @defer.inlineCallbacks + @log_function + def on_get_missing_events(self, origin, room_id, earliest_events, + latest_events, limit, min_depth): + missing_events = yield self.handler.on_get_missing_events( + origin, room_id, earliest_events, latest_events, limit, min_depth + ) + + time_now = self._clock.time_msec() + + defer.returnValue({ + "events": [ev.get_pdu_json(time_now) for ev in missing_events], + }) + @log_function def _get_persisted_pdu(self, origin, event_id, do_auth=True): """ Get a PDU from the database with given origin and id. @@ -331,7 +338,7 @@ class FederationServer(FederationBase): @defer.inlineCallbacks @log_function - def _handle_new_pdu(self, origin, pdu, max_recursion=10): + def _handle_new_pdu(self, origin, pdu, get_missing=True): # We reprocess pdus when we have seen them only as outliers existing = yield self._get_persisted_pdu( origin, pdu.event_id, do_auth=False @@ -383,48 +390,54 @@ class FederationServer(FederationBase): pdu.room_id, min_depth ) + prevs = {e_id for e_id, _ in pdu.prev_events} + seen = set(have_seen.keys()) + if min_depth and pdu.depth < min_depth: # This is so that we don't notify the user about this # message, to work around the fact that some events will # reference really really old events we really don't want to # send to the clients. pdu.internal_metadata.outlier = True - elif min_depth and pdu.depth > min_depth and max_recursion > 0: - for event_id, hashes in pdu.prev_events: - if event_id not in have_seen: - logger.debug( - "_handle_new_pdu requesting pdu %s", - event_id + elif min_depth and pdu.depth > min_depth: + if get_missing and prevs - seen: + latest_tuples = yield self.store.get_latest_events_in_room( + pdu.room_id + ) + + # We add the prev events that we have seen to the latest + # list to ensure the remote server doesn't give them to us + latest = set(e_id for e_id, _, _ in latest_tuples) + latest |= seen + + missing_events = yield self.get_missing_events( + origin, + pdu.room_id, + earliest_events_ids=list(latest), + latest_events=[pdu], + limit=10, + min_depth=min_depth, + ) + + # We want to sort these by depth so we process them and + # tell clients about them in order. + missing_events.sort(key=lambda x: x.depth) + + for e in missing_events: + yield self._handle_new_pdu( + origin, + e, + get_missing=False ) - try: - new_pdu = yield self.federation_client.get_pdu( - [origin, pdu.origin], - event_id=event_id, - ) - - if new_pdu: - yield self._handle_new_pdu( - origin, - new_pdu, - max_recursion=max_recursion-1 - ) - - logger.debug("Processed pdu %s", event_id) - else: - logger.warn("Failed to get PDU %s", event_id) - fetch_state = True - except: - # TODO(erikj): Do some more intelligent retries. - logger.exception("Failed to get PDU") - fetch_state = True - else: - prevs = {e_id for e_id, _ in pdu.prev_events} - seen = set(have_seen.keys()) - if prevs - seen: - fetch_state = True - else: - fetch_state = True + have_seen = yield self.store.have_events( + [ev for ev, _ in pdu.prev_events] + ) + + prevs = {e_id for e_id, _ in pdu.prev_events} + seen = set(have_seen.keys()) + if prevs - seen: + fetch_state = True if fetch_state: # We need to get the state at this event, since we haven't diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 7d30c924d1..741a4e7a1a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -224,6 +224,8 @@ class TransactionQueue(object): ] try: + self.pending_transactions[destination] = 1 + limiter = yield get_retry_limiter( destination, self._clock, @@ -239,8 +241,6 @@ class TransactionQueue(object): len(pending_failures) ) - self.pending_transactions[destination] = 1 - logger.debug("TX [%s] Persisting transaction...", destination) transaction = Transaction.create_new( @@ -287,7 +287,7 @@ class TransactionQueue(object): code = 200 if response: - for e_id, r in getattr(response, "pdus", {}).items(): + for e_id, r in response.get("pdus", {}).items(): if "error" in r: logger.warn( "Transaction returned error for %s: %s", diff --git a/synapse/federation/transport/__init__.py b/synapse/federation/transport/__init__.py index 6800ac46c5..2a671b9aec 100644 --- a/synapse/federation/transport/__init__.py +++ b/synapse/federation/transport/__init__.py @@ -24,6 +24,8 @@ communicate over a different (albeit still reliable) protocol. from .server import TransportLayerServer from .client import TransportLayerClient +from synapse.util.ratelimitutils import FederationRateLimiter + class TransportLayer(TransportLayerServer, TransportLayerClient): """This is a basic implementation of the transport layer that translates @@ -55,8 +57,18 @@ class TransportLayer(TransportLayerServer, TransportLayerClient): send requests """ self.keyring = homeserver.get_keyring() + self.clock = homeserver.get_clock() self.server_name = server_name self.server = server self.client = client self.request_handler = None self.received_handler = None + + self.ratelimiter = FederationRateLimiter( + self.clock, + window_size=homeserver.config.federation_rc_window_size, + sleep_limit=homeserver.config.federation_rc_sleep_limit, + sleep_msec=homeserver.config.federation_rc_sleep_delay, + reject_limit=homeserver.config.federation_rc_reject_limit, + concurrent_requests=homeserver.config.federation_rc_concurrent, + ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 8b137e7128..80d03012b7 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -219,3 +219,22 @@ class TransportLayerClient(object): ) defer.returnValue(content) + + @defer.inlineCallbacks + @log_function + def get_missing_events(self, destination, room_id, earliest_events, + latest_events, limit, min_depth): + path = PREFIX + "/get_missing_events/%s" % (room_id,) + + content = yield self.client.post_json( + destination=destination, + path=path, + data={ + "limit": int(limit), + "min_depth": int(min_depth), + "earliest_events": earliest_events, + "latest_events": latest_events, + } + ) + + defer.returnValue(content) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2ffb37aa18..ece6dbcf62 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -98,15 +98,23 @@ class TransportLayerServer(object): def new_handler(request, *args, **kwargs): try: (origin, content) = yield self._authenticate_request(request) - response = yield handler( - origin, content, request.args, *args, **kwargs - ) + with self.ratelimiter.ratelimit(origin) as d: + yield d + response = yield handler( + origin, content, request.args, *args, **kwargs + ) except: logger.exception("_authenticate_request failed") raise defer.returnValue(response) return new_handler + def rate_limit_origin(self, handler): + def new_handler(origin, *args, **kwargs): + response = yield handler(origin, *args, **kwargs) + defer.returnValue(response) + return new_handler() + @log_function def register_received_handler(self, handler): """ Register a handler that will be fired when we receive data. @@ -234,6 +242,7 @@ class TransportLayerServer(object): ) ) ) + self.server.register_path( "POST", re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), @@ -245,6 +254,17 @@ class TransportLayerServer(object): ) ) + self.server.register_path( + "POST", + re.compile("^" + PREFIX + "/get_missing_events/([^/]*)/?$"), + self._with_authentication( + lambda origin, content, query, room_id: + self._get_missing_events( + origin, content, room_id, + ) + ) + ) + @defer.inlineCallbacks @log_function def _on_send_request(self, origin, content, query, transaction_id): @@ -344,3 +364,22 @@ class TransportLayerServer(object): ) defer.returnValue((200, new_content)) + + @defer.inlineCallbacks + @log_function + def _get_missing_events(self, origin, content, room_id): + limit = int(content.get("limit", 10)) + min_depth = int(content.get("min_depth", 0)) + earliest_events = content.get("earliest_events", []) + latest_events = content.get("latest_events", []) + + content = yield self.request_handler.on_get_missing_events( + origin, + room_id=room_id, + earliest_events=earliest_events, + latest_events=latest_events, + min_depth=min_depth, + limit=limit, + ) + + defer.returnValue((200, content)) |