diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index f8ca93e4c3..51b656d74a 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -19,6 +19,7 @@ from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction, Edu
+from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
@@ -26,6 +27,7 @@ from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
from synapse.util.metrics import measure_func
+from synapse.types import get_domain_from_id
from synapse.handlers.presence import format_user_presence_state
import synapse.metrics
@@ -36,6 +38,12 @@ logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
+client_metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
+sent_pdus_destination_dist = client_metrics.register_distribution(
+ "sent_pdu_destinations"
+)
+sent_edus_counter = client_metrics.register_counter("sent_edus")
+
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -44,13 +52,14 @@ class TransactionQueue(object):
It batches pending PDUs into single transactions.
"""
- def __init__(self, hs, transport_layer):
+ def __init__(self, hs):
self.server_name = hs.hostname
self.store = hs.get_datastore()
+ self.state = hs.get_state_handler()
self.transaction_actions = TransactionActions(self.store)
- self.transport_layer = transport_layer
+ self.transport_layer = hs.get_federation_transport_client()
self.clock = hs.get_clock()
@@ -95,6 +104,11 @@ class TransactionQueue(object):
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
+ self._order = 1
+
+ self._is_processing = False
+ self._last_poked_id = -1
+
def can_send_to(self, destination):
"""Can we send messages to the given server?
@@ -115,11 +129,61 @@ class TransactionQueue(object):
else:
return not destination.startswith("localhost")
- def enqueue_pdu(self, pdu, destinations, order):
+ @defer.inlineCallbacks
+ def notify_new_events(self, current_id):
+ """This gets called when we have some new events we might want to
+ send out to other servers.
+ """
+ self._last_poked_id = max(current_id, self._last_poked_id)
+
+ if self._is_processing:
+ return
+
+ try:
+ self._is_processing = True
+ while True:
+ last_token = yield self.store.get_federation_out_pos("events")
+ next_token, events = yield self.store.get_all_new_events_stream(
+ last_token, self._last_poked_id, limit=20,
+ )
+
+ logger.debug("Handling %s -> %s", last_token, next_token)
+
+ if not events and next_token >= self._last_poked_id:
+ break
+
+ for event in events:
+ users_in_room = yield self.state.get_current_user_in_room(
+ event.room_id, latest_event_ids=[event.event_id],
+ )
+
+ destinations = set(
+ get_domain_from_id(user_id) for user_id in users_in_room
+ )
+
+ if event.type == EventTypes.Member:
+ if event.content["membership"] == Membership.JOIN:
+ destinations.add(get_domain_from_id(event.state_key))
+
+ logger.debug("Sending %s to %r", event, destinations)
+
+ self._send_pdu(event, destinations)
+
+ yield self.store.update_federation_out_pos(
+ "events", next_token
+ )
+
+ finally:
+ self._is_processing = False
+
+ def _send_pdu(self, pdu, destinations):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
# table and we'll get back to it later.
+ order = self._order
+ self._order += 1
+
destinations = set(destinations)
destinations = set(
dest for dest in destinations if self.can_send_to(dest)
@@ -130,6 +194,8 @@ class TransactionQueue(object):
if not destinations:
return
+ sent_pdus_destination_dist.inc_by(len(destinations))
+
for destination in destinations:
self.pending_pdus_by_dest.setdefault(destination, []).append(
(pdu, order)
@@ -139,7 +205,10 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_presence(self, destination, states):
+ def send_presence(self, destination, states):
+ if not self.can_send_to(destination):
+ return
+
self.pending_presence_by_dest.setdefault(destination, {}).update({
state.user_id: state for state in states
})
@@ -148,12 +217,19 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_edu(self, edu, key=None):
- destination = edu.destination
+ def send_edu(self, destination, edu_type, content, key=None):
+ edu = Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type=edu_type,
+ content=content,
+ )
if not self.can_send_to(destination):
return
+ sent_edus_counter.inc()
+
if key:
self.pending_edus_keyed_by_dest.setdefault(
destination, {}
@@ -165,7 +241,7 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_failure(self, failure, destination):
+ def send_failure(self, failure, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -180,7 +256,7 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
- def enqueue_device_messages(self, destination):
+ def send_device_messages(self, destination):
if destination == self.server_name or destination == "localhost":
return
@@ -191,6 +267,9 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
+ def get_current_token(self):
+ return 0
+
@defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
# list of (pending_pdu, deferred, order)
@@ -383,6 +462,13 @@ class TransactionQueue(object):
code = e.code
response = e.response
+ if e.code == 429 or 500 <= e.code:
+ logger.info(
+ "TX [%s] {%s} got %d response",
+ destination, txn_id, code
+ )
+ raise e
+
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
|