summary refs log tree commit diff
path: root/synapse/federation/sender/per_destination_queue.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/sender/per_destination_queue.py')
-rw-r--r--synapse/federation/sender/per_destination_queue.py107
1 files changed, 58 insertions, 49 deletions
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index fad980b893..4e698981a4 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,10 @@
 # limitations under the License.
 import datetime
 import logging
+from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
 from synapse.api.errors import (
     FederationDeniedError,
     HttpResponseException,
@@ -30,9 +29,13 @@ from synapse.federation.units import Edu
 from synapse.handlers.presence import format_user_presence_state
 from synapse.metrics import sent_transactions_counter
 from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.storage import UserPresenceState
+from synapse.storage.presence import UserPresenceState
+from synapse.types import ReadReceipt
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
+if TYPE_CHECKING:
+    import synapse.server
+
 # This is defined in the Matrix spec and enforced by the receiver.
 MAX_EDUS_PER_TRANSACTION = 100
 
@@ -55,13 +58,18 @@ class PerDestinationQueue(object):
     Manages the per-destination transmission queues.
 
     Args:
-        hs (synapse.HomeServer):
-        transaction_sender (TransactionManager):
-        destination (str): the server_name of the destination that we are managing
+        hs
+        transaction_sender
+        destination: the server_name of the destination that we are managing
             transmission for.
     """
 
-    def __init__(self, hs, transaction_manager, destination):
+    def __init__(
+        self,
+        hs: "synapse.server.HomeServer",
+        transaction_manager: "synapse.federation.sender.TransactionManager",
+        destination: str,
+    ):
         self._server_name = hs.hostname
         self._clock = hs.get_clock()
         self._store = hs.get_datastore()
@@ -71,20 +79,23 @@ class PerDestinationQueue(object):
         self.transmission_loop_running = False
 
         # a list of tuples of (pending pdu, order)
-        self._pending_pdus = []  # type: list[tuple[EventBase, int]]
-        self._pending_edus = []  # type: list[Edu]
+        self._pending_pdus = []  # type: List[Tuple[EventBase, int]]
+
+        # XXX this is never actually used: see
+        # https://github.com/matrix-org/synapse/issues/7549
+        self._pending_edus = []  # type: List[Edu]
 
         # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
         # based on their key (e.g. typing events by room_id)
         # Map of (edu_type, key) -> Edu
-        self._pending_edus_keyed = {}  # type: dict[tuple[str, str], Edu]
+        self._pending_edus_keyed = {}  # type: Dict[Tuple[str, Hashable], Edu]
 
         # Map of user_id -> UserPresenceState of pending presence to be sent to this
         # destination
-        self._pending_presence = {}  # type: dict[str, UserPresenceState]
+        self._pending_presence = {}  # type: Dict[str, UserPresenceState]
 
         # room_id -> receipt_type -> user_id -> receipt_dict
-        self._pending_rrs = {}
+        self._pending_rrs = {}  # type: Dict[str, Dict[str, Dict[str, dict]]]
         self._rrs_pending_flush = False
 
         # stream_id of last successfully sent to-device message.
@@ -94,50 +105,50 @@ class PerDestinationQueue(object):
         # stream_id of last successfully sent device list update.
         self._last_device_list_stream_id = 0
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "PerDestinationQueue[%s]" % self._destination
 
-    def pending_pdu_count(self):
+    def pending_pdu_count(self) -> int:
         return len(self._pending_pdus)
 
-    def pending_edu_count(self):
+    def pending_edu_count(self) -> int:
         return (
             len(self._pending_edus)
             + len(self._pending_presence)
             + len(self._pending_edus_keyed)
         )
 
-    def send_pdu(self, pdu, order):
+    def send_pdu(self, pdu: EventBase, order: int) -> None:
         """Add a PDU to the queue, and start the transmission loop if neccessary
 
         Args:
-            pdu (EventBase): pdu to send
-            order (int):
+            pdu: pdu to send
+            order
         """
         self._pending_pdus.append((pdu, order))
         self.attempt_new_transaction()
 
-    def send_presence(self, states):
+    def send_presence(self, states: Iterable[UserPresenceState]) -> None:
         """Add presence updates to the queue. Start the transmission loop if neccessary.
 
         Args:
-            states (iterable[UserPresenceState]): presence to send
+            states: presence to send
         """
         self._pending_presence.update({state.user_id: state for state in states})
         self.attempt_new_transaction()
 
-    def queue_read_receipt(self, receipt):
+    def queue_read_receipt(self, receipt: ReadReceipt) -> None:
         """Add a RR to the list to be sent. Doesn't start the transmission loop yet
         (see flush_read_receipts_for_room)
 
         Args:
-            receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued
+            receipt: receipt to be queued
         """
         self._pending_rrs.setdefault(receipt.room_id, {}).setdefault(
             receipt.receipt_type, {}
         )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data}
 
-    def flush_read_receipts_for_room(self, room_id):
+    def flush_read_receipts_for_room(self, room_id: str) -> None:
         # if we don't have any read-receipts for this room, it may be that we've already
         # sent them out, so we don't need to flush.
         if room_id not in self._pending_rrs:
@@ -145,15 +156,15 @@ class PerDestinationQueue(object):
         self._rrs_pending_flush = True
         self.attempt_new_transaction()
 
-    def send_keyed_edu(self, edu, key):
+    def send_keyed_edu(self, edu: Edu, key: Hashable) -> None:
         self._pending_edus_keyed[(edu.edu_type, key)] = edu
         self.attempt_new_transaction()
 
-    def send_edu(self, edu):
+    def send_edu(self, edu) -> None:
         self._pending_edus.append(edu)
         self.attempt_new_transaction()
 
-    def attempt_new_transaction(self):
+    def attempt_new_transaction(self) -> None:
         """Try to start a new transaction to this destination
 
         If there is already a transaction in progress to this destination,
@@ -176,31 +187,31 @@ class PerDestinationQueue(object):
             self._transaction_transmission_loop,
         )
 
-    @defer.inlineCallbacks
-    def _transaction_transmission_loop(self):
-        pending_pdus = []
+    async def _transaction_transmission_loop(self) -> None:
+        pending_pdus = []  # type: List[Tuple[EventBase, int]]
         try:
             self.transmission_loop_running = True
 
             # This will throw if we wouldn't retry. We do this here so we fail
             # quickly, but we will later check this again in the http client,
             # hence why we throw the result away.
-            yield get_retry_limiter(self._destination, self._clock, self._store)
+            await get_retry_limiter(self._destination, self._clock, self._store)
 
             pending_pdus = []
             while True:
                 # We have to keep 2 free slots for presence and rr_edus
                 limit = MAX_EDUS_PER_TRANSACTION - 2
 
-                device_update_edus, dev_list_id = (
-                    yield self._get_device_update_edus(limit)
+                device_update_edus, dev_list_id = await self._get_device_update_edus(
+                    limit
                 )
 
                 limit -= len(device_update_edus)
 
-                to_device_edus, device_stream_id = (
-                    yield self._get_to_device_message_edus(limit)
-                )
+                (
+                    to_device_edus,
+                    device_stream_id,
+                ) = await self._get_to_device_message_edus(limit)
 
                 pending_edus = device_update_edus + to_device_edus
 
@@ -267,7 +278,7 @@ class PerDestinationQueue(object):
 
                 # END CRITICAL SECTION
 
-                success = yield self._transaction_manager.send_new_transaction(
+                success = await self._transaction_manager.send_new_transaction(
                     self._destination, pending_pdus, pending_edus
                 )
                 if success:
@@ -278,7 +289,7 @@ class PerDestinationQueue(object):
                     # Remove the acknowledged device messages from the database
                     # Only bother if we actually sent some device messages
                     if to_device_edus:
-                        yield self._store.delete_device_msgs_for_remote(
+                        await self._store.delete_device_msgs_for_remote(
                             self._destination, device_stream_id
                         )
 
@@ -287,7 +298,7 @@ class PerDestinationQueue(object):
                         logger.info(
                             "Marking as sent %r %r", self._destination, dev_list_id
                         )
-                        yield self._store.mark_as_sent_devices_by_remote(
+                        await self._store.mark_as_sent_devices_by_remote(
                             self._destination, dev_list_id
                         )
 
@@ -332,7 +343,7 @@ class PerDestinationQueue(object):
             # We want to be *very* sure we clear this after we stop processing
             self.transmission_loop_running = False
 
-    def _get_rr_edus(self, force_flush):
+    def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
         if not self._pending_rrs:
             return
         if not force_flush and not self._rrs_pending_flush:
@@ -349,38 +360,36 @@ class PerDestinationQueue(object):
         self._rrs_pending_flush = False
         yield edu
 
-    def _pop_pending_edus(self, limit):
+    def _pop_pending_edus(self, limit: int) -> List[Edu]:
         pending_edus = self._pending_edus
         pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:]
         return pending_edus
 
-    @defer.inlineCallbacks
-    def _get_device_update_edus(self, limit):
+    async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]:
         last_device_list = self._last_device_list_stream_id
 
         # Retrieve list of new device updates to send to the destination
-        now_stream_id, results = yield self._store.get_devices_by_remote(
+        now_stream_id, results = await self._store.get_device_updates_by_remote(
             self._destination, last_device_list, limit=limit
         )
         edus = [
             Edu(
                 origin=self._server_name,
                 destination=self._destination,
-                edu_type="m.device_list_update",
+                edu_type=edu_type,
                 content=content,
             )
-            for content in results
+            for (edu_type, content) in results
         ]
 
-        assert len(edus) <= limit, "get_devices_by_remote returned too many EDUs"
+        assert len(edus) <= limit, "get_device_updates_by_remote returned too many EDUs"
 
         return (edus, now_stream_id)
 
-    @defer.inlineCallbacks
-    def _get_to_device_message_edus(self, limit):
+    async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]:
         last_device_stream_id = self._last_device_stream_id
         to_device_stream_id = self._store.get_to_device_stream_token()
-        contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
+        contents, stream_id = await self._store.get_new_device_msgs_for_remote(
             self._destination, last_device_stream_id, to_device_stream_id, limit
         )
         edus = [