summary refs log tree commit diff
path: root/synapse/federation/sender/per_destination_queue.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/sender/per_destination_queue.py')
-rw-r--r--synapse/federation/sender/per_destination_queue.py88
1 files changed, 45 insertions, 43 deletions
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 5012aaea35..e13cd20ffa 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -15,11 +15,11 @@
 # limitations under the License.
 import datetime
 import logging
+from typing import Dict, Hashable, Iterable, List, Tuple
 
 from prometheus_client import Counter
 
-from twisted.internet import defer
-
+import synapse.server
 from synapse.api.errors import (
     FederationDeniedError,
     HttpResponseException,
@@ -31,7 +31,7 @@ from synapse.handlers.presence import format_user_presence_state
 from synapse.metrics import sent_transactions_counter
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.presence import UserPresenceState
-from synapse.types import StateMap
+from synapse.types import ReadReceipt
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 
 # This is defined in the Matrix spec and enforced by the receiver.
@@ -56,13 +56,18 @@ class PerDestinationQueue(object):
     Manages the per-destination transmission queues.
 
     Args:
-        hs (synapse.HomeServer):
-        transaction_sender (TransactionManager):
-        destination (str): the server_name of the destination that we are managing
+        hs
+        transaction_sender
+        destination: the server_name of the destination that we are managing
             transmission for.
     """
 
-    def __init__(self, hs, transaction_manager, destination):
+    def __init__(
+        self,
+        hs: "synapse.server.HomeServer",
+        transaction_manager: "synapse.federation.sender.TransactionManager",
+        destination: str,
+    ):
         self._server_name = hs.hostname
         self._clock = hs.get_clock()
         self._store = hs.get_datastore()
@@ -72,20 +77,20 @@ class PerDestinationQueue(object):
         self.transmission_loop_running = False
 
         # a list of tuples of (pending pdu, order)
-        self._pending_pdus = []  # type: list[tuple[EventBase, int]]
-        self._pending_edus = []  # type: list[Edu]
+        self._pending_pdus = []  # type: List[Tuple[EventBase, int]]
+        self._pending_edus = []  # type: List[Edu]
 
         # Pending EDUs by their "key". Keyed EDUs are EDUs that get clobbered
         # based on their key (e.g. typing events by room_id)
         # Map of (edu_type, key) -> Edu
-        self._pending_edus_keyed = {}  # type: StateMap[Edu]
+        self._pending_edus_keyed = {}  # type: Dict[Tuple[str, Hashable], Edu]
 
         # Map of user_id -> UserPresenceState of pending presence to be sent to this
         # destination
-        self._pending_presence = {}  # type: dict[str, UserPresenceState]
+        self._pending_presence = {}  # type: Dict[str, UserPresenceState]
 
         # room_id -> receipt_type -> user_id -> receipt_dict
-        self._pending_rrs = {}
+        self._pending_rrs = {}  # type: Dict[str, Dict[str, Dict[str, dict]]]
         self._rrs_pending_flush = False
 
         # stream_id of last successfully sent to-device message.
@@ -95,50 +100,50 @@ class PerDestinationQueue(object):
         # stream_id of last successfully sent device list update.
         self._last_device_list_stream_id = 0
 
-    def __str__(self):
+    def __str__(self) -> str:
         return "PerDestinationQueue[%s]" % self._destination
 
-    def pending_pdu_count(self):
+    def pending_pdu_count(self) -> int:
         return len(self._pending_pdus)
 
-    def pending_edu_count(self):
+    def pending_edu_count(self) -> int:
         return (
             len(self._pending_edus)
             + len(self._pending_presence)
             + len(self._pending_edus_keyed)
         )
 
-    def send_pdu(self, pdu, order):
+    def send_pdu(self, pdu: EventBase, order: int) -> None:
         """Add a PDU to the queue, and start the transmission loop if neccessary
 
         Args:
-            pdu (EventBase): pdu to send
-            order (int):
+            pdu: pdu to send
+            order
         """
         self._pending_pdus.append((pdu, order))
         self.attempt_new_transaction()
 
-    def send_presence(self, states):
+    def send_presence(self, states: Iterable[UserPresenceState]) -> None:
         """Add presence updates to the queue. Start the transmission loop if neccessary.
 
         Args:
-            states (iterable[UserPresenceState]): presence to send
+            states: presence to send
         """
         self._pending_presence.update({state.user_id: state for state in states})
         self.attempt_new_transaction()
 
-    def queue_read_receipt(self, receipt):
+    def queue_read_receipt(self, receipt: ReadReceipt) -> None:
         """Add a RR to the list to be sent. Doesn't start the transmission loop yet
         (see flush_read_receipts_for_room)
 
         Args:
-            receipt (synapse.api.receipt_info.ReceiptInfo): receipt to be queued
+            receipt: receipt to be queued
         """
         self._pending_rrs.setdefault(receipt.room_id, {}).setdefault(
             receipt.receipt_type, {}
         )[receipt.user_id] = {"event_ids": receipt.event_ids, "data": receipt.data}
 
-    def flush_read_receipts_for_room(self, room_id):
+    def flush_read_receipts_for_room(self, room_id: str) -> None:
         # if we don't have any read-receipts for this room, it may be that we've already
         # sent them out, so we don't need to flush.
         if room_id not in self._pending_rrs:
@@ -146,15 +151,15 @@ class PerDestinationQueue(object):
         self._rrs_pending_flush = True
         self.attempt_new_transaction()
 
-    def send_keyed_edu(self, edu, key):
+    def send_keyed_edu(self, edu: Edu, key: Hashable) -> None:
         self._pending_edus_keyed[(edu.edu_type, key)] = edu
         self.attempt_new_transaction()
 
-    def send_edu(self, edu):
+    def send_edu(self, edu) -> None:
         self._pending_edus.append(edu)
         self.attempt_new_transaction()
 
-    def attempt_new_transaction(self):
+    def attempt_new_transaction(self) -> None:
         """Try to start a new transaction to this destination
 
         If there is already a transaction in progress to this destination,
@@ -177,23 +182,22 @@ class PerDestinationQueue(object):
             self._transaction_transmission_loop,
         )
 
-    @defer.inlineCallbacks
-    def _transaction_transmission_loop(self):
-        pending_pdus = []
+    async def _transaction_transmission_loop(self) -> None:
+        pending_pdus = []  # type: List[Tuple[EventBase, int]]
         try:
             self.transmission_loop_running = True
 
             # This will throw if we wouldn't retry. We do this here so we fail
             # quickly, but we will later check this again in the http client,
             # hence why we throw the result away.
-            yield get_retry_limiter(self._destination, self._clock, self._store)
+            await get_retry_limiter(self._destination, self._clock, self._store)
 
             pending_pdus = []
             while True:
                 # We have to keep 2 free slots for presence and rr_edus
                 limit = MAX_EDUS_PER_TRANSACTION - 2
 
-                device_update_edus, dev_list_id = yield self._get_device_update_edus(
+                device_update_edus, dev_list_id = await self._get_device_update_edus(
                     limit
                 )
 
@@ -202,7 +206,7 @@ class PerDestinationQueue(object):
                 (
                     to_device_edus,
                     device_stream_id,
-                ) = yield self._get_to_device_message_edus(limit)
+                ) = await self._get_to_device_message_edus(limit)
 
                 pending_edus = device_update_edus + to_device_edus
 
@@ -269,7 +273,7 @@ class PerDestinationQueue(object):
 
                 # END CRITICAL SECTION
 
-                success = yield self._transaction_manager.send_new_transaction(
+                success = await self._transaction_manager.send_new_transaction(
                     self._destination, pending_pdus, pending_edus
                 )
                 if success:
@@ -280,7 +284,7 @@ class PerDestinationQueue(object):
                     # Remove the acknowledged device messages from the database
                     # Only bother if we actually sent some device messages
                     if to_device_edus:
-                        yield self._store.delete_device_msgs_for_remote(
+                        await self._store.delete_device_msgs_for_remote(
                             self._destination, device_stream_id
                         )
 
@@ -289,7 +293,7 @@ class PerDestinationQueue(object):
                         logger.info(
                             "Marking as sent %r %r", self._destination, dev_list_id
                         )
-                        yield self._store.mark_as_sent_devices_by_remote(
+                        await self._store.mark_as_sent_devices_by_remote(
                             self._destination, dev_list_id
                         )
 
@@ -334,7 +338,7 @@ class PerDestinationQueue(object):
             # We want to be *very* sure we clear this after we stop processing
             self.transmission_loop_running = False
 
-    def _get_rr_edus(self, force_flush):
+    def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
         if not self._pending_rrs:
             return
         if not force_flush and not self._rrs_pending_flush:
@@ -351,17 +355,16 @@ class PerDestinationQueue(object):
         self._rrs_pending_flush = False
         yield edu
 
-    def _pop_pending_edus(self, limit):
+    def _pop_pending_edus(self, limit: int) -> List[Edu]:
         pending_edus = self._pending_edus
         pending_edus, self._pending_edus = pending_edus[:limit], pending_edus[limit:]
         return pending_edus
 
-    @defer.inlineCallbacks
-    def _get_device_update_edus(self, limit):
+    async def _get_device_update_edus(self, limit: int) -> Tuple[List[Edu], int]:
         last_device_list = self._last_device_list_stream_id
 
         # Retrieve list of new device updates to send to the destination
-        now_stream_id, results = yield self._store.get_device_updates_by_remote(
+        now_stream_id, results = await self._store.get_device_updates_by_remote(
             self._destination, last_device_list, limit=limit
         )
         edus = [
@@ -378,11 +381,10 @@ class PerDestinationQueue(object):
 
         return (edus, now_stream_id)
 
-    @defer.inlineCallbacks
-    def _get_to_device_message_edus(self, limit):
+    async def _get_to_device_message_edus(self, limit: int) -> Tuple[List[Edu], int]:
         last_device_stream_id = self._last_device_stream_id
         to_device_stream_id = self._store.get_to_device_stream_token()
-        contents, stream_id = yield self._store.get_new_device_msgs_for_remote(
+        contents, stream_id = await self._store.get_new_device_msgs_for_remote(
             self._destination, last_device_stream_id, to_device_stream_id, limit
         )
         edus = [