diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index deb519f3ef..cc0d765e5f 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -17,6 +17,7 @@ import datetime
import logging
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
+import attr
from prometheus_client import Counter
from synapse.api.errors import (
@@ -93,6 +94,10 @@ class PerDestinationQueue:
self._destination = destination
self.transmission_loop_running = False
+ # Flag to signal to any running transmission loop that there is new data
+ # queued up to be sent.
+ self._new_data_to_send = False
+
# True whilst we are sending events that the remote homeserver missed
# because it was unreachable. We start in this state so we can perform
# catch-up at startup.
@@ -108,7 +113,7 @@ class PerDestinationQueue:
# destination (we are the only updater so this is safe)
self._last_successful_stream_ordering = None # type: Optional[int]
- # a list of pending PDUs
+ # a queue of pending PDUs
self._pending_pdus = [] # type: List[EventBase]
# XXX this is never actually used: see
@@ -208,6 +213,10 @@ class PerDestinationQueue:
transaction in the background.
"""
+ # Mark that we (may) have new things to send, so that any running
+ # transmission loop will recheck whether there is stuff to send.
+ self._new_data_to_send = True
+
if self.transmission_loop_running:
# XXX: this can get stuck on by a never-ending
# request at which point pending_pdus just keeps growing.
@@ -250,125 +259,41 @@ class PerDestinationQueue:
pending_pdus = []
while True:
- # We have to keep 2 free slots for presence and rr_edus
- limit = MAX_EDUS_PER_TRANSACTION - 2
-
- device_update_edus, dev_list_id = await self._get_device_update_edus(
- limit
- )
-
- limit -= len(device_update_edus)
-
- (
- to_device_edus,
- device_stream_id,
- ) = await self._get_to_device_message_edus(limit)
-
- pending_edus = device_update_edus + to_device_edus
-
- # BEGIN CRITICAL SECTION
- #
- # In order to avoid a race condition, we need to make sure that
- # the following code (from popping the queues up to the point
- # where we decide if we actually have any pending messages) is
- # atomic - otherwise new PDUs or EDUs might arrive in the
- # meantime, but not get sent because we hold the
- # transmission_loop_running flag.
-
- pending_pdus = self._pending_pdus
+ self._new_data_to_send = False
- # We can only include at most 50 PDUs per transactions
- pending_pdus, self._pending_pdus = pending_pdus[:50], pending_pdus[50:]
+ async with _TransactionQueueManager(self) as (
+ pending_pdus,
+ pending_edus,
+ ):
+ if not pending_pdus and not pending_edus:
+ logger.debug("TX [%s] Nothing to send", self._destination)
+
+ # If we've gotten told about new things to send during
+ # checking for things to send, we try looking again.
+ # Otherwise new PDUs or EDUs might arrive in the meantime,
+ # but not get sent because we hold the
+ # `transmission_loop_running` flag.
+ if self._new_data_to_send:
+ continue
+ else:
+ return
- pending_edus.extend(self._get_rr_edus(force_flush=False))
- pending_presence = self._pending_presence
- self._pending_presence = {}
- if pending_presence:
- pending_edus.append(
- Edu(
- origin=self._server_name,
- destination=self._destination,
- edu_type="m.presence",
- content={
- "push": [
- format_user_presence_state(
- presence, self._clock.time_msec()
- )
- for presence in pending_presence.values()
- ]
- },
+ if pending_pdus:
+ logger.debug(
+ "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+ self._destination,
+ len(pending_pdus),
)
- )
- pending_edus.extend(
- self._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
- )
- while (
- len(pending_edus) < MAX_EDUS_PER_TRANSACTION
- and self._pending_edus_keyed
- ):
- _, val = self._pending_edus_keyed.popitem()
- pending_edus.append(val)
-
- if pending_pdus:
- logger.debug(
- "TX [%s] len(pending_pdus_by_dest[dest]) = %d",
- self._destination,
- len(pending_pdus),
+ await self._transaction_manager.send_new_transaction(
+ self._destination, pending_pdus, pending_edus
)
- if not pending_pdus and not pending_edus:
- logger.debug("TX [%s] Nothing to send", self._destination)
- self._last_device_stream_id = device_stream_id
- return
-
- # if we've decided to send a transaction anyway, and we have room, we
- # may as well send any pending RRs
- if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
- pending_edus.extend(self._get_rr_edus(force_flush=True))
-
- # END CRITICAL SECTION
-
- success = await self._transaction_manager.send_new_transaction(
- self._destination, pending_pdus, pending_edus
- )
- if success:
sent_transactions_counter.inc()
sent_edus_counter.inc(len(pending_edus))
for edu in pending_edus:
sent_edus_by_type.labels(edu.edu_type).inc()
- # Remove the acknowledged device messages from the database
- # Only bother if we actually sent some device messages
- if to_device_edus:
- await self._store.delete_device_msgs_for_remote(
- self._destination, device_stream_id
- )
- # also mark the device updates as sent
- if device_update_edus:
- logger.info(
- "Marking as sent %r %r", self._destination, dev_list_id
- )
- await self._store.mark_as_sent_devices_by_remote(
- self._destination, dev_list_id
- )
-
- self._last_device_stream_id = device_stream_id
- self._last_device_list_stream_id = dev_list_id
-
- if pending_pdus:
- # we sent some PDUs and it was successful, so update our
- # last_successful_stream_ordering in the destinations table.
- final_pdu = pending_pdus[-1]
- last_successful_stream_ordering = (
- final_pdu.internal_metadata.stream_ordering
- )
- assert last_successful_stream_ordering
- await self._store.set_destination_last_successful_stream_ordering(
- self._destination, last_successful_stream_ordering
- )
- else:
- break
except NotRetryingDestination as e:
logger.debug(
"TX [%s] not ready for retry yet (next retry at %s) - "
@@ -401,7 +326,7 @@ class PerDestinationQueue:
self._pending_presence = {}
self._pending_rrs = {}
- self._start_catching_up()
+ self._start_catching_up()
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
@@ -412,7 +337,6 @@ class PerDestinationQueue:
e,
)
- self._start_catching_up()
except RequestSendFailed as e:
logger.warning(
"TX [%s] Failed to send transaction: %s", self._destination, e
@@ -422,16 +346,12 @@ class PerDestinationQueue:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
-
- self._start_catching_up()
except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination)
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
-
- self._start_catching_up()
finally:
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
@@ -499,13 +419,10 @@ class PerDestinationQueue:
rooms = [p.room_id for p in catchup_pdus]
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
- success = await self._transaction_manager.send_new_transaction(
+ await self._transaction_manager.send_new_transaction(
self._destination, catchup_pdus, []
)
- if not success:
- return
-
sent_transactions_counter.inc()
final_pdu = catchup_pdus[-1]
self._last_successful_stream_ordering = cast(
@@ -584,3 +501,135 @@ class PerDestinationQueue:
"""
self._catching_up = True
self._pending_pdus = []
+
+
+@attr.s(slots=True)
+class _TransactionQueueManager:
+ """A helper async context manager for pulling stuff off the queues and
+ tracking what was last successfully sent, etc.
+ """
+
+ queue = attr.ib(type=PerDestinationQueue)
+
+ _device_stream_id = attr.ib(type=Optional[int], default=None)
+ _device_list_id = attr.ib(type=Optional[int], default=None)
+ _last_stream_ordering = attr.ib(type=Optional[int], default=None)
+ _pdus = attr.ib(type=List[EventBase], factory=list)
+
+ async def __aenter__(self) -> Tuple[List[EventBase], List[Edu]]:
+ # First we calculate the EDUs we want to send, if any.
+
+ # We start by fetching device related EDUs, i.e device updates and to
+ # device messages. We have to keep 2 free slots for presence and rr_edus.
+ limit = MAX_EDUS_PER_TRANSACTION - 2
+
+ device_update_edus, dev_list_id = await self.queue._get_device_update_edus(
+ limit
+ )
+
+ if device_update_edus:
+ self._device_list_id = dev_list_id
+ else:
+ self.queue._last_device_list_stream_id = dev_list_id
+
+ limit -= len(device_update_edus)
+
+ (
+ to_device_edus,
+ device_stream_id,
+ ) = await self.queue._get_to_device_message_edus(limit)
+
+ if to_device_edus:
+ self._device_stream_id = device_stream_id
+ else:
+ self.queue._last_device_stream_id = device_stream_id
+
+ pending_edus = device_update_edus + to_device_edus
+
+ # Now add the read receipt EDU.
+ pending_edus.extend(self.queue._get_rr_edus(force_flush=False))
+
+ # And presence EDU.
+ if self.queue._pending_presence:
+ pending_edus.append(
+ Edu(
+ origin=self.queue._server_name,
+ destination=self.queue._destination,
+ edu_type="m.presence",
+ content={
+ "push": [
+ format_user_presence_state(
+ presence, self.queue._clock.time_msec()
+ )
+ for presence in self.queue._pending_presence.values()
+ ]
+ },
+ )
+ )
+ self.queue._pending_presence = {}
+
+ # Finally add any other types of EDUs if there is room.
+ pending_edus.extend(
+ self.queue._pop_pending_edus(MAX_EDUS_PER_TRANSACTION - len(pending_edus))
+ )
+ while (
+ len(pending_edus) < MAX_EDUS_PER_TRANSACTION
+ and self.queue._pending_edus_keyed
+ ):
+ _, val = self.queue._pending_edus_keyed.popitem()
+ pending_edus.append(val)
+
+ # Now we look for any PDUs to send, by getting up to 50 PDUs from the
+ # queue
+ self._pdus = self.queue._pending_pdus[:50]
+
+ if not self._pdus and not pending_edus:
+ return [], []
+
+ # if we've decided to send a transaction anyway, and we have room, we
+ # may as well send any pending RRs
+ if len(pending_edus) < MAX_EDUS_PER_TRANSACTION:
+ pending_edus.extend(self.queue._get_rr_edus(force_flush=True))
+
+ if self._pdus:
+ self._last_stream_ordering = self._pdus[
+ -1
+ ].internal_metadata.stream_ordering
+ assert self._last_stream_ordering
+
+ return self._pdus, pending_edus
+
+ async def __aexit__(self, exc_type, exc, tb):
+ if exc_type is not None:
+ # Failed to send transaction, so we bail out.
+ return
+
+ # Successfully sent transactions, so we remove pending PDUs from the queue
+ if self._pdus:
+ self.queue._pending_pdus = self.queue._pending_pdus[len(self._pdus) :]
+
+ # Succeeded to send the transaction so we record where we have sent up
+ # to in the various streams
+
+ if self._device_stream_id:
+ await self.queue._store.delete_device_msgs_for_remote(
+ self.queue._destination, self._device_stream_id
+ )
+ self.queue._last_device_stream_id = self._device_stream_id
+
+ # also mark the device updates as sent
+ if self._device_list_id:
+ logger.info(
+ "Marking as sent %r %r", self.queue._destination, self._device_list_id
+ )
+ await self.queue._store.mark_as_sent_devices_by_remote(
+ self.queue._destination, self._device_list_id
+ )
+ self.queue._last_device_list_stream_id = self._device_list_id
+
+ if self._last_stream_ordering:
+ # we sent some PDUs and it was successful, so update our
+ # last_successful_stream_ordering in the destinations table.
+ await self.queue._store.set_destination_last_successful_stream_ordering(
+ self.queue._destination, self._last_stream_ordering
+ )
diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py
index 2a9cd063c4..07b740c2f2 100644
--- a/synapse/federation/sender/transaction_manager.py
+++ b/synapse/federation/sender/transaction_manager.py
@@ -69,15 +69,12 @@ class TransactionManager:
destination: str,
pdus: List[EventBase],
edus: List[Edu],
- ) -> bool:
+ ) -> None:
"""
Args:
destination: The destination to send to (e.g. 'example.org')
pdus: In-order list of PDUs to send
edus: List of EDUs to send
-
- Returns:
- True iff the transaction was successful
"""
# Make a transaction-sending opentracing span. This span follows on from
@@ -96,8 +93,6 @@ class TransactionManager:
edu.strip_context()
with start_active_span_follows_from("send_transaction", span_contexts):
- success = True
-
logger.debug("TX [%s] _attempt_new_transaction", destination)
txn_id = str(self._next_txn_id)
@@ -152,44 +147,29 @@ class TransactionManager:
response = await self._transport_layer.send_transaction(
transaction, json_data_cb
)
- code = 200
except HttpResponseException as e:
code = e.code
response = e.response
- if e.code in (401, 404, 429) or 500 <= e.code:
- logger.info(
- "TX [%s] {%s} got %d response", destination, txn_id, code
- )
- raise e
-
- logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
-
- if code == 200:
- for e_id, r in response.get("pdus", {}).items():
- if "error" in r:
- logger.warning(
- "TX [%s] {%s} Remote returned error for %s: %s",
- destination,
- txn_id,
- e_id,
- r,
- )
- else:
- for p in pdus:
+ set_tag(tags.ERROR, True)
+
+ logger.info("TX [%s] {%s} got %d response", destination, txn_id, code)
+ raise
+
+ logger.info("TX [%s] {%s} got 200 response", destination, txn_id)
+
+ for e_id, r in response.get("pdus", {}).items():
+ if "error" in r:
logger.warning(
- "TX [%s] {%s} Failed to send event %s",
+ "TX [%s] {%s} Remote returned error for %s: %s",
destination,
txn_id,
- p.event_id,
+ e_id,
+ r,
)
- success = False
- if success and pdus and destination in self._federation_metrics_domains:
+ if pdus and destination in self._federation_metrics_domains:
last_pdu = pdus[-1]
last_pdu_ts_metric.labels(server_name=destination).set(
last_pdu.origin_server_ts / 1000
)
-
- set_tag(tags.ERROR, not success)
- return success
diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py
index b921d63d30..0309661841 100644
--- a/synapse/storage/databases/main/transactions.py
+++ b/synapse/storage/databases/main/transactions.py
@@ -350,11 +350,11 @@ class TransactionStore(TransactionWorkerStore):
self.db_pool.simple_upsert_many_txn(
txn,
- "destination_rooms",
- ["destination", "room_id"],
- rows,
- ["stream_ordering"],
- [(stream_ordering,)] * len(rows),
+ table="destination_rooms",
+ key_names=("destination", "room_id"),
+ key_values=rows,
+ value_names=["stream_ordering"],
+ value_values=[(stream_ordering,)] * len(rows),
)
async def get_destination_last_successful_stream_ordering(
|