diff options
Diffstat (limited to 'synapse/federation/transaction_queue.py')
-rw-r--r-- | synapse/federation/transaction_queue.py | 95 |
1 files changed, 87 insertions, 8 deletions
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index f8ca93e4c3..c94c74a67e 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -19,6 +19,7 @@ from twisted.internet import defer from .persistence import TransactionActions from .units import Transaction, Edu +from synapse.api.constants import EventTypes, Membership from synapse.api.errors import HttpResponseException from synapse.util.async import run_on_reactor from synapse.util.logcontext import preserve_context_over_fn @@ -26,6 +27,7 @@ from synapse.util.retryutils import ( get_retry_limiter, NotRetryingDestination, ) from synapse.util.metrics import measure_func +from synapse.types import get_domain_from_id from synapse.handlers.presence import format_user_presence_state import synapse.metrics @@ -36,6 +38,12 @@ logger = logging.getLogger(__name__) metrics = synapse.metrics.get_metrics_for(__name__) +client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client") +sent_pdus_destination_dist = client_metrics.register_distribution( + "sent_pdu_destinations" +) +sent_edus_counter = client_metrics.register_counter("sent_edus") + class TransactionQueue(object): """This class makes sure we only have one transaction in flight at @@ -44,13 +52,14 @@ class TransactionQueue(object): It batches pending PDUs into single transactions. """ - def __init__(self, hs, transport_layer): + def __init__(self, hs): self.server_name = hs.hostname self.store = hs.get_datastore() + self.state = hs.get_state_handler() self.transaction_actions = TransactionActions(self.store) - self.transport_layer = transport_layer + self.transport_layer = hs.get_federation_transport_client() self.clock = hs.get_clock() @@ -95,6 +104,11 @@ class TransactionQueue(object): # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) + self._order = 1 + + self._is_processing = False + self._last_poked_id = -1 + def can_send_to(self, destination): """Can we send messages to the given server? @@ -115,11 +129,61 @@ class TransactionQueue(object): else: return not destination.startswith("localhost") - def enqueue_pdu(self, pdu, destinations, order): + @defer.inlineCallbacks + def notify_new_events(self, current_id): + """This gets called when we have some new events we might want to + send out to other servers. + """ + self._last_poked_id = max(current_id, self._last_poked_id) + + if self._is_processing: + return + + try: + self._is_processing = True + while True: + last_token = yield self.store.get_federation_out_pos("events") + next_token, events = yield self.store.get_all_new_events_stream( + last_token, self._last_poked_id, limit=20, + ) + + logger.debug("Handling %s -> %s", last_token, next_token) + + if not events and next_token >= self._last_poked_id: + break + + for event in events: + users_in_room = yield self.state.get_current_user_in_room( + event.room_id, latest_event_ids=[event.event_id], + ) + + destinations = set( + get_domain_from_id(user_id) for user_id in users_in_room + ) + + if event.type == EventTypes.Member: + if event.content["membership"] == Membership.JOIN: + destinations.add(get_domain_from_id(event.state_key)) + + logger.debug("Sending %s to %r", event, destinations) + + self._send_pdu(event, destinations) + + yield self.store.update_federation_out_pos( + "events", next_token + ) + + finally: + self._is_processing = False + + def _send_pdu(self, pdu, destinations): # We loop through all destinations to see whether we already have # a transaction in progress. If we do, stick it in the pending_pdus # table and we'll get back to it later. + order = self._order + self._order += 1 + destinations = set(destinations) destinations = set( dest for dest in destinations if self.can_send_to(dest) @@ -130,6 +194,8 @@ class TransactionQueue(object): if not destinations: return + sent_pdus_destination_dist.inc_by(len(destinations)) + for destination in destinations: self.pending_pdus_by_dest.setdefault(destination, []).append( (pdu, order) @@ -139,7 +205,10 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_presence(self, destination, states): + def send_presence(self, destination, states): + if not self.can_send_to(destination): + return + self.pending_presence_by_dest.setdefault(destination, {}).update({ state.user_id: state for state in states }) @@ -148,12 +217,19 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_edu(self, edu, key=None): - destination = edu.destination + def send_edu(self, destination, edu_type, content, key=None): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) if not self.can_send_to(destination): return + sent_edus_counter.inc() + if key: self.pending_edus_keyed_by_dest.setdefault( destination, {} @@ -165,7 +241,7 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_failure(self, failure, destination): + def send_failure(self, failure, destination): if destination == self.server_name or destination == "localhost": return @@ -180,7 +256,7 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) - def enqueue_device_messages(self, destination): + def send_device_messages(self, destination): if destination == self.server_name or destination == "localhost": return @@ -191,6 +267,9 @@ class TransactionQueue(object): self._attempt_new_transaction, destination ) + def get_current_token(self): + return 0 + @defer.inlineCallbacks def _attempt_new_transaction(self, destination): # list of (pending_pdu, deferred, order) |