summary refs log tree commit diff
path: root/synapse/federation/transaction_queue.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transaction_queue.py')
-rw-r--r--synapse/federation/transaction_queue.py102
1 files changed, 94 insertions, 8 deletions
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index f8ca93e4c3..51b656d74a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
 from .persistence import TransactionActions
 from .units import Transaction, Edu
 
+from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
 from synapse.util.logcontext import preserve_context_over_fn
@@ -26,6 +27,7 @@ from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
 from synapse.util.metrics import measure_func
+from synapse.types import get_domain_from_id
 from synapse.handlers.presence import format_user_presence_state
 import synapse.metrics
 
@@ -36,6 +38,12 @@ logger = logging.getLogger(__name__)
 
 metrics = synapse.metrics.get_metrics_for(__name__)
 
+client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+sent_pdus_destination_dist = client_metrics.register_distribution(
+    "sent_pdu_destinations"
+)
+sent_edus_counter = client_metrics.register_counter("sent_edus")
+
 
 class TransactionQueue(object):
     """This class makes sure we only have one transaction in flight at
@@ -44,13 +52,14 @@ class TransactionQueue(object):
     It batches pending PDUs into single transactions.
     """
 
-    def __init__(self, hs, transport_layer):
+    def __init__(self, hs):
         self.server_name = hs.hostname
 
         self.store = hs.get_datastore()
+        self.state = hs.get_state_handler()
         self.transaction_actions = TransactionActions(self.store)
 
-        self.transport_layer = transport_layer
+        self.transport_layer = hs.get_federation_transport_client()
 
         self.clock = hs.get_clock()
 
@@ -95,6 +104,11 @@ class TransactionQueue(object):
         # HACK to get unique tx id
         self._next_txn_id = int(self.clock.time_msec())
 
+        self._order = 1
+
+        self._is_processing = False
+        self._last_poked_id = -1
+
     def can_send_to(self, destination):
         """Can we send messages to the given server?
 
@@ -115,11 +129,61 @@ class TransactionQueue(object):
         else:
             return not destination.startswith("localhost")
 
-    def enqueue_pdu(self, pdu, destinations, order):
+    @defer.inlineCallbacks
+    def notify_new_events(self, current_id):
+        """This gets called when we have some new events we might want to
+        send out to other servers.
+        """
+        self._last_poked_id = max(current_id, self._last_poked_id)
+
+        if self._is_processing:
+            return
+
+        try:
+            self._is_processing = True
+            while True:
+                last_token = yield self.store.get_federation_out_pos("events")
+                next_token, events = yield self.store.get_all_new_events_stream(
+                    last_token, self._last_poked_id, limit=20,
+                )
+
+                logger.debug("Handling %s -> %s", last_token, next_token)
+
+                if not events and next_token >= self._last_poked_id:
+                    break
+
+                for event in events:
+                    users_in_room = yield self.state.get_current_user_in_room(
+                        event.room_id, latest_event_ids=[event.event_id],
+                    )
+
+                    destinations = set(
+                        get_domain_from_id(user_id) for user_id in users_in_room
+                    )
+
+                    if event.type == EventTypes.Member:
+                        if event.content["membership"] == Membership.JOIN:
+                            destinations.add(get_domain_from_id(event.state_key))
+
+                    logger.debug("Sending %s to %r", event, destinations)
+
+                    self._send_pdu(event, destinations)
+
+                yield self.store.update_federation_out_pos(
+                    "events", next_token
+                )
+
+        finally:
+            self._is_processing = False
+
+    def _send_pdu(self, pdu, destinations):
         # We loop through all destinations to see whether we already have
         # a transaction in progress. If we do, stick it in the pending_pdus
         # table and we'll get back to it later.
 
+        order = self._order
+        self._order += 1
+
         destinations = set(destinations)
         destinations = set(
             dest for dest in destinations if self.can_send_to(dest)
@@ -130,6 +194,8 @@ class TransactionQueue(object):
         if not destinations:
             return
 
+        sent_pdus_destination_dist.inc_by(len(destinations))
+
         for destination in destinations:
             self.pending_pdus_by_dest.setdefault(destination, []).append(
                 (pdu, order)
@@ -139,7 +205,10 @@ class TransactionQueue(object):
                 self._attempt_new_transaction, destination
             )
 
-    def enqueue_presence(self, destination, states):
+    def send_presence(self, destination, states):
+        if not self.can_send_to(destination):
+            return
+
         self.pending_presence_by_dest.setdefault(destination, {}).update({
             state.user_id: state for state in states
         })
@@ -148,12 +217,19 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
-    def enqueue_edu(self, edu, key=None):
-        destination = edu.destination
+    def send_edu(self, destination, edu_type, content, key=None):
+        edu = Edu(
+            origin=self.server_name,
+            destination=destination,
+            edu_type=edu_type,
+            content=content,
+        )
 
         if not self.can_send_to(destination):
             return
 
+        sent_edus_counter.inc()
+
         if key:
             self.pending_edus_keyed_by_dest.setdefault(
                 destination, {}
@@ -165,7 +241,7 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
-    def enqueue_failure(self, failure, destination):
+    def send_failure(self, failure, destination):
         if destination == self.server_name or destination == "localhost":
             return
 
@@ -180,7 +256,7 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
-    def enqueue_device_messages(self, destination):
+    def send_device_messages(self, destination):
         if destination == self.server_name or destination == "localhost":
             return
 
@@ -191,6 +267,9 @@ class TransactionQueue(object):
             self._attempt_new_transaction, destination
         )
 
+    def get_current_token(self):
+        return 0
+
     @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
         # list of (pending_pdu, deferred, order)
@@ -383,6 +462,13 @@ class TransactionQueue(object):
                     code = e.code
                     response = e.response
 
+                    if e.code == 429 or 500 <= e.code:
+                        logger.info(
+                            "TX [%s] {%s} got %d response",
+                            destination, txn_id, code
+                        )
+                        raise e
+
                 logger.info(
                     "TX [%s] {%s} got %d response",
                     destination, txn_id, code