diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index cb2ef0210c..633c79c352 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -17,7 +17,7 @@
from twisted.internet import defer
from .persistence import TransactionActions
-from .units import Transaction
+from .units import Transaction, Edu
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
@@ -81,6 +81,8 @@ class TransactionQueue(object):
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
+ self.last_device_stream_id_by_dest = {}
+
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
@@ -155,6 +157,17 @@ class TransactionQueue(object):
self._attempt_new_transaction, destination
)
+ def enqueue_device_messages(self, destination):
+ if destination == self.server_name or destination == "localhost":
+ return
+
+ if not self.can_send_to(destination):
+ return
+
+ preserve_context_over_fn(
+ self._attempt_new_transaction, destination
+ )
+
@defer.inlineCallbacks
def _attempt_new_transaction(self, destination):
yield run_on_reactor()
@@ -175,6 +188,12 @@ class TransactionQueue(object):
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_failures = self.pending_failures_by_dest.pop(destination, [])
+ device_message_edus, device_stream_id = (
+ yield self._get_new_device_messages(destination)
+ )
+
+ pending_edus.extend(device_message_edus)
+
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
@@ -184,13 +203,34 @@ class TransactionQueue(object):
return
yield self._send_new_transaction(
- destination, pending_pdus, pending_edus, pending_failures
+ destination, pending_pdus, pending_edus, pending_failures,
+ device_stream_id,
+ should_delete_from_device_stream=bool(device_message_edus)
)
+ @defer.inlineCallbacks
+ def _get_new_device_messages(self, destination):
+ last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0)
+ to_device_stream_id = self.store.get_to_device_stream_token()
+ contents, stream_id = yield self.store.get_new_device_msgs_for_remote(
+ destination, last_device_stream_id, to_device_stream_id
+ )
+ edus = [
+ Edu(
+ origin=self.server_name,
+ destination=destination,
+ edu_type="m.direct_to_device",
+ content=content,
+ )
+ for content in contents
+ ]
+ defer.returnValue((edus, stream_id))
+
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
- pending_failures):
+ pending_failures, device_stream_id,
+ should_delete_from_device_stream):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
@@ -215,9 +255,9 @@ class TransactionQueue(object):
"TX [%s] {%s} Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures)
+ len(pdus),
+ len(edus),
+ len(failures)
)
logger.debug("TX [%s] Persisting transaction...", destination)
@@ -242,9 +282,9 @@ class TransactionQueue(object):
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
transaction.transaction_id,
- len(pending_pdus),
- len(pending_edus),
- len(pending_failures),
+ len(pdus),
+ len(edus),
+ len(failures),
)
with limiter:
@@ -299,6 +339,13 @@ class TransactionQueue(object):
logger.info(
"Failed to send event %s to %s", p.event_id, destination
)
+ else:
+ # Remove the acknowledged device messages from the database
+ if should_delete_from_device_stream:
+ yield self.store.delete_device_msgs_for_remote(
+ destination, device_stream_id
+ )
+ self.last_device_stream_id_by_dest[destination] = device_stream_id
except NotRetryingDestination:
logger.info(
"TX [%s] not ready for retry yet - "
|