diff options
Diffstat (limited to 'synapse/federation')
-rw-r--r-- | synapse/federation/__init__.py | 7 | ||||
-rw-r--r-- | synapse/federation/federation_client.py | 62 | ||||
-rw-r--r-- | synapse/federation/replication.py | 5 | ||||
-rw-r--r-- | synapse/federation/send_queue.py | 298 | ||||
-rw-r--r-- | synapse/federation/transaction_queue.py | 102 |
5 files changed, 395 insertions, 79 deletions
diff --git a/synapse/federation/__init__.py b/synapse/federation/__init__.py index 979fdf2431..2e32d245ba 100644 --- a/synapse/federation/__init__.py +++ b/synapse/federation/__init__.py @@ -17,10 +17,9 @@ """ from .replication import ReplicationLayer -from .transport.client import TransportLayerClient -def initialize_http_replication(homeserver): - transport = TransportLayerClient(homeserver) +def initialize_http_replication(hs): + transport = hs.get_federation_transport_client() - return ReplicationLayer(homeserver, transport) + return ReplicationLayer(hs, transport) diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 94e76b1978..b255709165 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -18,7 +18,6 @@ from twisted.internet import defer from .federation_base import FederationBase from synapse.api.constants import Membership -from .units import Edu from synapse.api.errors import ( CodeMessageException, HttpResponseException, SynapseError, @@ -45,10 +44,6 @@ logger = logging.getLogger(__name__) # synapse.federation.federation_client is a silly name metrics = synapse.metrics.get_metrics_for("synapse.federation.client") -sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations") - -sent_edus_counter = metrics.register_counter("sent_edus") - sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"]) @@ -93,63 +88,6 @@ class FederationClient(FederationBase): self._get_pdu_cache.start() @log_function - def send_pdu(self, pdu, destinations): - """Informs the replication layer about a new PDU generated within the - home server that should be transmitted to others. - - TODO: Figure out when we should actually resolve the deferred. - - Args: - pdu (Pdu): The new Pdu. - - Returns: - Deferred: Completes when we have successfully processed the PDU - and replicated it to any interested remote home servers. - """ - order = self._order - self._order += 1 - - sent_pdus_destination_dist.inc_by(len(destinations)) - - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_pdu(pdu, destinations, order) - - logger.debug( - "[%s] transaction_layer.enqueue_pdu... done", - pdu.event_id - ) - - def send_presence(self, destination, states): - if destination != self.server_name: - self._transaction_queue.enqueue_presence(destination, states) - - @log_function - def send_edu(self, destination, edu_type, content, key=None): - edu = Edu( - origin=self.server_name, - destination=destination, - edu_type=edu_type, - content=content, - ) - - sent_edus_counter.inc() - - self._transaction_queue.enqueue_edu(edu, key=key) - - @log_function - def send_device_messages(self, destination): - """Sends the device messages in the local database to the remote - destination""" - self._transaction_queue.enqueue_device_messages(destination) - - @log_function - def send_failure(self, failure, destination): - self._transaction_queue.enqueue_failure(failure, destination) - return defer.succeed(None) - - @log_function def make_query(self, destination, query_type, args, retry_on_dns_fail=False): """Sends a federation Query to a remote homeserver of the given type diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index ea66a5dcbc..62d865ec4b 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -20,8 +20,6 @@ a given transport. from .federation_client import FederationClient from .federation_server import FederationServer -from .transaction_queue import TransactionQueue - from .persistence import TransactionActions import logging @@ -66,9 +64,6 @@ class ReplicationLayer(FederationClient, FederationServer): self._clock = hs.get_clock() self.transaction_actions = TransactionActions(self.store) - self._transaction_queue = TransactionQueue(hs, transport_layer) - - self._order = 0 self.hs = hs diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py new file mode 100644 index 0000000000..5c9f7a86f0 --- /dev/null +++ b/synapse/federation/send_queue.py @@ -0,0 +1,298 @@ +# -*- coding: utf-8 -*- +# Copyright 2014-2016 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A federation sender that forwards things to be sent across replication to +a worker process. + +It assumes there is a single worker process feeding off of it. + +Each row in the replication stream consists of a type and some json, where the +types indicate whether they are presence, or edus, etc. + +Ephemeral or non-event data are queued up in-memory. When the worker requests +updates since a particular point, all in-memory data since before that point is +dropped. We also expire things in the queue after 5 minutes, to ensure that a +dead worker doesn't cause the queues to grow limitlessly. + +Events are replicated via a separate events stream. +""" + +from .units import Edu + +from synapse.util.metrics import Measure +import synapse.metrics + +from blist import sorteddict +import ujson + + +metrics = synapse.metrics.get_metrics_for(__name__) + + +PRESENCE_TYPE = "p" +KEYED_EDU_TYPE = "k" +EDU_TYPE = "e" +FAILURE_TYPE = "f" +DEVICE_MESSAGE_TYPE = "d" + + +class FederationRemoteSendQueue(object): + """A drop in replacement for TransactionQueue""" + + def __init__(self, hs): + self.server_name = hs.hostname + self.clock = hs.get_clock() + + self.presence_map = {} + self.presence_changed = sorteddict() + + self.keyed_edu = {} + self.keyed_edu_changed = sorteddict() + + self.edus = sorteddict() + + self.failures = sorteddict() + + self.device_messages = sorteddict() + + self.pos = 1 + self.pos_time = sorteddict() + + # EVERYTHING IS SAD. In particular, python only makes new scopes when + # we make a new function, so we need to make a new function so the inner + # lambda binds to the queue rather than to the name of the queue which + # changes. ARGH. + def register(name, queue): + metrics.register_callback( + queue_name + "_size", + lambda: len(queue), + ) + + for queue_name in [ + "presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed", + "edus", "failures", "device_messages", "pos_time", + ]: + register(queue_name, getattr(self, queue_name)) + + self.clock.looping_call(self._clear_queue, 30 * 1000) + + def _next_pos(self): + pos = self.pos + self.pos += 1 + self.pos_time[self.clock.time_msec()] = pos + return pos + + def _clear_queue(self): + """Clear the queues for anything older than N minutes""" + + FIVE_MINUTES_AGO = 5 * 60 * 1000 + now = self.clock.time_msec() + + keys = self.pos_time.keys() + time = keys.bisect_left(now - FIVE_MINUTES_AGO) + if not keys[:time]: + return + + position_to_delete = max(keys[:time]) + for key in keys[:time]: + del self.pos_time[key] + + self._clear_queue_before_pos(position_to_delete) + + def _clear_queue_before_pos(self, position_to_delete): + """Clear all the queues from before a given position""" + with Measure(self.clock, "send_queue._clear"): + # Delete things out of presence maps + keys = self.presence_changed.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.presence_changed[key] + + user_ids = set( + user_id for uids in self.presence_changed.values() for _, user_id in uids + ) + + to_del = [ + user_id for user_id in self.presence_map if user_id not in user_ids + ] + for user_id in to_del: + del self.presence_map[user_id] + + # Delete things out of keyed edus + keys = self.keyed_edu_changed.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.keyed_edu_changed[key] + + live_keys = set() + for edu_key in self.keyed_edu_changed.values(): + live_keys.add(edu_key) + + to_del = [edu_key for edu_key in self.keyed_edu if edu_key not in live_keys] + for edu_key in to_del: + del self.keyed_edu[edu_key] + + # Delete things out of edu map + keys = self.edus.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.edus[key] + + # Delete things out of failure map + keys = self.failures.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.failures[key] + + # Delete things out of device map + keys = self.device_messages.keys() + i = keys.bisect_left(position_to_delete) + for key in keys[:i]: + del self.device_messages[key] + + def notify_new_events(self, current_id): + """As per TransactionQueue""" + # We don't need to replicate this as it gets sent down a different + # stream. + pass + + def send_edu(self, destination, edu_type, content, key=None): + """As per TransactionQueue""" + pos = self._next_pos() + + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) + + if key: + assert isinstance(key, tuple) + self.keyed_edu[(destination, key)] = edu + self.keyed_edu_changed[pos] = (destination, key) + else: + self.edus[pos] = edu + + def send_presence(self, destination, states): + """As per TransactionQueue""" + pos = self._next_pos() + + self.presence_map.update({ + state.user_id: state + for state in states + }) + + self.presence_changed[pos] = [ + (destination, state.user_id) for state in states + ] + + def send_failure(self, failure, destination): + """As per TransactionQueue""" + pos = self._next_pos() + + self.failures[pos] = (destination, str(failure)) + + def send_device_messages(self, destination): + """As per TransactionQueue""" + pos = self._next_pos() + self.device_messages[pos] = destination + + def get_current_token(self): + return self.pos - 1 + + def get_replication_rows(self, token, limit, federation_ack=None): + """ + Args: + token (int) + limit (int) + federation_ack (int): Optional. The position where the worker is + explicitly acknowledged it has handled. Allows us to drop + data from before that point + """ + # TODO: Handle limit. + + # To handle restarts where we wrap around + if token > self.pos: + token = -1 + + rows = [] + + # There should be only one reader, so lets delete everything its + # acknowledged its seen. + if federation_ack: + self._clear_queue_before_pos(federation_ack) + + # Fetch changed presence + keys = self.presence_changed.keys() + i = keys.bisect_right(token) + dest_user_ids = set( + (pos, dest_user_id) + for pos in keys[i:] + for dest_user_id in self.presence_changed[pos] + ) + + for (key, (dest, user_id)) in dest_user_ids: + rows.append((key, PRESENCE_TYPE, ujson.dumps({ + "destination": dest, + "state": self.presence_map[user_id].as_dict(), + }))) + + # Fetch changes keyed edus + keys = self.keyed_edu_changed.keys() + i = keys.bisect_right(token) + keyed_edus = set((k, self.keyed_edu_changed[k]) for k in keys[i:]) + + for (pos, (destination, edu_key)) in keyed_edus: + rows.append( + (pos, KEYED_EDU_TYPE, ujson.dumps({ + "key": edu_key, + "edu": self.keyed_edu[(destination, edu_key)].get_internal_dict(), + })) + ) + + # Fetch changed edus + keys = self.edus.keys() + i = keys.bisect_right(token) + edus = set((k, self.edus[k]) for k in keys[i:]) + + for (pos, edu) in edus: + rows.append((pos, EDU_TYPE, ujson.dumps(edu.get_internal_dict()))) + + # Fetch changed failures + keys = self.failures.keys() + i = keys.bisect_right(token) + failures = set((k, self.failures[k]) for k in keys[i:]) + + for (pos, (destination, failure)) in failures: + rows.append((pos, FAILURE_TYPE, ujson.dumps({ + "destination": destination, + "failure": failure, + }))) + + # Fetch changed device messages + keys = self.device_messages.keys() + i = keys.bisect_right(token) + device_messages = set((k, self.device_messages[k]) for k in keys[i:]) + + for (pos, destination) in device_messages: + rows.append((pos, DEVICE_MESSAGE_TYPE, ujson.dumps({ + "destination": destination, + }))) + + # Sort rows based on pos + rows.sort() + + return rows diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index f8ca93e4c3..51b656d74a 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -19,6 +19,7 @@ from twisted.internet import defer from .persistence import TransactionActions from .units import Transaction, Edu +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import HttpResponseException from synapse.util.async import run_on_reactor from synapse.util.logcontext import preserve_context_over_fn @@ -26,6 +27,7 @@ from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) from synapse.util.metrics import measure_func +from synapse.types import get_domain_from_id from synapse.handlers.presence import format_user_presence_state import synapse.metrics @@ -36,6 +38,12 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) +client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client") +sent_pdus_destination_dist = client_metrics.register_distribution( + "sent_pdu_destinations" +) +sent_edus_counter = client_metrics.register_counter("sent_edus") + class TransactionQueue(object): """This class makes sure we only have one transaction in flight at @@ -44,13 +52,14 @@ class TransactionQueue(object): It batches pending PDUs into single transactions. """ - def __init__(self, hs, transport_layer): + def __init__(self, hs): self.server_name = hs.hostname self.store = hs.get_datastore() + self.state = hs.get_state_handler() self.transaction_actions = TransactionActions(self.store) - self.transport_layer = transport_layer + self.transport_layer = hs.get_federation_transport_client() self.clock = hs.get_clock() @@ -95,6 +104,11 @@ class TransactionQueue(object): # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) + self._order = 1 + + self._is_processing = False + self._last_poked_id = -1 + def can_send_to(self, destination): """Can we send messages to the given server? @@ -115,11 +129,61 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - def enqueue_pdu(self, pdu, destinations, order): + @defer.inlineCallbacks + def notify_new_events(self, current_id): + """This gets called when we have some new events we might want to + send out to other servers. + """ + self._last_poked_id = max(current_id, self._last_poked_id) + + if self._is_processing: + return + + try: + self._is_processing = True + while True: + last_token = yield self.store.get_federation_out_pos("events") + next_token, events = yield self.store.get_all_new_events_stream( + last_token, self._last_poked_id, limit=20, + ) + + logger.debug("Handling %s -> %s", last_token, next_token) + + if not events and next_token >= self._last_poked_id: + break + + for event in events: + users_in_room = yield self.state.get_current_user_in_room( + event.room_id, latest_event_ids=[event.event_id], + ) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + ) + + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.JOIN: + destinations.add(get_domain_from_id(event.state_key)) + + logger.debug("Sending %s to %r", event, destinations) + + self._send_pdu(event, destinations) + + yield self.store.update_federation_out_pos( + "events", next_token + ) + + finally: + self._is_processing = False + + def _send_pdu(self, pdu, destinations): # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. + order = self._order + self._order += 1 + destinations = set(destinations) destinations = set( dest for dest in destinations if self.can_send_to(dest) @@ -130,6 +194,8 @@ class TransactionQueue(object): if not destinations: return + sent_pdus_destination_dist.inc_by(len(destinations)) + for destination in destinations: self.pending_pdus_by_dest.setdefault(destination, []).append( (pdu, order) @@ -139,7 +205,10 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_presence(self, destination, states): + def send_presence(self, destination, states): + if not self.can_send_to(destination): + return + self.pending_presence_by_dest.setdefault(destination, {}).update({ state.user_id: state for state in states }) @@ -148,12 +217,19 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_edu(self, edu, key=None): - destination = edu.destination + def send_edu(self, destination, edu_type, content, key=None): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) if not self.can_send_to(destination): return + sent_edus_counter.inc() + if key: self.pending_edus_keyed_by_dest.setdefault( destination, {} @@ -165,7 +241,7 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_failure(self, failure, destination): + def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": return @@ -180,7 +256,7 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_device_messages(self, destination): + def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": return @@ -191,6 +267,9 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) + def get_current_token(self): + return 0 + @defer.inlineCallbacks def _attempt_new_transaction(self, destination): # list of (pending_pdu, deferred, order) @@ -383,6 +462,13 @@ class TransactionQueue(object): code = e.code response = e.response + if e.code == 429 or 500 <= e.code: + logger.info( + "TX [%s] {%s} got %d response", + destination, txn_id, code + ) + raise e + logger.info( "TX [%s] {%s} got %d response", destination, txn_id, code |